@@ -1421,7 +1421,6 @@ def __setstate__(self, state):
14211421
14221422 defaults = dict (
14231423 method = "unknown" ,
1424- dims = ("epoch" , "channel" , "freq" , "time" )[- state ["data" ].ndim :],
14251424 baseline = None ,
14261425 decim = 1 ,
14271426 data_type = "TFR" ,
@@ -1445,7 +1444,7 @@ def __setstate__(self, state):
14451444 unknown_class = Epochs if "epoch" in self ._dims else Evoked
14461445 inst_types = dict (Raw = Raw , Epochs = Epochs , Evoked = Evoked , Unknown = unknown_class )
14471446 self ._inst_type = inst_types [defaults ["inst_type_str" ]]
1448- # sanity check data/freqs/times/info agreement
1447+ # sanity check data/freqs/times/info/weights agreement
14491448 self ._check_state ()
14501449
14511450 def __repr__ (self ):
@@ -1498,14 +1497,26 @@ def _check_compatibility(self, other):
14981497 raise RuntimeError (msg .format (problem , extra ))
14991498
15001499 def _check_state (self ):
1501- """Check data/freqs/times/info agreement during __setstate__."""
1500+ """Check data/freqs/times/info/weights agreement during __setstate__."""
15021501 msg = "{} axis of data ({}) doesn't match {} attribute ({})"
15031502 n_chan_info = len (self .info ["chs" ])
15041503 n_chan = self ._data .shape [self ._dims .index ("channel" )]
1504+ n_taper = (
1505+ self ._data .shape [self ._dims .index ("taper" )]
1506+ if "taper" in self ._dims
1507+ else None
1508+ )
15051509 n_freq = self ._data .shape [self ._dims .index ("freq" )]
15061510 n_time = self ._data .shape [self ._dims .index ("time" )]
15071511 if n_chan_info != n_chan :
15081512 msg = msg .format ("Channel" , n_chan , "info" , n_chan_info )
1513+ elif n_taper is not None :
1514+ if self ._weights is None :
1515+ raise RuntimeError ("Taper dimension in data, but no weights found." )
1516+ if n_taper != self ._weights .shape [0 ]:
1517+ msg = msg .format ("Taper" , n_taper , "weights" , self ._weights .shape [0 ])
1518+ elif n_freq != self ._weights .shape [1 ]:
1519+ msg = msg .format ("Frequency" , n_freq , "weights" , self ._weights .shape [1 ])
15091520 elif n_freq != len (self .freqs ):
15101521 msg = msg .format ("Frequency" , n_freq , "freqs" , self .freqs .size )
15111522 elif n_time != len (self .times ):
@@ -2775,6 +2786,7 @@ class AverageTFR(BaseTFR):
27752786 %(nave_tfr_attr)s
27762787 %(sfreq_tfr_attr)s
27772788 %(shape_tfr_attr)s
2789+ %(weights_tfr_attr)s
27782790
27792791 See Also
27802792 --------
@@ -2891,6 +2903,10 @@ def __getstate__(self):
28912903
28922904 def __setstate__ (self , state ):
28932905 """Unpack AverageTFR from serialized format."""
2906+ if state ["data" ].ndim != 3 :
2907+ raise ValueError (f"RawTFR data should be 3D, got { state ['data' ].ndim } ." )
2908+ # Set dims now since optional tapers makes it difficult to disentangle later
2909+ state ["dims" ] = ("channel" , "freq" , "time" )
28942910 super ().__setstate__ (state )
28952911 self ._comment = state .get ("comment" , "" )
28962912 self ._nave = state .get ("nave" , 1 )
@@ -3046,6 +3062,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin):
30463062 %(selection_attr)s
30473063 %(sfreq_tfr_attr)s
30483064 %(shape_tfr_attr)s
3065+ %(weights_tfr_attr)s
30493066
30503067 See Also
30513068 --------
@@ -3130,8 +3147,15 @@ def __getstate__(self):
31303147
31313148 def __setstate__ (self , state ):
31323149 """Unpack EpochsTFR from serialized format."""
3133- if state ["data" ].ndim != 4 :
3134- raise ValueError (f"EpochsTFR data should be 4D, got { state ['data' ].ndim } ." )
3150+ if state ["data" ].ndim not in [4 , 5 ]:
3151+ raise ValueError (
3152+ f"EpochsTFR data should be 4D or 5D, got { state ['data' ].ndim } ."
3153+ )
3154+ # Set dims now since optional tapers makes it difficult to disentangle later
3155+ state ["dims" ] = ("epoch" , "channel" )
3156+ if state ["data" ].ndim == 5 :
3157+ state ["dims" ] += ("taper" ,)
3158+ state ["dims" ] += ("freq" , "time" )
31353159 super ().__setstate__ (state )
31363160 self ._metadata = state .get ("metadata" , None )
31373161 n_epochs = self .shape [0 ]
@@ -3235,7 +3259,16 @@ def average(self, method="mean", *, dim="epochs", copy=False):
32353259 See discussion here:
32363260
32373261 https://github.com/scipy/scipy/pull/12676#issuecomment-783370228
3262+
3263+ Averaging is not supported for data containing a taper dimension.
32383264 """
3265+ if "taper" in self ._dims :
3266+ raise NotImplementedError (
3267+ "Averaging multitaper tapers across epochs, frequencies, or times is "
3268+ "not supported. If averaging across epochs, consider averaging the "
3269+ "epochs before computing the complex/phase spectrum."
3270+ )
3271+
32393272 _check_option ("dim" , dim , ("epochs" , "freqs" , "times" ))
32403273 axis = self ._dims .index (dim [:- 1 ]) # self._dims entries aren't plural
32413274
@@ -3607,6 +3640,7 @@ class EpochsTFRArray(EpochsTFR):
36073640 %(selection)s
36083641 %(drop_log)s
36093642 %(metadata_epochstfr)s
3643+ %(weights_tfr_array)s
36103644
36113645 Attributes
36123646 ----------
@@ -3623,6 +3657,7 @@ class EpochsTFRArray(EpochsTFR):
36233657 %(selection_attr)s
36243658 %(sfreq_tfr_attr)s
36253659 %(shape_tfr_attr)s
3660+ %(weights_tfr_attr)s
36263661
36273662 See Also
36283663 --------
@@ -3645,6 +3680,7 @@ def __init__(
36453680 selection = None ,
36463681 drop_log = None ,
36473682 metadata = None ,
3683+ weights = None ,
36483684 ):
36493685 state = dict (info = info , data = data , times = times , freqs = freqs )
36503686 optional = dict (
@@ -3655,6 +3691,7 @@ def __init__(
36553691 selection = selection ,
36563692 drop_log = drop_log ,
36573693 metadata = metadata ,
3694+ weights = weights ,
36583695 )
36593696 for name , value in optional .items ():
36603697 if value is not None :
@@ -3697,6 +3734,7 @@ class RawTFR(BaseTFR):
36973734 method : str
36983735 The method used to compute the spectra (``'morlet'``, ``'multitaper'``
36993736 or ``'stockwell'``).
3737+ %(weights_tfr_attr)s
37003738
37013739 See Also
37023740 --------
@@ -3746,6 +3784,19 @@ def __init__(
37463784 ** method_kw ,
37473785 )
37483786
3787+ def __setstate__ (self , state ):
3788+ """Unpack RawTFR from serialized format."""
3789+ if state ["data" ].ndim not in [3 , 4 ]:
3790+ raise ValueError (
3791+ f"RawTFR data should be 3D or 4D, got { state ['data' ].ndim } ."
3792+ )
3793+ # Set dims now since optional tapers makes it difficult to disentangle later
3794+ state ["dims" ] = ("channel" ,)
3795+ if state ["data" ].ndim == 4 :
3796+ state ["dims" ] += ("taper" ,)
3797+ state ["dims" ] += ("freq" , "time" )
3798+ super ().__setstate__ (state )
3799+
37493800 def __getitem__ (self , item ):
37503801 """Get RawTFR data.
37513802
@@ -3811,6 +3862,7 @@ class RawTFRArray(RawTFR):
38113862 %(times)s
38123863 %(freqs_tfr_array)s
38133864 %(method_tfr_array)s
3865+ %(weights_tfr_array)s
38143866
38153867 Attributes
38163868 ----------
@@ -3821,6 +3873,7 @@ class RawTFRArray(RawTFR):
38213873 %(method_tfr_attr)s
38223874 %(sfreq_tfr_attr)s
38233875 %(shape_tfr_attr)s
3876+ %(weights_tfr_attr)s
38243877
38253878 See Also
38263879 --------
@@ -3838,10 +3891,13 @@ def __init__(
38383891 freqs ,
38393892 * ,
38403893 method = None ,
3894+ weights = None ,
38413895 ):
38423896 state = dict (info = info , data = data , times = times , freqs = freqs )
3843- if method is not None :
3844- state ["method" ] = method
3897+ optional = dict (method = method , weights = weights )
3898+ for name , value in optional .items ():
3899+ if value is not None :
3900+ state [name ] = value
38453901 self .__setstate__ (state )
38463902
38473903
0 commit comments