Skip to content

Commit 9fe1fb6

Browse files
committed
Add option to store and return tfr taper weights
1 parent fa841cb commit 9fe1fb6

3 files changed

Lines changed: 65 additions & 7 deletions

File tree

mne/time_frequency/multitaper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ def tfr_array_multitaper(
469469
use_fft=True,
470470
decim=1,
471471
output="complex",
472+
return_weights=False,
472473
n_jobs=None,
473474
*,
474475
verbose=None,
@@ -502,6 +503,12 @@ def tfr_array_multitaper(
502503
* ``'itc'`` : inter-trial coherence.
503504
* ``'avg_power_itc'`` : average of single trial power and inter-trial
504505
coherence across trials.
506+
507+
return_weights : bool, default False
508+
If True, return the taper weights. Only applies if ``output="complex"``.
509+
510+
.. versionadded:: 1.9.0
511+
505512
%(n_jobs)s
506513
The parallelization is implemented across channels.
507514
%(verbose)s
@@ -520,6 +527,9 @@ def tfr_array_multitaper(
520527
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
521528
contain the average power and the imaginary values contain the
522529
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
530+
weights : array of shape (n_tapers, n_freqs)
531+
The taper weights. Only returned if ``output="complex"`` and
532+
``return_weights=True``.
523533
524534
See Also
525535
--------
@@ -550,6 +560,7 @@ def tfr_array_multitaper(
550560
use_fft=use_fft,
551561
decim=decim,
552562
output=output,
563+
return_weights=return_weights,
553564
n_jobs=n_jobs,
554565
verbose=verbose,
555566
)

mne/time_frequency/tests/test_tfr.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,17 +432,21 @@ def test_tfr_morlet():
432432
def test_dpsswavelet():
433433
"""Test DPSS tapers."""
434434
freqs = np.arange(5, 25, 3)
435-
Ws = _make_dpss(
436-
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
435+
Ws, weights = _make_dpss(
436+
1000,
437+
freqs=freqs,
438+
n_cycles=freqs / 2.0,
439+
time_bandwidth=4.0,
440+
zero_mean=True,
441+
return_weights=True,
437442
)
438443

439-
assert len(Ws) == 3 # 3 tapers expected
444+
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
445+
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)
440446

441447
# Check that zero mean is true
442448
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5
443449

444-
assert len(Ws[0]) == len(freqs) # As many wavelets as asked for
445-
446450

447451
@pytest.mark.slowtest
448452
def test_tfr_multitaper():

mne/time_frequency/tfr.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,11 @@ def _make_dpss(
264264
-------
265265
Ws : list of array
266266
The wavelets time series.
267+
Cs : list of array
268+
The concentration weights. Only returned if return_weights=True.
267269
"""
268270
Ws = list()
271+
Cs = list()
269272

270273
freqs = np.array(freqs)
271274
if np.any(freqs <= 0):
@@ -281,6 +284,7 @@ def _make_dpss(
281284

282285
for m in range(n_taps):
283286
Wm = list()
287+
Cm = list()
284288
for k, f in enumerate(freqs):
285289
if len(n_cycles) != 1:
286290
this_n_cycles = n_cycles[k]
@@ -302,12 +306,15 @@ def _make_dpss(
302306
real_offset = Wk.mean()
303307
Wk -= real_offset
304308
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
309+
Ck = np.sqrt(conc[m])
305310

306311
Wm.append(Wk)
312+
Cm.append(Ck)
307313

308314
Ws.append(Wm)
315+
Cs.append(Cm)
309316
if return_weights:
310-
return Ws, conc
317+
return Ws, Cs
311318
return Ws
312319

313320

@@ -428,6 +435,7 @@ def _compute_tfr(
428435
use_fft=True,
429436
decim=1,
430437
output="complex",
438+
return_weights=False,
431439
n_jobs=None,
432440
*,
433441
verbose=None,
@@ -479,6 +487,9 @@ def _compute_tfr(
479487
* 'avg_power_itc' : average of single trial power and inter-trial
480488
coherence across trials.
481489
490+
return_weights : bool, default False
491+
Whether to return the taper weights. Only applies if method='multitaper' and
492+
output='complex' or 'phase'.
482493
%(n_jobs)s
483494
The number of epochs to process at the same time. The parallelization
484495
is implemented across channels.
@@ -495,6 +506,10 @@ def _compute_tfr(
495506
n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
496507
real values in the ``output`` contain average power' and the imaginary
497508
values contain the ITC: ``out = avg_power + i * itc``.
509+
510+
weights : array of shape (n_tapers, n_freqs)
511+
The taper weights. Only returned if method='multitaper', output='complex' or
512+
'phase', and return_weights=True.
498513
"""
499514
# Check data
500515
epoch_data = np.asarray(epoch_data)
@@ -516,6 +531,9 @@ def _compute_tfr(
516531
decim,
517532
output,
518533
)
534+
return_weights = (
535+
return_weights and method == "multitaper" and output in ["complex", "phase"]
536+
)
519537

520538
decim = _ensure_slice(decim)
521539
if (freqs > sfreq / 2.0).any():
@@ -531,13 +549,18 @@ def _compute_tfr(
531549
Ws = [W] # to have same dimensionality as the 'multitaper' case
532550

533551
elif method == "multitaper":
534-
Ws = _make_dpss(
552+
out = _make_dpss(
535553
sfreq,
536554
freqs,
537555
n_cycles=n_cycles,
538556
time_bandwidth=time_bandwidth,
539557
zero_mean=zero_mean,
558+
return_weights=return_weights,
540559
)
560+
if return_weights:
561+
Ws, weights = out
562+
else:
563+
Ws = out
541564

542565
# Check wavelets
543566
if len(Ws[0][0]) > epoch_data.shape[2]:
@@ -561,6 +584,8 @@ def _compute_tfr(
561584
out = np.empty((n_chans, n_freqs, n_times), dtype)
562585
elif output in ["complex", "phase"] and method == "multitaper":
563586
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
587+
if return_weights:
588+
weights = np.array(weights)
564589
else:
565590
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
566591

@@ -585,6 +610,9 @@ def _compute_tfr(
585610
out = out.transpose(2, 0, 1, 3, 4)
586611
else:
587612
out = out.transpose(1, 0, 2, 3)
613+
614+
if return_weights:
615+
return out, weights
588616
return out
589617

590618

@@ -1203,6 +1231,9 @@ def __init__(
12031231
method_kw.setdefault("output", "power")
12041232
self._freqs = np.asarray(freqs, dtype=np.float64)
12051233
del freqs
1234+
# always store weights for per-taper outputs
1235+
if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]:
1236+
method_kw["return_weights"] = True
12061237
# check validity of kwargs manually to save compute time if any are invalid
12071238
tfr_funcs = dict(
12081239
morlet=tfr_array_morlet,
@@ -1224,6 +1255,7 @@ def __init__(
12241255
self._method = method
12251256
self._inst_type = type(inst)
12261257
self._baseline = None
1258+
self._weights = None
12271259
self.preload = True # needed for __getitem__, never False for TFRs
12281260
# self._dims may also get updated by child classes
12291261
self._dims = ["channel", "freq", "time"]
@@ -1382,6 +1414,7 @@ def __getstate__(self):
13821414
info=self.info,
13831415
baseline=self._baseline,
13841416
decim=self._decim,
1417+
weights=self._weights,
13851418
)
13861419

13871420
def __setstate__(self, state):
@@ -1410,6 +1443,7 @@ def __setstate__(self, state):
14101443
self._decim = defaults["decim"]
14111444
self.preload = True
14121445
self._set_times(self._raw_times)
1446+
self._weights = state.get("weights") # objs saved before #XXX won't have
14131447
# Handle instance type. Prior to gh-11282, Raw was not a possibility so if
14141448
# `inst_type_str` is missing it must be Epochs or Evoked
14151449
unknown_class = Epochs if "epoch" in self._dims else Evoked
@@ -1516,6 +1550,10 @@ def _compute_tfr(self, data, n_jobs, verbose):
15161550
if self.method == "stockwell":
15171551
self._data, self._itc, freqs = result
15181552
assert np.array_equal(self._freqs, freqs)
1553+
elif self.method == "multitaper" and self._tfr_func.keywords.get(
1554+
"output", ""
1555+
) in ["complex", "phase"]:
1556+
self._data, self._weights = result
15191557
elif self._tfr_func.keywords.get("output", "").endswith("_itc"):
15201558
self._data, self._itc = result.real, result.imag
15211559
else:
@@ -1694,6 +1732,11 @@ def times(self):
16941732
"""The time points present in the data (in seconds)."""
16951733
return self._times_readonly
16961734

1735+
@property
1736+
def weights(self):
1737+
"""The weights used for each taper in the time-frequency estimates."""
1738+
return self._weights
1739+
16971740
@fill_doc
16981741
def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True):
16991742
"""Crop data to a given time interval in place.

0 commit comments

Comments
 (0)