Skip to content

Commit 6763f7f

Browse files
committed
improvements to the realtime module
currently the test is breaking when it comes to using the RtEpochs object.
1 parent a00e5ed commit 6763f7f

File tree

4 files changed

+99
-63
lines changed

4 files changed

+99
-63
lines changed

mne/realtime/base_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def register_receive_callback(self, callback):
133133
if callback not in self._recv_callbacks:
134134
self._recv_callbacks.append(callback)
135135

136+
def start(self):
137+
"""Start the client."""
138+
self.__enter__()
139+
140+
return self
141+
136142
def start_receive_thread(self, nchan):
137143
"""Start the receive thread.
138144
@@ -150,6 +156,12 @@ def start_receive_thread(self, nchan):
150156
self._recv_thread.daemon = True
151157
self._recv_thread.start()
152158

159+
def stop(self):
160+
"""Stop the client."""
161+
self.__exit__()
162+
163+
return self
164+
153165
def stop_receive_thread(self, stop_measurement=False):
154166
"""Stop the receive thread.
155167

mne/realtime/lsl_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,8 @@ def get_data_as_epoch(self, n_samples=1024, picks=None):
7474

7575
def iter_raw_buffers(self):
7676
"""Return an iterator over raw buffers."""
77-
pylsl = _check_pylsl_installed(strict=True)
78-
inlet = pylsl.StreamInlet(self.client)
79-
8077
while True:
81-
samples, _ = inlet.pull_chunk(max_samples=self.buffer_size)
78+
samples, _ = self.client.pull_chunk(max_samples=self.buffer_size)
8279

8380
yield np.vstack(samples).T
8481

mne/realtime/mock_lsl_stream.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,63 +31,68 @@ class MockLSLStream:
3131

3232
def __init__(self, host, n_channels=8, ch_type="eeg", sfreq=100,
3333
testing=False):
34-
self.host = host
35-
self.n_channels = n_channels
36-
self.ch_type = ch_type
37-
self.sfreq = sfreq
38-
self.streaming = False
39-
self.testing = testing
34+
self._host = host
35+
self._n_channels = n_channels
36+
self._ch_type = ch_type
37+
self._sfreq = sfreq
38+
self._testing = testing
39+
self._streaming = False
4040

4141
def start(self):
4242
"""Start a mock LSL stream."""
43-
pylsl = _check_pylsl_installed(strict=True)
44-
self.streaming = True
45-
info = pylsl.StreamInfo('MNE', self.ch_type.upper(), self.n_channels,
46-
self.sfreq, 'float32', self.host)
47-
info.desc().append_child_value("manufacturer", "MNE")
48-
channels = info.desc().append_child("channels")
49-
for c_id in range(1, self.n_channels + 1):
50-
channels.append_child("channel") \
51-
.append_child_value("label", "MNE {:03d}".format(c_id)) \
52-
.append_child_value("type", self.ch_type.lower()) \
53-
.append_child_value("unit", "microvolt")
54-
55-
# next make an outlet
56-
outlet = pylsl.StreamOutlet(info)
57-
5843
print("now sending data...")
59-
self.process = Process(target=self._initiate_stream, args=(outlet,))
44+
self.process = Process(target=self._initiate_stream)
6045
self.process.daemon = True
6146
self.process.start()
6247

6348
return self
6449

65-
def close(self):
50+
def stop(self):
6651
"""Stop a mock LSL stream."""
52+
self._streaming = False
6753
self.process.terminate()
6854

6955
print("Stopping stream...")
7056

7157
return self
7258

73-
def _initiate_stream(self, outlet):
59+
def _initiate_stream(self):
60+
# outlet needs to be made on the same process
61+
pylsl = _check_pylsl_installed(strict=True)
62+
self._streaming = True
63+
info = pylsl.StreamInfo(name='MNE', type=self._ch_type.upper(),
64+
channel_count=self._n_channels,
65+
nominal_srate=self._sfreq,
66+
channel_format='float32', source_id=self._host)
67+
info.desc().append_child_value("manufacturer", "MNE")
68+
channels = info.desc().append_child("channels")
69+
for c_id in range(1, self._n_channels + 1):
70+
channels.append_child("channel") \
71+
.append_child_value("label", "MNE {:03d}".format(c_id)) \
72+
.append_child_value("type", self._ch_type.lower()) \
73+
.append_child_value("unit", "microvolt")
74+
75+
# next make an outlet
76+
outlet = pylsl.StreamOutlet(info)
77+
78+
# let's make some data
7479
counter = 0
7580
trigger = 0
76-
while True:
77-
sample = counter % self.sfreq
81+
while self._streaming:
82+
sample = counter % self._sfreq
7883
# let's bound trigger to be between 1 and 10 so the max cycle
7984
# is ten seconds
80-
if trigger == 10:
81-
trigger = 0
85+
trigger = 0 if trigger == 10 else trigger
8286
if sample == 0:
8387
trigger += 1
8488

