Skip to content

Commit 8b17361

Browse files
author
Joan Massich
committed
Add Digitization class
1 parent f4ff170 commit 8b17361

File tree

10 files changed

+177
-12
lines changed

10 files changed

+177
-12
lines changed

mne/digitization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import DigPoint
1+
from .base import DigPoint, Digitization
22

33
__all__ = [
44
'DigPoint',

mne/digitization/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..utils.check import _check_option
3535
from .. import __version__
3636

37-
from .base import _format_dig_points
37+
from .base import _format_dig_points, Digitization
3838

3939
b = bytes # alias
4040

@@ -261,7 +261,7 @@ def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None,
261261
'kind': FIFF.FIFFV_POINT_EEG,
262262
'coord_frame': FIFF.FIFFV_COORD_HEAD})
263263

264-
return _format_dig_points(dig)
264+
return Digitization(_format_dig_points(dig))
265265

266266

267267
def _call_make_dig_points(nasion, lpa, rpa, hpi, extra, convert=True):

mne/digitization/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
from ..transforms import _coord_frame_name
99
from ..io.constants import FIFF
10+
from ..utils._bunch import MNEObjectsList
1011

1112
_dig_kind_dict = {
1213
'cardinal': FIFF.FIFFV_POINT_CARDINAL,
@@ -73,3 +74,16 @@ def __eq__(self, other): # noqa: D105
7374
return False
7475
else:
7576
return np.allclose(self['r'], other['r'])
77+
78+
79+
class Digitization(MNEObjectsList):
80+
"""Represent a list of DigPoint objects.
81+
82+
Parameters
83+
----------
84+
elements : list
85+
A list of DigPoint objects.
86+
"""
87+
88+
def __init__(self, elements=None):
89+
super(Digitization, self).__init__(elements=elements, kls=DigPoint)

mne/io/ctf/ctf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .info import _compose_meas_info, _read_bad_chans, _annotate_bad_segments
2222
from .constants import CTF
2323

24+
from ...digitization.base import _format_dig_points
25+
from ...digitization import Digitization
26+
2427

2528
@fill_doc
2629
def read_raw_ctf(directory, system_clock='truncate', preload=False,
@@ -116,6 +119,7 @@ def __init__(self, directory, system_clock='truncate', preload=False,
116119
# Compose a structure which makes fiff writing a piece of cake
117120
info = _compose_meas_info(res4, coils, coord_trans, eeg)
118121
info['dig'] += digs
122+
info['dig'] = Digitization(_format_dig_points(info['dig']))
119123
info['bads'] += _read_bad_chans(directory, info)
120124

121125
# Determine how our data is distributed across files

mne/io/meas_info.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from .proc_history import _read_proc_history, _write_proc_history
3030
from ..transforms import invert_transform
3131
from ..utils import logger, verbose, warn, object_diff, _validate_type
32+
from ..digitization.base import _format_dig_points
33+
from ..digitization import Digitization
3234
from .compensator import get_current_comp
3335

3436
# XXX: most probably the functions needing this, should go somewhere else
@@ -156,6 +158,7 @@ def _unique_channel_names(ch_names):
156158

157159

158160
# XXX Eventually this should be de-duplicated with the MNE-MATLAB stuff...
161+
# XXX: Digitization, Docstrings need a pass changing lists for Digitization
159162
class Info(dict):
160163
"""Measurement information.
161164
@@ -1129,7 +1132,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
11291132
info['dev_ctf_t'] = Transform('meg', 'ctf_head', dev_ctf_trans)
11301133

11311134
# All kinds of auxliary stuff
1132-
info['dig'] = dig
1135+
info['dig'] = Digitization(_format_dig_points(dig))
11331136
info['bads'] = bads
11341137
info._update_redundant()
11351138
if clean_bads:

mne/io/tests/test_raw.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from mne.utils import _TempDir, catch_logging, _raw_annot
2121
from mne.io.meas_info import _get_valid_units
2222

23+
from mne.digitization import Digitization
24+
2325

2426
def test_orig_units():
2527
"""Test the error handling for original units."""
@@ -87,6 +89,7 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, **kwargs):
8789
full_data = raw._data
8890
assert raw.__class__.__name__ in repr(raw) # to test repr
8991
assert raw.info.__class__.__name__ in repr(raw.info)
92+
assert isinstance(raw.info['dig'], (type(None), Digitization))
9093

9194
# gh-5604
9295
assert _handle_meas_date(raw.info['meas_date']) >= 0

mne/tests/test_digitization.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
3+
# Joan Massich <mailsik@gmail.com>
4+
#
5+
# License: BSD (3-clause)
6+
import pytest
7+
import numpy as np
8+
from mne.digitization import Digitization
9+
from mne.digitization.base import _format_dig_points
10+
11+
dig_dict_list = [
12+
dict(kind=_, ident=_, r=np.empty((3,)), coord_frame=_)
13+
for _ in [1, 2, 42]
14+
]
15+
16+
digpoints_list = _format_dig_points(dig_dict_list)
17+
18+
19+
@pytest.mark.parametrize('data', [
20+
pytest.param(digpoints_list, id='list of digpoints'),
21+
pytest.param(dig_dict_list, id='list of digpoint dicts',
22+
marks=pytest.mark.xfail(raises=ValueError)),
23+
pytest.param(['foo', 'bar'], id='list of strings',
24+
marks=pytest.mark.xfail(raises=ValueError)),
25+
])
26+
def test_digitization_constructor(data):
27+
"""Test Digitization constructor."""
28+
dig = Digitization(data)
29+
assert dig == data
30+
31+
dig[0]['kind'] = data[0]['kind'] - 1 # modify something in dig
32+
assert dig != data
33+
34+
35+
def test_delete_elements():
36+
"""Test deleting some Digitization elements."""
37+
dig = Digitization(digpoints_list)
38+
original_length = len(dig)
39+
del dig[0]
40+
assert len(dig) == original_length - 1

mne/utils/_bunch.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
"""Bunch-related classes."""
33
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
44
# Eric Larson <larson.eric.d@gmail.com>
5+
# Joan Massich <mailsik@gmail.com>
56
#
67
# License: BSD (3-clause)
78

89
from copy import deepcopy
10+
from collections.abc import MutableSequence
911

1012

1113
###############################################################################
@@ -94,10 +96,58 @@ def _named_subclass(klass):
9496
class NamedInt(_Named, int):
9597
"""Int with a name in __repr__."""
9698

97-
pass
98-
9999

100100
class NamedFloat(_Named, float):
101101
"""Float with a name in __repr__."""
102102

103-
pass
103+
104+
class MNEObjectsList(MutableSequence):
105+
"""All the bolierplate for a list of specific MNE objects.
106+
107+
Parameters
108+
----------
109+
elements : list
110+
A list of Objects objects.
111+
112+
Attributes
113+
----------
114+
_items : list
115+
The container
116+
"""
117+
118+
def __init__(self, elements=None, kls=None):
119+
if kls is None:
120+
raise ValueError('kls is necessary')
121+
if elements is None:
122+
self._items = list()
123+
elif all([isinstance(_, kls) for _ in elements]):
124+
if elements is None:
125+
self._items = list()
126+
else:
127+
self._items = deepcopy(list(elements))
128+
else:
129+
# XXX: _msg should not be Digitization related
130+
_msg = 'Digitization expected a iterable of DigPoint objects.'
131+
raise ValueError(_msg)
132+
133+
def __len__(self): # noqa: D105
134+
return len(self._items)
135+
136+
def __getitem__(self, index): # noqa: D105
137+
return self._items[index]
138+
139+
def __setitem__(self, index, value): # noqa: D105
140+
self._items[index] = value
141+
142+
def __delitem__(self, index): # noqa: D105
143+
del self._items[index]
144+
145+
def insert(self, index, value): # noqa: D102
146+
self._items.insert(index, value)
147+
148+
def __eq__(self, other): # noqa: D105
149+
# if not isinstance(other, Digitization) or len(self) != len(other):
150+
if len(self) != len(other):
151+
return False
152+
else:
153+
return all([ss == oo for ss, oo in zip(self, other)])

mne/utils/numerics.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -663,13 +663,15 @@ def object_size(x):
663663
x : object
664664
Object to approximate the size of.
665665
Can be anything comprised of nested versions of:
666-
{dict, list, tuple, ndarray, str, bytes, float, int, None}.
666+
{dict, list, tuple, ndarray, str, bytes, float, int, None,
667+
Digitization}.
667668
668669
Returns
669670
-------
670671
size : int
671672
The estimated size in bytes of the object.
672673
"""
674+
from ..digitization import Digitization
673675
# Note: this will not process object arrays properly (since those only)
674676
# hold references
675677
if isinstance(x, (bytes, str, int, float, type(None))):
@@ -685,7 +687,7 @@ def object_size(x):
685687
for key, value in x.items():
686688
size += object_size(key)
687689
size += object_size(value)
688-
elif isinstance(x, (list, tuple)):
690+
elif isinstance(x, (list, tuple, Digitization)):
689691
size = sys.getsizeof(x) + sum(object_size(xx) for xx in x)
690692
elif sparse.isspmatrix_csc(x) or sparse.isspmatrix_csr(x):
691693
size = sum(sys.getsizeof(xx)
@@ -710,9 +712,9 @@ def object_diff(a, b, pre=''):
710712
----------
711713
a : object
712714
Currently supported: dict, list, tuple, ndarray, int, str, bytes,
713-
float, StringIO, BytesIO.
715+
float, StringIO, BytesIO, Digitization.
714716
b : object
715-
Must be same type as x1.
717+
Must be same type as ``a``.
716718
pre : str
717719
String to prepend to each line.
718720
@@ -721,6 +723,7 @@ def object_diff(a, b, pre=''):
721723
diffs : str
722724
A string representation of the differences.
723725
"""
726+
from ..digitization import Digitization
724727
out = ''
725728
if type(a) != type(b):
726729
out += pre + ' type mismatch (%s, %s)\n' % (type(a), type(b))
@@ -741,7 +744,7 @@ def object_diff(a, b, pre=''):
741744
else:
742745
for ii, (xx1, xx2) in enumerate(zip(a, b)):
743746
out += object_diff(xx1, xx2, pre + '[%s]' % ii)
744-
elif isinstance(a, (str, int, float, bytes, np.generic)):
747+
elif isinstance(a, (str, int, float, bytes, np.generic, Digitization)):
745748
if a != b:
746749
out += pre + ' value mismatch (%s, %s)\n' % (a, b)
747750
elif a is None:
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
3+
# Joan Massich <mailsik@gmail.com>
4+
#
5+
# License: BSD (3-clause)
6+
import pytest
7+
from mne.utils._bunch import MNEObjectsList
8+
9+
10+
class MyIntList(MNEObjectsList): # noqa: D101
11+
def __init__(self, elements=None):
12+
super(MyIntList, self).__init__(elements=elements, kls=int)
13+
14+
15+
@pytest.mark.parametrize('data', [
16+
pytest.param([1, 2], id='list of ints'),
17+
pytest.param(['foo', 'bar'], id='list of strings',
18+
marks=pytest.mark.xfail(raises=ValueError)),
19+
])
20+
def test_mne_objects_list_constructor(data):
21+
"""Test MyIntList constructor."""
22+
my_int_list = MyIntList(data)
23+
assert my_int_list == data
24+
25+
26+
# XXX: to fix
27+
def test_check_proper_constructor_error():
28+
"""Test constructor error wording."""
29+
with pytest.raises(ValueError): # , match='')
30+
MyIntList(['foo', 'bar'])
31+
32+
33+
@pytest.mark.parametrize('data, expected_len, expected_bool', [
34+
pytest.param(None, 0, False, id='None'),
35+
pytest.param([], 0, False, id='empty list'),
36+
pytest.param([1, 2], 2, True, id='list of ints'),
37+
pytest.param(['foo', 'bar'], 2, True, id='list of strings',
38+
marks=pytest.mark.xfail(raises=ValueError)),
39+
])
40+
def test_emptylist_none_behaviour_in_conditionals(
41+
data, expected_len, expected_bool):
42+
"""Test MyIntList constructor."""
43+
my_int_list = MyIntList(data)
44+
assert len(my_int_list) == expected_len
45+
if my_int_list:
46+
assert expected_bool
47+
else:
48+
assert not expected_bool

0 commit comments

Comments
 (0)