Skip to content

Commit d143d8f

Browse files
authored
Merge branch 'main' into cap_python_proyect_toml
2 parents 5347cba + a606364 commit d143d8f

File tree

19 files changed

+1109
-490
lines changed

19 files changed

+1109
-490
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.6.0
3+
rev: v5.0.0
44
hooks:
55
- id: check-yaml
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
- repo: https://github.com/psf/black
9-
rev: 24.8.0
9+
rev: 24.10.0
1010
hooks:
1111
- id: black
1212
files: ^src/

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def run(self, **job_kwargs):
3333
sorting["unit_index"] = spikes["cluster_index"]
3434
sorting["segment_index"] = spikes["segment_index"]
3535
sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids)
36-
self.result = {"sorting": sorting}
36+
self.result = {"sorting": sorting, "spikes": spikes}
3737
self.result["templates"] = self.templates
3838

3939
def compute_result(self, with_collision=False, **result_params):
@@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params):
4545

4646
_run_key_saved = [
4747
("sorting", "sorting"),
48+
("spikes", "npy"),
4849
("templates", "zarr_templates"),
4950
]
5051
_result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")]
@@ -71,9 +72,15 @@ def plot_performances_vs_snr(self, **kwargs):
7172

7273
return plot_performances_vs_snr(self, **kwargs)
7374

75+
def plot_performances_comparison(self, **kwargs):
76+
from .benchmark_plot_tools import plot_performances_comparison
77+
78+
return plot_performances_comparison(self, **kwargs)
79+
7480
def plot_collisions(self, case_keys=None, figsize=None):
7581
if case_keys is None:
7682
case_keys = list(self.cases.keys())
83+
import matplotlib.pyplot as plt
7784

7885
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
7986

@@ -90,70 +97,6 @@ def plot_collisions(self, case_keys=None, figsize=None):
9097

9198
return fig
9299

93-
def plot_comparison_matching(
94-
self,
95-
case_keys=None,
96-
performance_names=["accuracy", "recall", "precision"],
97-
colors=["g", "b", "r"],
98-
ylim=(-0.1, 1.1),
99-
figsize=None,
100-
):
101-
102-
if case_keys is None:
103-
case_keys = list(self.cases.keys())
104-
105-
num_methods = len(case_keys)
106-
import pylab as plt
107-
108-
fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10))
109-
for i, key1 in enumerate(case_keys):
110-
for j, key2 in enumerate(case_keys):
111-
if len(axs.shape) > 1:
112-
ax = axs[i, j]
113-
else:
114-
ax = axs[j]
115-
comp1 = self.get_result(key1)["gt_comparison"]
116-
comp2 = self.get_result(key2)["gt_comparison"]
117-
if i <= j:
118-
for performance, color in zip(performance_names, colors):
119-
perf1 = comp1.get_performance()[performance]
120-
perf2 = comp2.get_performance()[performance]
121-
ax.plot(perf2, perf1, ".", label=performance, color=color)
122-
123-
ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
124-
ax.set_ylim(ylim)
125-
ax.set_xlim(ylim)
126-
ax.spines[["right", "top"]].set_visible(False)
127-
ax.set_aspect("equal")
128-
129-
label1 = self.cases[key1]["label"]
130-
label2 = self.cases[key2]["label"]
131-
if j == i:
132-
ax.set_ylabel(f"{label1}")
133-
else:
134-
ax.set_yticks([])
135-
if i == j:
136-
ax.set_xlabel(f"{label2}")
137-
else:
138-
ax.set_xticks([])
139-
if i == num_methods - 1 and j == num_methods - 1:
140-
patches = []
141-
import matplotlib.patches as mpatches
142-
143-
for color, name in zip(colors, performance_names):
144-
patches.append(mpatches.Patch(color=color, label=name))
145-
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
146-
else:
147-
ax.spines["bottom"].set_visible(False)
148-
ax.spines["left"].set_visible(False)
149-
ax.spines["top"].set_visible(False)
150-
ax.spines["right"].set_visible(False)
151-
ax.set_xticks([])
152-
ax.set_yticks([])
153-
plt.tight_layout(h_pad=0, w_pad=0)
154-
155-
return fig
156-
157100
def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
158101
import pandas as pd
159102

@@ -196,6 +139,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None):
196139
plot_study_unit_counts(self, case_keys, figsize=figsize)
197140

198141
def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):
142+
import matplotlib.pyplot as plt
199143

200144
fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False)
201145

src/spikeinterface/benchmark/benchmark_plot_tools.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu
235235
ax.scatter(x, y, marker=".", label=label)
236236
ax.set_title(k)
237237

238-
ax.set_ylim(0, 1.05)
238+
ax.set_ylim(-0.05, 1.05)
239239

