Skip to content

Commit d6f8c5a

Browse files
authored
Add similarity_method param in automerge and fix multi-segment cross-contamination (#4201)
1 parent 3abb033 commit d6f8c5a

File tree

3 files changed

+100
-102
lines changed

3 files changed

+100
-102
lines changed

src/spikeinterface/curation/auto_merge.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"sigma_smooth_ms": 0.6,
7575
"adaptative_window_thresh": 0.5,
7676
},
77-
"template_similarity": {"template_diff_thresh": 0.25},
77+
"template_similarity": {"similarity_method": "l1", "template_diff_thresh": 0.25},
7878
"presence_distance": {"presence_distance_thresh": 100},
7979
"knn": {"k_nn": 10},
8080
"cross_contamination": {
@@ -310,7 +310,13 @@ def compute_merge_unit_groups(
310310
# STEP : check if potential merge with CC also have template similarity
311311
elif step == "template_similarity":
312312
template_similarity_ext = sorting_analyzer.get_extension("template_similarity")
313-
templates_similarity = template_similarity_ext.get_data()
313+
if template_similarity_ext.params["method"] == params["similarity_method"]:
314+
templates_similarity = template_similarity_ext.get_data()
315+
else:
316+
template_similarity_ext = sorting_analyzer.compute(
317+
"template_similarity", method=params["similarity_method"], save=False
318+
)
319+
templates_similarity = template_similarity_ext.get_data()
314320
templates_diff = 1 - templates_similarity
315321
pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"])
316322
outs["templates_diff"] = templates_diff
@@ -1057,17 +1063,29 @@ def compute_cross_contaminations(analyzer, pair_mask, cc_thresh, refractory_peri
10571063
CC = np.zeros((n, n), dtype=np.float32)
10581064
p_values = np.zeros((n, n), dtype=np.float32)
10591065

1060-
for unit_ind1 in range(len(unit_ids)):
1066+
if sorting.get_num_segments() > 1:
1067+
# for multi-segment sortings, we need to concatenate segments,
1068+
from spikeinterface import concatenate_sortings, select_segment_sorting
1069+
1070+
sorting_list = []
1071+
total_samples_list = []
1072+
for segment_index in range(sorting.get_num_segments()):
1073+
sorting_list.append(select_segment_sorting(sorting, segment_index))
1074+
total_samples_list.append(analyzer.get_num_samples(segment_index))
1075+
# concatenate segments
1076+
sorting_concat = concatenate_sortings(sorting_list=sorting_list, total_samples_list=total_samples_list)
1077+
else:
1078+
sorting_concat = sorting
10611079

1080+
for unit_ind1 in range(len(unit_ids)):
10621081
unit_id1 = unit_ids[unit_ind1]
1063-
spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
1064-
1082+
spike_train1 = np.array(sorting_concat.get_unit_spike_train(unit_id1))
10651083
for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)):
10661084
if not pair_mask[unit_ind1, unit_ind2]:
10671085
continue
10681086

10691087
unit_id2 = unit_ids[unit_ind2]
1070-
spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))
1088+
spike_train2 = np.array(sorting_concat.get_unit_spike_train(unit_id2))
10711089
# Compuyting the cross-contamination difference
10721090
if contaminations is not None:
10731091
C1 = contaminations[unit_ind1]
@@ -1194,8 +1212,8 @@ def presence_distance(sorting, unit1, unit2, bin_duration_s=2, bins=None, num_sa
11941212
ns = num_samples[segment_index]
11951213
bins = np.arange(0, ns, bin_size)
11961214

1197-
st1 = sorting.get_unit_spike_train(unit_id=unit1)
1198-
st2 = sorting.get_unit_spike_train(unit_id=unit2)
1215+
st1 = sorting.get_unit_spike_train(unit_id=unit1, segment_index=segment_index)
1216+
st2 = sorting.get_unit_spike_train(unit_id=unit2, segment_index=segment_index)
11991217

12001218
h1, _ = np.histogram(st1, bins)
12011219
h1 = h1.astype(float)

src/spikeinterface/curation/tests/common.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,27 @@
33
import pytest
44

55
from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer
6+
from spikeinterface.core.generate import inject_some_split_units
67

78
job_kwargs = dict(n_jobs=-1)
89

10+
extensions = [
11+
"noise_levels",
12+
"random_spikes",
13+
"waveforms",
14+
"templates",
15+
"unit_locations",
16+
"spike_amplitudes",
17+
"spike_locations",
18+
"correlograms",
19+
"template_similarity",
20+
]
921

10-
def make_sorting_analyzer(sparse=True, num_units=5):
22+
23+
def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]):
24+
job_kwargs = dict(n_jobs=-1)
1125
recording, sorting = generate_ground_truth_recording(
12-
durations=[300.0],
26+
durations=durations,
1327
sampling_frequency=30000.0,
1428
num_channels=4,
1529
num_units=num_units,
@@ -23,23 +37,51 @@ def make_sorting_analyzer(sparse=True, num_units=5):
2337
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
2438
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)
2539

