diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 2d078c4d28..9cce5e3f23 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -74,7 +74,7 @@ "sigma_smooth_ms": 0.6, "adaptative_window_thresh": 0.5, }, - "template_similarity": {"template_diff_thresh": 0.25}, + "template_similarity": {"similarity_method": "l1", "template_diff_thresh": 0.25}, "presence_distance": {"presence_distance_thresh": 100}, "knn": {"k_nn": 10}, "cross_contamination": { @@ -310,7 +310,13 @@ def compute_merge_unit_groups( # STEP : check if potential merge with CC also have template similarity elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") - templates_similarity = template_similarity_ext.get_data() + if template_similarity_ext.params["method"] == params["similarity_method"]: + templates_similarity = template_similarity_ext.get_data() + else: + template_similarity_ext = sorting_analyzer.compute( + "template_similarity", method=params["similarity_method"], save=False + ) + templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff @@ -1057,17 +1063,29 @@ def compute_cross_contaminations(analyzer, pair_mask, cc_thresh, refractory_peri CC = np.zeros((n, n), dtype=np.float32) p_values = np.zeros((n, n), dtype=np.float32) - for unit_ind1 in range(len(unit_ids)): + if sorting.get_num_segments() > 1: + # for multi-segment sortings, we need to concatenate segments, + from spikeinterface import concatenate_sortings, select_segment_sorting + + sorting_list = [] + total_samples_list = [] + for segment_index in range(sorting.get_num_segments()): + sorting_list.append(select_segment_sorting(sorting, segment_index)) + total_samples_list.append(analyzer.get_num_samples(segment_index)) + # concatenate segments + sorting_concat = concatenate_sortings(sorting_list=sorting_list, total_samples_list=total_samples_list) + else: + sorting_concat = sorting + for unit_ind1 in range(len(unit_ids)): unit_id1 = unit_ids[unit_ind1] - spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1)) - + spike_train1 = np.array(sorting_concat.get_unit_spike_train(unit_id1)) for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)): if not pair_mask[unit_ind1, unit_ind2]: continue unit_id2 = unit_ids[unit_ind2] - spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2)) + spike_train2 = np.array(sorting_concat.get_unit_spike_train(unit_id2)) # Compuyting the cross-contamination difference if contaminations is not None: C1 = contaminations[unit_ind1] @@ -1194,8 +1212,8 @@ def presence_distance(sorting, unit1, unit2, bin_duration_s=2, bins=None, num_sa ns = num_samples[segment_index] bins = np.arange(0, ns, bin_size) - st1 = sorting.get_unit_spike_train(unit_id=unit1) - st2 = sorting.get_unit_spike_train(unit_id=unit2) + st1 = sorting.get_unit_spike_train(unit_id=unit1, segment_index=segment_index) + st2 = sorting.get_unit_spike_train(unit_id=unit2, segment_index=segment_index) h1, _ = np.histogram(st1, bins) h1 = h1.astype(float) diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 239353a93b..20ad84efa2 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -3,13 +3,27 @@ import pytest from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer +from spikeinterface.core.generate import inject_some_split_units job_kwargs = dict(n_jobs=-1) +extensions = [ + "noise_levels", + "random_spikes", + "waveforms", + "templates", + "unit_locations", + "spike_amplitudes", + "spike_locations", + "correlograms", + "template_similarity", +] -def make_sorting_analyzer(sparse=True, num_units=5): + +def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]): + job_kwargs = dict(n_jobs=-1) recording, sorting = generate_ground_truth_recording( - durations=[300.0], + durations=durations, sampling_frequency=30000.0, num_channels=4, num_units=num_units, @@ -23,23 +37,51 @@ def make_sorting_analyzer(sparse=True, num_units=5): recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) - sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("noise_levels") - # sorting_analyzer.compute("principal_components") - # sorting_analyzer.compute("template_similarity") - # sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer = create_sorting_analyzer( + sorting=sorting, recording=recording, format="memory", sparse=sparse, **job_kwargs + ) + sorting_analyzer.compute(extensions, **job_kwargs) return sorting_analyzer +def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num_split=2): + job_kwargs = dict(n_jobs=-1) + sorting = sorting_analyzer.sorting + + split_ids = sorting.unit_ids[:num_unit_splitted] + sorting_with_split, other_ids = inject_some_split_units( + sorting, + split_ids=split_ids, + num_split=num_split, + output_ids=True, + seed=42, + ) + + sorting_analyzer_with_splits = create_sorting_analyzer( + sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True + ) + sorting_analyzer_with_splits.compute(extensions, **job_kwargs) + + return sorting_analyzer_with_splits, num_unit_splitted, other_ids + + @pytest.fixture(scope="module") def sorting_analyzer_for_curation(): return make_sorting_analyzer(sparse=True) +@pytest.fixture(scope="module") +def sorting_analyzer_multi_segment_for_curation(): + return make_sorting_analyzer(sparse=True, durations=[50.0, 30.0]) + + +@pytest.fixture(scope="module") +def sorting_analyzer_with_splits(): + sorting_analyzer = make_sorting_analyzer(sparse=True, durations=[50.0]) + return make_sorting_analyzer_with_splits(sorting_analyzer) + + if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=False) print(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 83c942bb2c..4daf1118a2 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -2,50 +2,25 @@ from spikeinterface.core import create_sorting_analyzer -from spikeinterface.core.generate import inject_some_split_units from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units from spikeinterface.generation import split_sorting_by_times -from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.tests.common import ( + make_sorting_analyzer, + sorting_analyzer_for_curation, + sorting_analyzer_with_splits, + sorting_analyzer_multi_segment_for_curation, +) @pytest.mark.parametrize( "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): - - print(sorting_analyzer_for_curation) - sorting = sorting_analyzer_for_curation.sorting - recording = sorting_analyzer_for_curation.recording - num_unit_splited = 1 - num_split = 2 - - split_ids = sorting.unit_ids[:num_unit_splited] - sorting_with_split, other_ids = inject_some_split_units( - sorting, - split_ids=split_ids, - num_split=num_split, - output_ids=True, - seed=42, - ) +def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset): job_kwargs = dict(n_jobs=-1) - - sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory") - sorting_analyzer.compute( - [ - "random_spikes", - "waveforms", - "templates", - "unit_locations", - "spike_amplitudes", - "spike_locations", - "correlograms", - "template_similarity", - ], - **job_kwargs, - ) + sorting_analyzer, num_unit_splitted, other_ids = sorting_analyzer_with_splits if preset is not None: # do not resolve graph for checking true pairs @@ -67,7 +42,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): **job_kwargs, ) if preset == "x_contaminations": - assert len(merge_unit_groups) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splitted for true_pair in other_ids.values(): true_pair = tuple(true_pair) assert true_pair in merge_unit_groups @@ -83,56 +58,19 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): ) -# DEBUG -# import matplotlib.pyplot as plt -# from spikeinterface.curation.auto_merge import normalize_correlogram -# templates_diff = outs['templates_diff'] -# correlogram_diff = outs['correlogram_diff'] -# bins = outs['bins'] -# correlograms_smoothed = outs['correlograms_smoothed'] -# correlograms = outs['correlograms'] -# win_sizes = outs['win_sizes'] - -# fig, ax = plt.subplots() -# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05)) - -# fig, ax = plt.subplots() -# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05)) - -# m = correlograms.shape[2] // 2 - -# for unit_id1, unit_id2 in merge_unit_groups[:5]: -# unit_ind1 = sorting_with_split.id_to_index(unit_id1) -# unit_ind2 = sorting_with_split.id_to_index(unit_id2) - -# bins2 = bins[:-1] + np.mean(np.diff(bins)) -# fig, axs = plt.subplots(ncols=3) -# ax = axs[0] -# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b') -# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r') -# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b') -# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r') - -# ax.set_title(f'{unit_id1} {unit_id2}') -# ax = axs[1] -# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g') - -# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :]) -# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :]) -# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :]) - -# ax = axs[2] -# ax.plot(bins2, auto_corr1, color='b') -# ax.plot(bins2, auto_corr2, color='r') -# ax.plot(bins2, cross_corr, color='g') - -# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b') -# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b') -# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r') -# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r') - -# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}') -# plt.show() +@pytest.mark.parametrize( + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] +) +def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_for_curation, preset): + job_kwargs = dict(n_jobs=-1) + sorting_analyzer = sorting_analyzer_multi_segment_for_curation + print(sorting_analyzer) + + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + preset=preset, + **job_kwargs, + ) def test_auto_merge_units(sorting_analyzer_for_curation):