85-
if not self.testing:
86-
const = np.sin(2 * np.pi * sample / self.sfreq) * 1e-6
87-
mysample = rand(self.n_channels).dot(const).tolist()
89+
if self._testing:
90+
const = trigger
8891
else:
89-
mysample = np.ones(self.n_channels).dot(trigger).tolist()
92+
const = np.sin(2 * np.pi * sample / self._sfreq) * 1e-6
93+
94+
mysample = rand(self._n_channels).dot(const).tolist()
9095
# now send it and wait for a bit
9196
outlet.push_sample(mysample)
9297
counter += 1
93-
time.sleep(self.sfreq**-1)
98+
time.sleep(self._sfreq**-1)

mne/realtime/tests/test_lsl_client.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,26 @@
22
#
33
# License: BSD (3-clause)
44
from random import random as rand
5+
import numpy as np
56

67
from mne.realtime import LSLClient, MockLSLStream, RtEpochs
78
from mne.utils import run_tests_if_main, requires_pylsl
89
from mne import create_info
10+
from mne.io import constants
911

1012

1113
host = 'myuid34234'
12-
14+
event_id = 5
15+
n_channels = 8
16+
sfreq = 100
17+
tmin = -0.1
18+
tmax = 0.5
19+
ch_types = ["eeg" for n in range(n_channels - 1)] + ['stim']
20+
ch_names = ["MNE {:03d}".format(ch_id) for ch_id
21+
in range(1, n_channels + 1)]
22+
stim_channel = ch_names[-1]
23+
24+
info = create_info(ch_names, sfreq, ch_types)
1325

1426
@requires_pylsl
1527
def test_lsl_client():
@@ -21,47 +33,54 @@ def test_lsl_client():
2133
stream = MockLSLStream(host, n_channels, testing=True)
2234
stream.start()
2335

24-
with LSLClient(info=None, host=host, wait_max=wait_max) as client:
36+
with LSLClient(info=info, host=host, wait_max=wait_max) as client:
2537
client_info = client.get_measurement_info()
26-
27-
assert ([ch["ch_name"] for ch in client_info["chs"]] ==
28-
["MNE {:03d}".format(ch_id) for ch_id in
29-
range(1, n_channels + 1)])
30-
3138
epoch = client.get_data_as_epoch(n_samples=n_samples)
32-
assert n_channels, n_samples == epoch.get_data().shape[1:]
3339

34-
stream.close()
40+
assert client_info['nchan'] == n_channels
41+
assert ([ch["ch_name"] for ch in client_info["chs"]] ==
42+
[ch_name for ch_name in ch_names])
43+
assert any([constants.FIFF.FIFFV_STIM_CH == ch['kind']
44+
for ch in info['chs']])
45+
assert n_channels, n_samples == epoch.get_data().shape[1:]
46+
47+
stream.stop()
3548

3649

3750
def test_lsl_rt_epochs():
3851
"""Test the functionality of the LSL Client with RtEpochs."""
39-
event_id = 5
40-
n_channels = 8
41-
sfreq = 100
42-
tmin = -0.1
43-
tmax = 0.5
44-
ch_types = ["eeg" for n in range(n_channels - 1)] + ['stim']
45-
ch_names = ["MNE {:03d}".format(ch_id) for ch_id
46-
in range(1, n_channels + 1)]
47-
stim_channel = ch_names[-1]
48-
4952
stream = MockLSLStream(host, n_channels=n_channels, ch_type="eeg",
5053
sfreq=sfreq, testing=True)
5154
stream.start()
5255

53-
info = create_info(ch_names, sfreq, ch_types)
54-
with LSLClient(info=info, host=host) as client:
55-
epochs_rt = RtEpochs(client, event_id, tmin, tmax, stim_channel)
56-
epochs_rt.start()
57-
time.sleep(10)
58-
epochs.stop(stop_receive_thread=True)
56+
try:
57+
data = None
58+
events_ids = None
59+
with LSLClient(info=info, host=host) as client:
60+
epochs_rt = RtEpochs(client, event_id, tmin, tmax, stim_channel)
61+
epochs_rt.start()
62+
time.sleep(15)
63+
64+
for ev_num, ev in enumerate(epochs_rt.iter_evoked()):
65+
if ev_num == 0:
66+
data = ev.data[None, :, :]
67+
events_ids = int(
68+
ev.comment) # comment attribute contains event_id
69+
else:
70+
data = np.concatenate(
71+
(data, ev.data[None, :, :]), axis=0)
72+
events_ids = np.append(events_ids,
73+
int(ev.comment))
74+
75+
epochs_rt.stop(stop_receive_thread=True)
5976

6077
data = epochs_rt.get_data()
78+
6179
n_samples_prestim = np.abs(tmin * sfreq)
6280
n_sample_poststim = tmax * sfreq
6381

64-
data_channels, data_samples = data.get_data().shape[1:]
82+
data_channels, data_samples = data.shape[1:]
83+
6584
assert data_channels == n_channels
6685
assert data_samples == n_samples_prestim + n_sample_poststim
6786

@@ -72,4 +91,7 @@ def test_lsl_rt_epochs():
7291

7392
assert data == data_expected
7493

94+
finally:
95+
stream.stop()
96+
7597
run_tests_if_main()

0 commit comments

Comments
 (0)