240240
if count == 2:
241241
ax.legend()
242242

243243
return fig
244+
245+
246+
def plot_performances_comparison(
247+
study,
248+
case_keys=None,
249+
figsize=None,
250+
metrics=["accuracy", "recall", "precision"],
251+
colors=["g", "b", "r"],
252+
ylim=(-0.1, 1.1),
253+
):
254+
import matplotlib.pyplot as plt
255+
256+
if case_keys is None:
257+
case_keys = list(study.cases.keys())
258+
259+
num_methods = len(case_keys)
260+
assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!"
261+
262+
fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False)
263+
for i, key1 in enumerate(case_keys):
264+
for j, key2 in enumerate(case_keys):
265+
266+
if i < j:
267+
ax = axs[i, j - 1]
268+
269+
comp1 = study.get_result(key1)["gt_comparison"]
270+
comp2 = study.get_result(key2)["gt_comparison"]
271+
272+
for performance, color in zip(metrics, colors):
273+
perf1 = comp1.get_performance()[performance]
274+
perf2 = comp2.get_performance()[performance]
275+
ax.scatter(perf2, perf1, marker=".", label=performance, color=color)
276+
277+
ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
278+
ax.set_ylim(ylim)
279+
ax.set_xlim(ylim)
280+
ax.spines[["right", "top"]].set_visible(False)
281+
ax.set_aspect("equal")
282+
283+
label1 = study.cases[key1]["label"]
284+
label2 = study.cases[key2]["label"]
285+
286+
if i == j - 1:
287+
ax.set_xlabel(label2)
288+
ax.set_ylabel(label1)
289+
290+
else:
291+
if j >= 1 and i < num_methods - 1:
292+
ax = axs[i, j - 1]
293+
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
294+
ax.set_xticks([])
295+
ax.set_yticks([])
296+
297+
ax = axs[num_methods - 2, 0]
298+
patches = []
299+
from matplotlib.patches import Patch
300+
301+
for color, name in zip(colors, metrics):
302+
patches.append(Patch(color=color, label=name))
303+
ax.legend(handles=patches)
304+
fig.tight_layout()
305+
return fig

src/spikeinterface/benchmark/tests/test_benchmark_sorter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def test_SorterStudy(setup_module):
6464
print(study)
6565

6666
# # this run the sorters
67-
# study.run()
67+
study.run()
6868

6969
# # this run comparisons
70-
# study.compute_results()
70+
study.compute_results()
7171
print(study)
7272

7373
# this is from the base class
@@ -84,5 +84,7 @@ def test_SorterStudy(setup_module):
8484

8585
if __name__ == "__main__":
8686
study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy"
87+
if study_folder.exists():
88+
shutil.rmtree(study_folder)
8789
create_a_study(study_folder)
8890
test_SorterStudy(study_folder)

src/spikeinterface/extractors/neoextractors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .mearec import MEArecRecordingExtractor, MEArecSortingExtractor, read_mearec
1010
from .mcsraw import MCSRawRecordingExtractor, read_mcsraw
1111
from .neuralynx import NeuralynxRecordingExtractor, NeuralynxSortingExtractor, read_neuralynx, read_neuralynx_sorting
12+
from .neuronexus import NeuroNexusRecordingExtractor, read_neuronexus
1213
from .neuroscope import (
1314
NeuroScopeRecordingExtractor,
1415
NeuroScopeSortingExtractor,
@@ -54,6 +55,7 @@
5455
MCSRawRecordingExtractor,
5556
NeuralynxRecordingExtractor,
5657
NeuroScopeRecordingExtractor,
58+
NeuroNexusRecordingExtractor,
5759
NixRecordingExtractor,
5860
OpenEphysBinaryRecordingExtractor,
5961
OpenEphysLegacyRecordingExtractor,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
from spikeinterface.core.core_tools import define_function_from_class
6+
7+
from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor
8+
9+
10+
class NeuroNexusRecordingExtractor(NeoBaseRecordingExtractor):
11+
"""
12+
Class for reading data from NeuroNexus Allego.
13+
14+
Based on :py:class:`neo.rawio.NeuronexusRawIO`
15+
16+
Parameters
17+
----------
18+
file_path : str | Path
19+
The file path to the metadata .xdat.json file of an Allego session
20+
stream_id : str | None, default: None
21+
If there are several streams, specify the stream id you want to load.
22+
stream_name : str | None, default: None
23+
If there are several streams, specify the stream name you want to load.
24+
all_annotations : bool, default: False
25+
Load exhaustively all annotations from neo.
26+
use_names_as_ids : bool, default: False
27+
Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the
28+
names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO.
29+
30+
In Neuronexus the ids provided by NeoRawIO are the hardware channel ids stored as `ntv_chan_name` within
31+
the metada and the names are the `chan_names`
32+
33+
34+
"""
35+
36+
NeoRawIOClass = "NeuroNexusRawIO"
37+
38+
def __init__(
39+
self,
40+
file_path: str | Path,
41+
stream_id: str | None = None,
42+
stream_name: str | None = None,
43+
all_annotations: bool = False,
44+
use_names_as_ids: bool = False,
45+
):
46+
neo_kwargs = self.map_to_neo_kwargs(file_path)
47+
NeoBaseRecordingExtractor.__init__(
48+
self,
49+
stream_id=stream_id,
50+
stream_name=stream_name,
51+
all_annotations=all_annotations,
52+
use_names_as_ids=use_names_as_ids,
53+
**neo_kwargs,
54+
)
55+
56+
self._kwargs.update(dict(file_path=str(Path(file_path).resolve())))
57+
58+
@classmethod
59+
def map_to_neo_kwargs(cls, file_path):
60+
61+
neo_kwargs = {"filename": str(file_path)}
62+
63+
return neo_kwargs
64+
65+
66+
read_neuronexus = define_function_from_class(source_class=NeuroNexusRecordingExtractor, name="read_neuronexus")

src/spikeinterface/extractors/neoextractors/plexon2.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor):
2828
ids: ["source3.1" , "source3.2", "source3.3", "source3.4"]
2929
all_annotations : bool, default: False
3030
Load exhaustively all annotations from neo.
31+
reading_attempts : int, default: 25
32+
Number of attempts to read the file before raising an error
33+
This opening process is somewhat unreliable and might fail occasionally. Adjust this higher
34+
if you encounter problems in opening the file.
3135
3236
Examples
3337
--------
@@ -37,8 +41,16 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor):
3741

