22
33
44from spikeinterface .core import create_sorting_analyzer
5- from spikeinterface .core .generate import inject_some_split_units
65from spikeinterface .curation import compute_merge_unit_groups , auto_merge_units
76from spikeinterface .generation import split_sorting_by_times
87
98
10- from spikeinterface .curation .tests .common import make_sorting_analyzer , sorting_analyzer_for_curation
9+ from spikeinterface .curation .tests .common import (
10+ make_sorting_analyzer ,
11+ sorting_analyzer_for_curation ,
12+ sorting_analyzer_with_splits ,
13+ sorting_analyzer_multi_segment_for_curation ,
14+ )
1115
1216
1317@pytest .mark .parametrize (
1418 "preset" , ["x_contaminations" , "feature_neighbors" , "temporal_splits" , "similarity_correlograms" , None ]
1519)
16- def test_compute_merge_unit_groups (sorting_analyzer_for_curation , preset ):
17-
18- print (sorting_analyzer_for_curation )
19- sorting = sorting_analyzer_for_curation .sorting
20- recording = sorting_analyzer_for_curation .recording
21- num_unit_splited = 1
22- num_split = 2
23-
24- split_ids = sorting .unit_ids [:num_unit_splited ]
25- sorting_with_split , other_ids = inject_some_split_units (
26- sorting ,
27- split_ids = split_ids ,
28- num_split = num_split ,
29- output_ids = True ,
30- seed = 42 ,
31- )
20+ def test_compute_merge_unit_groups (sorting_analyzer_with_splits , preset ):
3221
3322 job_kwargs = dict (n_jobs = - 1 )
34-
35- sorting_analyzer = create_sorting_analyzer (sorting_with_split , recording , format = "memory" )
36- sorting_analyzer .compute (
37- [
38- "random_spikes" ,
39- "waveforms" ,
40- "templates" ,
41- "unit_locations" ,
42- "spike_amplitudes" ,
43- "spike_locations" ,
44- "correlograms" ,
45- "template_similarity" ,
46- ],
47- ** job_kwargs ,
48- )
23+ sorting_analyzer , num_unit_splitted , other_ids = sorting_analyzer_with_splits
4924
5025 if preset is not None :
5126 # do not resolve graph for checking true pairs
@@ -67,7 +42,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
6742 ** job_kwargs ,
6843 )
6944 if preset == "x_contaminations" :
70- assert len (merge_unit_groups ) == num_unit_splited
45+ assert len (merge_unit_groups ) == num_unit_splitted
7146 for true_pair in other_ids .values ():
7247 true_pair = tuple (true_pair )
7348 assert true_pair in merge_unit_groups
@@ -83,56 +58,19 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
8358 )
8459
8560
86- # DEBUG
87- # import matplotlib.pyplot as plt
88- # from spikeinterface.curation.auto_merge import normalize_correlogram
89- # templates_diff = outs['templates_diff']
90- # correlogram_diff = outs['correlogram_diff']
91- # bins = outs['bins']
92- # correlograms_smoothed = outs['correlograms_smoothed']
93- # correlograms = outs['correlograms']
94- # win_sizes = outs['win_sizes']
95-
96- # fig, ax = plt.subplots()
97- # ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05))
98-
99- # fig, ax = plt.subplots()
100- # ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05))
101-
102- # m = correlograms.shape[2] // 2
103-
104- # for unit_id1, unit_id2 in merge_unit_groups[:5]:
105- # unit_ind1 = sorting_with_split.id_to_index(unit_id1)
106- # unit_ind2 = sorting_with_split.id_to_index(unit_id2)
107-
108- # bins2 = bins[:-1] + np.mean(np.diff(bins))
109- # fig, axs = plt.subplots(ncols=3)
110- # ax = axs[0]
111- # ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b')
112- # ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r')
113- # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b')
114- # ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r')
115-
116- # ax.set_title(f'{unit_id1} {unit_id2}')
117- # ax = axs[1]
118- # ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g')
119-
120- # auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :])
121- # auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :])
122- # cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :])
123-
124- # ax = axs[2]
125- # ax.plot(bins2, auto_corr1, color='b')
126- # ax.plot(bins2, auto_corr2, color='r')
127- # ax.plot(bins2, cross_corr, color='g')
128-
129- # ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b')
130- # ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b')
131- # ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r')
132- # ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r')
133-
134- # ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}')
135- # plt.show()
61+ @pytest .mark .parametrize (
62+ "preset" , ["x_contaminations" , "feature_neighbors" , "temporal_splits" , "similarity_correlograms" ]
63+ )
64+ def test_compute_merge_unit_groups_multi_segment (sorting_analyzer_multi_segment_for_curation , preset ):
65+ job_kwargs = dict (n_jobs = - 1 )
66+ sorting_analyzer = sorting_analyzer_multi_segment_for_curation
67+ print (sorting_analyzer )
68+
69+ merge_unit_groups = compute_merge_unit_groups (
70+ sorting_analyzer ,
71+ preset = preset ,
72+ ** job_kwargs ,
73+ )
13674
13775
13876def test_auto_merge_units (sorting_analyzer_for_curation ):
0 commit comments