26-
sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse)
27-
sorting_analyzer.compute("random_spikes")
28-
sorting_analyzer.compute("waveforms", **job_kwargs)
29-
sorting_analyzer.compute("templates")
30-
sorting_analyzer.compute("noise_levels")
31-
# sorting_analyzer.compute("principal_components")
32-
# sorting_analyzer.compute("template_similarity")
33-
# sorting_analyzer.compute("quality_metrics", metric_names=["snr"])
40+
sorting_analyzer = create_sorting_analyzer(
41+
sorting=sorting, recording=recording, format="memory", sparse=sparse, **job_kwargs
42+
)
43+
sorting_analyzer.compute(extensions, **job_kwargs)
3444

3545
return sorting_analyzer
3646

3747

48+
def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num_split=2):
49+
job_kwargs = dict(n_jobs=-1)
50+
sorting = sorting_analyzer.sorting
51+
52+
split_ids = sorting.unit_ids[:num_unit_splitted]
53+
sorting_with_split, other_ids = inject_some_split_units(
54+
sorting,
55+
split_ids=split_ids,
56+
num_split=num_split,
57+
output_ids=True,
58+
seed=42,
59+
)
60+
61+
sorting_analyzer_with_splits = create_sorting_analyzer(
62+
sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True
63+
)
64+
sorting_analyzer_with_splits.compute(extensions, **job_kwargs)
65+
66+
return sorting_analyzer_with_splits, num_unit_splitted, other_ids
67+
68+
3869
@pytest.fixture(scope="module")
3970
def sorting_analyzer_for_curation():
4071
return make_sorting_analyzer(sparse=True)
4172

4273

74+
@pytest.fixture(scope="module")
75+
def sorting_analyzer_multi_segment_for_curation():
76+
return make_sorting_analyzer(sparse=True, durations=[50.0, 30.0])
77+
78+
79+
@pytest.fixture(scope="module")
80+
def sorting_analyzer_with_splits():
81+
sorting_analyzer = make_sorting_analyzer(sparse=True, durations=[50.0])
82+
return make_sorting_analyzer_with_splits(sorting_analyzer)
83+
84+
4385
if __name__ == "__main__":
4486
sorting_analyzer = make_sorting_analyzer(sparse=False)
4587
print(sorting_analyzer)

src/spikeinterface/curation/tests/test_auto_merge.py

Lines changed: 22 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,25 @@
22

33

44
from spikeinterface.core import create_sorting_analyzer
5-
from spikeinterface.core.generate import inject_some_split_units
65
from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units
76
from spikeinterface.generation import split_sorting_by_times
87

98

10-
from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
9+
from spikeinterface.curation.tests.common import (
10+
make_sorting_analyzer,
11+
sorting_analyzer_for_curation,
12+
sorting_analyzer_with_splits,
13+
sorting_analyzer_multi_segment_for_curation,
14+
)
1115

1216

