Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ jobs:
conda install --yes --quiet numpy matplotlib scipy scikit-learn
conda install --yes --quiet cython pillow
pip install --upgrade pip
pip install scikit-optimize tqdm mne
pip install scikit-optimize tqdm
pip install https://api.github.com/repos/mne-tools/mne-python/zipball/master
pip install sphinx sphinx-gallery sphinx_bootstrap_theme numpydoc
pip install -e .
- run:
Expand Down
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ install:
- source activate testenv
- conda install --yes --quiet numpy scipy scikit-learn matplotlib
- conda install --yes --quiet nose coverage
- pip install -q flake8 mne check-manifest h5py
- pip install -q flake8 check-manifest h5py
- pip install https://api.github.com/repos/mne-tools/mne-python/zipball/master
- pip install coverage
- python setup.py install
script:
Expand Down
48 changes: 41 additions & 7 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import mne
from mne import pick_types
from mne.externals.h5io import read_hdf5, write_hdf5
from mne.viz import plot_epochs as plot_mne_epochs

from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
Expand All @@ -24,7 +25,6 @@
_pbar, _handle_picks, _check_data, _compute_dots,
_get_picks_by_type, _pprint)
from .bayesopt import expected_improvement, bayes_opt
from .viz import plot_epochs

_INIT_PARAMS = ('consensus', 'n_interpolate', 'picks',
'verbose', 'n_jobs', 'cv', 'random_state',
Expand Down Expand Up @@ -528,7 +528,7 @@ def __repr__(self):
consensus=self.consensus,
verbose=self.verbose, picks=self.picks)
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)
offset=len(class_name),),)

def _vote_bad_epochs(self, epochs, picks):
"""Each channel votes for an epoch as good or bad.
Expand Down Expand Up @@ -769,7 +769,7 @@ def _run_local_reject_cv(epochs, thresh_func, picks_, n_interpolate, cv,
desc = 'n_interp'

for jdx, n_interp in enumerate(_pbar(n_interpolate, desc=desc,
position=1, verbose=verbose)):
position=1, verbose=verbose)):
# we can interpolate before doing cross-valida(tion
# because interpolation is independent across trials.
local_reject.n_interpolate_[ch_type] = n_interp
Expand Down Expand Up @@ -906,7 +906,7 @@ def __repr__(self):
thresh_method=self.thresh_method,
random_state=self.random_state, n_jobs=self.n_jobs)
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)
offset=len(class_name),),)

def __getstate__(self):
"""Get the state of autoreject as a dictionary."""
Expand Down Expand Up @@ -1289,8 +1289,42 @@ def plot_epochs(self, epochs, scalings=None, title=''):
fig : Instance of matplotlib.figure.Figure
Epochs traces.
"""
return plot_epochs(
labels = self.labels
n_epochs, n_channels = labels.shape

if not labels.shape[0] == len(epochs.events):
raise ValueError('The number of epochs should match the number of'
'epochs *before* autoreject. Please provide'
'the epochs object before running autoreject')
if not labels.shape[1] == len(epochs.ch_names):
raise ValueError('The number of channels should match the number'
' of channels before running autoreject.')
bad_epochs_idx = np.where(self.bad_epochs)[0]
if len(bad_epochs_idx) > 0 and \
bad_epochs_idx.max() > len(epochs.events):
raise ValueError('You had a bad_epoch with index'
'%d but there are only %d epochs. Make sure'
' to provide the epochs *before* running'
'autoreject.'
% (bad_epochs_idx.max(),
len(epochs.events)))

color_map = {0: None, 1: 'r', 2: (0.6, 0.6, 0.6, 1.0)}
epoch_colors = list()
for epoch_idx, label_epoch in enumerate(labels):
if self.bad_epochs[epoch_idx]:
epoch_color = ['r'] * n_channels
epoch_colors.append(epoch_color)
continue
epoch_color = list()
for this_label in label_epoch:
if not np.isnan(this_label):
epoch_color.append(color_map[this_label])
else:
epoch_color.append(None)
epoch_colors.append(epoch_color)

return plot_mne_epochs(
epochs=epochs,
bad_epochs_idx=np.where(self.bad_epochs)[0],
log_labels=self.labels, scalings=scalings,
epoch_colors=epoch_colors, scalings=scalings,
title='')
Loading