3842
NeoRawIOClass = "Plexon2RawIO"
3943

40-
def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False):
41-
neo_kwargs = self.map_to_neo_kwargs(file_path)
44+
def __init__(
45+
self,
46+
file_path,
47+
stream_id=None,
48+
stream_name=None,
49+
use_names_as_ids=True,
50+
all_annotations=False,
51+
reading_attempts: int = 25,
52+
):
53+
neo_kwargs = self.map_to_neo_kwargs(file_path, reading_attempts=reading_attempts)
4254
NeoBaseRecordingExtractor.__init__(
4355
self,
4456
stream_id=stream_id,
@@ -50,8 +62,18 @@ def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids
5062
self._kwargs.update({"file_path": str(file_path)})
5163

5264
@classmethod
53-
def map_to_neo_kwargs(cls, file_path):
65+
def map_to_neo_kwargs(cls, file_path, reading_attempts: int = 25):
66+
5467
neo_kwargs = {"filename": str(file_path)}
68+
69+
from packaging.version import Version
70+
import neo
71+
72+
neo_version = Version(neo.__version__)
73+
74+
if neo_version > Version("0.13.3"):
75+
neo_kwargs["reading_attempts"] = reading_attempts
76+
5577
return neo_kwargs
5678

5779

src/spikeinterface/extractors/tests/common_tests.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ def test_open(self):
5252
num_samples = rec.get_num_samples(segment_index=segment_index)
5353

5454
full_traces = rec.get_traces(segment_index=segment_index)
55-
assert full_traces.shape == (num_samples, num_chans)
56-
assert full_traces.dtype == dtype
55+
assert full_traces.shape == (
56+
num_samples,
57+
num_chans,
58+
), f"{full_traces.shape} != {(num_samples, num_chans)}"
59+
assert full_traces.dtype == dtype, f"{full_traces.dtype} != {dtype=}"
5760

5861
traces_sample_first = rec.get_traces(segment_index=segment_index, start_frame=0, end_frame=1)
5962
assert traces_sample_first.shape == (1, num_chans)

src/spikeinterface/extractors/tests/test_neoextractors.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,14 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase):
181181
]
182182

183183

184+
class NeuroNexusRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
185+
ExtractorClass = NeuroNexusRecordingExtractor
186+
downloads = ["neuronexus"]
187+
entities = [
188+
("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "0"}),
189+
]
190+
191+
184192
class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
185193
ExtractorClass = PlexonRecordingExtractor
186194
downloads = ["plexon"]
@@ -360,7 +368,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase):
360368
ExtractorClass = Plexon2RecordingExtractor
361369
downloads = ["plexon"]
362370
entities = [
363-
("plexon/4chDemoPL2.pl2", {"stream_id": "3"}),
371+
("plexon/4chDemoPL2.pl2", {"stream_name": "WB-Wideband"}),
364372
]
365373

366374

0 commit comments

Comments
 (0)