@@ -264,8 +264,11 @@ def _make_dpss(
264264 -------
265265 Ws : list of array
266266 The wavelets time series.
267+ Cs : list of array
268+ The concentration weights. Only returned if return_weights=True.
267269 """
268270 Ws = list ()
271+ Cs = list ()
269272
270273 freqs = np .array (freqs )
271274 if np .any (freqs <= 0 ):
@@ -281,6 +284,7 @@ def _make_dpss(
281284
282285 for m in range (n_taps ):
283286 Wm = list ()
287+ Cm = list ()
284288 for k , f in enumerate (freqs ):
285289 if len (n_cycles ) != 1 :
286290 this_n_cycles = n_cycles [k ]
@@ -302,12 +306,15 @@ def _make_dpss(
302306 real_offset = Wk .mean ()
303307 Wk -= real_offset
304308 Wk /= np .sqrt (0.5 ) * np .linalg .norm (Wk .ravel ())
309+ Ck = np .sqrt (conc [m ])
305310
306311 Wm .append (Wk )
312+ Cm .append (Ck )
307313
308314 Ws .append (Wm )
315+ Cs .append (Cm )
309316 if return_weights :
310- return Ws , conc
317+ return Ws , Cs
311318 return Ws
312319
313320
@@ -428,6 +435,7 @@ def _compute_tfr(
428435 use_fft = True ,
429436 decim = 1 ,
430437 output = "complex" ,
438+ return_weights = False ,
431439 n_jobs = None ,
432440 * ,
433441 verbose = None ,
@@ -479,6 +487,9 @@ def _compute_tfr(
479487 * 'avg_power_itc' : average of single trial power and inter-trial
480488 coherence across trials.
481489
490+ return_weights : bool, default False
491+ Whether to return the taper weights. Only applies if method='multitaper' and
492+ output='complex' or 'phase'.
482493 %(n_jobs)s
483494 The number of epochs to process at the same time. The parallelization
484495 is implemented across channels.
@@ -495,6 +506,10 @@ def _compute_tfr(
495506 n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
496507 real values in the ``output`` contain average power' and the imaginary
497508 values contain the ITC: ``out = avg_power + i * itc``.
509+
510+ weights : array of shape (n_tapers, n_freqs)
511+ The taper weights. Only returned if method='multitaper', output='complex' or
512+ 'phase', and return_weights=True.
498513 """
499514 # Check data
500515 epoch_data = np .asarray (epoch_data )
@@ -516,6 +531,9 @@ def _compute_tfr(
516531 decim ,
517532 output ,
518533 )
534+ return_weights = (
535+ return_weights and method == "multitaper" and output in ["complex" , "phase" ]
536+ )
519537
520538 decim = _ensure_slice (decim )
521539 if (freqs > sfreq / 2.0 ).any ():
@@ -531,13 +549,18 @@ def _compute_tfr(
531549 Ws = [W ] # to have same dimensionality as the 'multitaper' case
532550
533551 elif method == "multitaper" :
534- Ws = _make_dpss (
552+ out = _make_dpss (
535553 sfreq ,
536554 freqs ,
537555 n_cycles = n_cycles ,
538556 time_bandwidth = time_bandwidth ,
539557 zero_mean = zero_mean ,
558+ return_weights = return_weights ,
540559 )
560+ if return_weights :
561+ Ws , weights = out
562+ else :
563+ Ws = out
541564
542565 # Check wavelets
543566 if len (Ws [0 ][0 ]) > epoch_data .shape [2 ]:
@@ -561,6 +584,8 @@ def _compute_tfr(
561584 out = np .empty ((n_chans , n_freqs , n_times ), dtype )
562585 elif output in ["complex" , "phase" ] and method == "multitaper" :
563586 out = np .empty ((n_chans , n_tapers , n_epochs , n_freqs , n_times ), dtype )
587+ if return_weights :
588+ weights = np .array (weights )
564589 else :
565590 out = np .empty ((n_chans , n_epochs , n_freqs , n_times ), dtype )
566591
@@ -585,6 +610,9 @@ def _compute_tfr(
585610 out = out .transpose (2 , 0 , 1 , 3 , 4 )
586611 else :
587612 out = out .transpose (1 , 0 , 2 , 3 )
613+
614+ if return_weights :
615+ return out , weights
588616 return out
589617
590618
@@ -1203,6 +1231,9 @@ def __init__(
12031231 method_kw .setdefault ("output" , "power" )
12041232 self ._freqs = np .asarray (freqs , dtype = np .float64 )
12051233 del freqs
1234+ # always store weights for per-taper outputs
1235+ if method == "multitaper" and method_kw .get ("output" ) in ["complex" , "phase" ]:
1236+ method_kw ["return_weights" ] = True
12061237 # check validity of kwargs manually to save compute time if any are invalid
12071238 tfr_funcs = dict (
12081239 morlet = tfr_array_morlet ,
@@ -1224,6 +1255,7 @@ def __init__(
12241255 self ._method = method
12251256 self ._inst_type = type (inst )
12261257 self ._baseline = None
1258+ self ._weights = None
12271259 self .preload = True # needed for __getitem__, never False for TFRs
12281260 # self._dims may also get updated by child classes
12291261 self ._dims = ["channel" , "freq" , "time" ]
@@ -1382,6 +1414,7 @@ def __getstate__(self):
13821414 info = self .info ,
13831415 baseline = self ._baseline ,
13841416 decim = self ._decim ,
1417+ weights = self ._weights ,
13851418 )
13861419
13871420 def __setstate__ (self , state ):
@@ -1410,6 +1443,7 @@ def __setstate__(self, state):
14101443 self ._decim = defaults ["decim" ]
14111444 self .preload = True
14121445 self ._set_times (self ._raw_times )
1446+ self ._weights = state .get ("weights" ) # objs saved before #XXX won't have
14131447 # Handle instance type. Prior to gh-11282, Raw was not a possibility so if
14141448 # `inst_type_str` is missing it must be Epochs or Evoked
14151449 unknown_class = Epochs if "epoch" in self ._dims else Evoked
@@ -1516,6 +1550,10 @@ def _compute_tfr(self, data, n_jobs, verbose):
15161550 if self .method == "stockwell" :
15171551 self ._data , self ._itc , freqs = result
15181552 assert np .array_equal (self ._freqs , freqs )
1553+ elif self .method == "multitaper" and self ._tfr_func .keywords .get (
1554+ "output" , ""
1555+ ) in ["complex" , "phase" ]:
1556+ self ._data , self ._weights = result
15191557 elif self ._tfr_func .keywords .get ("output" , "" ).endswith ("_itc" ):
15201558 self ._data , self ._itc = result .real , result .imag
15211559 else :
@@ -1694,6 +1732,11 @@ def times(self):
16941732 """The time points present in the data (in seconds)."""
16951733 return self ._times_readonly
16961734
1735+ @property
1736+ def weights (self ):
1737+ """The weights used for each taper in the time-frequency estimates."""
1738+ return self ._weights
1739+
16971740 @fill_doc
16981741 def crop (self , tmin = None , tmax = None , fmin = None , fmax = None , include_tmax = True ):
16991742 """Crop data to a given time interval in place.
0 commit comments