1317
@pytest.mark.parametrize(
1418
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None]
1519
)
16-
def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
17-
18-
print(sorting_analyzer_for_curation)
19-
sorting = sorting_analyzer_for_curation.sorting
20-
recording = sorting_analyzer_for_curation.recording
21-
num_unit_splited = 1
22-
num_split = 2
23-
24-
split_ids = sorting.unit_ids[:num_unit_splited]
25-
sorting_with_split, other_ids = inject_some_split_units(
26-
sorting,
27-
split_ids=split_ids,
28-
num_split=num_split,
29-
output_ids=True,
30-
seed=42,
31-
)
20+
def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):
3221

3322
job_kwargs = dict(n_jobs=-1)
34-
35-
sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory")
36-
sorting_analyzer.compute(
37-
[
38-
"random_spikes",
39-
"waveforms",
40-
"templates",
41-
"unit_locations",
42-
"spike_amplitudes",
43-
"spike_locations",
44-
"correlograms",
45-
"template_similarity",
46-
],
47-
**job_kwargs,
48-
)
23+
sorting_analyzer, num_unit_splitted, other_ids = sorting_analyzer_with_splits
4924

5025
if preset is not None:
5126
# do not resolve graph for checking true pairs
@@ -67,7 +42,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
6742
**job_kwargs,
6843
)
6944
if preset == "x_contaminations":
70-
assert len(merge_unit_groups) == num_unit_splited
45+
assert len(merge_unit_groups) == num_unit_splitted
7146
for true_pair in other_ids.values():
7247
true_pair = tuple(true_pair)
7348
assert true_pair in merge_unit_groups
@@ -83,56 +58,19 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
8358
)
8459

8560

86-
# DEBUG
87-
# import matplotlib.pyplot as plt
88-
# from spikeinterface.curation.auto_merge import normalize_correlogram
89-
# templates_diff = outs['templates_diff']
90-
# correlogram_diff = outs['correlogram_diff']
91-
# bins = outs['bins']
92-
# correlograms_smoothed = outs['correlograms_smoothed']
93-
# correlograms = outs['correlograms']
94-
# win_sizes = outs['win_sizes']
95-
96-
# fig, ax = plt.subplots()
97-
# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05))
98-
99-
# fig, ax = plt.subplots()
100-
# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05))
101-
102-
# m = correlograms.shape[2] // 2
103-
104-
# for unit_id1, unit_id2 in merge_unit_groups[:5]:
105-
# unit_ind1 = sorting_with_split.id_to_index(unit_id1)
106-
# unit_ind2 = sorting_with_split.id_to_index(unit_id2)
107-
108-
# bins2 = bins[:-1] + np.mean(np.diff(bins))
109-
# fig, axs = plt.subplots(ncols=3)
110-
# ax = axs[0]
111-
# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b')
112-
# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r')
113-
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b')
114-
# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r')
115-
116-
# ax.set_title(f'{unit_id1} {unit_id2}')
117-
# ax = axs[1]
118-
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g')
119-
120-
# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :])
121-
# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :])
122-
# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :])
123-
124-
# ax = axs[2]
125-
# ax.plot(bins2, auto_corr1, color='b')
126-
# ax.plot(bins2, auto_corr2, color='r')
127-
# ax.plot(bins2, cross_corr, color='g')
128-
129-
# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b')
130-
# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b')
131-
# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r')
132-
# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r')
133-
134-
# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}')
135-
# plt.show()
61+
@pytest.mark.parametrize(
62+
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"]
63+
)
64+
def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_for_curation, preset):
65+
job_kwargs = dict(n_jobs=-1)
66+
sorting_analyzer = sorting_analyzer_multi_segment_for_curation
67+
print(sorting_analyzer)
68+
69+
merge_unit_groups = compute_merge_unit_groups(
70+
sorting_analyzer,
71+
preset=preset,
72+
**job_kwargs,
73+
)
13674

13775

13876
def test_auto_merge_units(sorting_analyzer_for_curation):

0 commit comments

Comments
 (0)