@@ -216,16 +216,28 @@ def psd_array_welch(
216216 )
217217
218218 parallel , my_spect_func , n_jobs = parallel_func (_spect_func , n_jobs = n_jobs )
219- func = partial (
220- spectrogram ,
221- detrend = detrend ,
222- noverlap = n_overlap ,
223- nperseg = n_per_seg ,
224- nfft = n_fft ,
225- fs = sfreq ,
226- window = window ,
227- mode = mode ,
228- )
219+
220+ def func (* args , ** kwargs ):
221+ # swallow SciPy warnings
222+ with warnings .catch_warnings ():
223+ warnings .filterwarnings (
224+ action = "ignore" ,
225+ module = "scipy" ,
226+ category = UserWarning ,
227+ message = r"nperseg = \d+ is greater than input length" ,
228+ )
229+ return spectrogram (
230+ * args ,
231+ ** kwargs ,
232+ detrend = detrend ,
233+ noverlap = n_overlap ,
234+ nperseg = n_per_seg ,
235+ nfft = n_fft ,
236+ fs = sfreq ,
237+ window = window ,
238+ mode = mode ,
239+ )
240+
229241 if np .any (np .isnan (x )):
230242 good_mask = ~ np .isnan (x )
231243 # NaNs originate from annot, so must match for all channels. Note that we CANNOT
@@ -256,18 +268,10 @@ def psd_array_welch(
256268 else :
257269 x_splits = [arr for arr in np .array_split (x , n_jobs ) if arr .size != 0 ]
258270 agg_func = np .concatenate
259- # swallow SciPy warnings
260- with warnings .catch_warnings ():
261- warnings .filterwarnings (
262- action = "ignore" ,
263- module = "scipy" ,
264- category = UserWarning ,
265- message = r"nperseg = \d+ is greater than input length" ,
266- )
267- f_spect = parallel (
268- my_spect_func (d , func = func , freq_sl = freq_sl , average = average , output = output )
269- for d in x_splits
270- )
271+ f_spect = parallel (
272+ my_spect_func (d , func = func , freq_sl = freq_sl , average = average , output = output )
273+ for d in x_splits
274+ )
271275 psds = agg_func (f_spect , axis = 0 )
272276 shape = dshape + (len (freqs ),)
273277 if average is None :
0 commit comments