Skip to content
Merged
58 changes: 35 additions & 23 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"sigma_smooth_ms": 0.6,
"adaptative_window_thresh": 0.5,
},
"template_similarity": {"template_diff_thresh": 0.25},
"template_similarity": {"similarity_method": "l1", "template_diff_thresh": 0.25},
"presence_distance": {"presence_distance_thresh": 100},
"knn": {"k_nn": 10},
"cross_contamination": {
Expand Down Expand Up @@ -310,7 +310,13 @@ def compute_merge_unit_groups(
# STEP : check if potential merge with CC also have template similarity
elif step == "template_similarity":
template_similarity_ext = sorting_analyzer.get_extension("template_similarity")
templates_similarity = template_similarity_ext.get_data()
if template_similarity_ext.params["method"] == params["similarity_method"]:
templates_similarity = template_similarity_ext.get_data()
else:
template_similarity_ext = sorting_analyzer.compute(
"template_similarity", method=params["similarity_method"], save=False
)
templates_similarity = template_similarity_ext.get_data()
templates_diff = 1 - templates_similarity
pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"])
outs["templates_diff"] = templates_diff
Expand Down Expand Up @@ -1054,28 +1060,34 @@ def compute_cross_contaminations(analyzer, pair_mask, cc_thresh, refractory_peri
if pair_mask is None:
pair_mask = np.ones((n, n), dtype="bool")

CC = np.zeros((n, n), dtype=np.float32)
p_values = np.zeros((n, n), dtype=np.float32)

for unit_ind1 in range(len(unit_ids)):
num_segments = sorting.get_num_segments()
CC = np.zeros((num_segments, n, n), dtype=np.float32)
p_values = np.zeros((num_segments, n, n), dtype=np.float32)

unit_id1 = unit_ids[unit_ind1]
spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
for segment_index in range(num_segments):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of doing this, concatenate_sortings

for unit_ind1 in range(len(unit_ids)):
unit_id1 = unit_ids[unit_ind1]
spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1, segment_index=segment_index))
for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)):
if not pair_mask[unit_ind1, unit_ind2]:
continue

for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)):
if not pair_mask[unit_ind1, unit_ind2]:
continue
unit_id2 = unit_ids[unit_ind2]
spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2, segment_index=segment_index))
# Compuyting the cross-contamination difference
if contaminations is not None:
C1 = contaminations[unit_ind1]
else:
C1 = None
CC[segment_index, unit_ind1, unit_ind2], p_values[segment_index, unit_ind1, unit_ind2] = (
estimate_cross_contamination(
spike_train1, spike_train2, sf, n_frames, refractory_period, limit=cc_thresh, C1=C1
)
)

unit_id2 = unit_ids[unit_ind2]
spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))
# Compuyting the cross-contamination difference
if contaminations is not None:
C1 = contaminations[unit_ind1]
else:
C1 = None
CC[unit_ind1, unit_ind2], p_values[unit_ind1, unit_ind2] = estimate_cross_contamination(
spike_train1, spike_train2, sf, n_frames, refractory_period, limit=cc_thresh, C1=C1
)
# average over segments
CC = np.mean(CC, axis=0)
p_values = np.mean(p_values, axis=0)

return CC, p_values

Expand Down Expand Up @@ -1194,8 +1206,8 @@ def presence_distance(sorting, unit1, unit2, bin_duration_s=2, bins=None, num_sa
ns = num_samples[segment_index]
bins = np.arange(0, ns, bin_size)

st1 = sorting.get_unit_spike_train(unit_id=unit1)
st2 = sorting.get_unit_spike_train(unit_id=unit2)
st1 = sorting.get_unit_spike_train(unit_id=unit1, segment_index=segment_index)
st2 = sorting.get_unit_spike_train(unit_id=unit2, segment_index=segment_index)

h1, _ = np.histogram(st1, bins)
h1 = h1.astype(float)
Expand Down
61 changes: 51 additions & 10 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@
import pytest

from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer
from spikeinterface.core.generate import inject_some_split_units

job_kwargs = dict(n_jobs=-1)

extensions = [
"random_spikes",
"waveforms",
"templates",
"unit_locations",
"spike_amplitudes",
"spike_locations",
"correlograms",
"template_similarity",
]

def make_sorting_analyzer(sparse=True, num_units=5):

def make_sorting_analyzer(sparse=True, num_units=5, durations=[100.0]):
job_kwargs = dict(n_jobs=-1)
recording, sorting = generate_ground_truth_recording(
durations=[300.0],
durations=durations,
sampling_frequency=30000.0,
num_channels=4,
num_units=num_units,
Expand All @@ -23,23 +36,51 @@ def make_sorting_analyzer(sparse=True, num_units=5):
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("noise_levels")
# sorting_analyzer.compute("principal_components")
# sorting_analyzer.compute("template_similarity")
# sorting_analyzer.compute("quality_metrics", metric_names=["snr"])
sorting_analyzer = create_sorting_analyzer(
sorting=sorting, recording=recording, format="memory", sparse=sparse, **job_kwargs
)
sorting_analyzer.compute(extensions, **job_kwargs)

return sorting_analyzer


def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num_split=2):
job_kwargs = dict(n_jobs=-1)
sorting = sorting_analyzer.sorting

split_ids = sorting.unit_ids[:num_unit_splitted]
sorting_with_split, other_ids = inject_some_split_units(
sorting,
split_ids=split_ids,
num_split=num_split,
output_ids=True,
seed=42,
)

sorting_analyzer_with_splits = create_sorting_analyzer(
sorting=sorting_with_split, recording=sorting_analyzer.recording, format="memory", sparse=True
)
sorting_analyzer_with_splits.compute(extensions, **job_kwargs)

return sorting_analyzer_with_splits, num_unit_splitted, other_ids


@pytest.fixture(scope="module")
def sorting_analyzer_for_curation():
return make_sorting_analyzer(sparse=True)


@pytest.fixture(scope="module")
def sorting_analyzer_multi_segment_for_curation():
return make_sorting_analyzer(sparse=True, durations=[50.0, 30.0])


@pytest.fixture(scope="module")
def sorting_analyzer_with_splits():
sorting_analyzer = make_sorting_analyzer(sparse=True, durations=[50.0])
return make_sorting_analyzer_with_splits(sorting_analyzer)


if __name__ == "__main__":
sorting_analyzer = make_sorting_analyzer(sparse=False)
print(sorting_analyzer)
106 changes: 22 additions & 84 deletions src/spikeinterface/curation/tests/test_auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,25 @@


from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core.generate import inject_some_split_units
from spikeinterface.curation import compute_merge_unit_groups, auto_merge_units
from spikeinterface.generation import split_sorting_by_times


from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
from spikeinterface.curation.tests.common import (
make_sorting_analyzer,
sorting_analyzer_for_curation,
sorting_analyzer_with_splits,
sorting_analyzer_multi_segment_for_curation,
)


@pytest.mark.parametrize(
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None]
)
def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):

print(sorting_analyzer_for_curation)
sorting = sorting_analyzer_for_curation.sorting
recording = sorting_analyzer_for_curation.recording
num_unit_splited = 1
num_split = 2

split_ids = sorting.unit_ids[:num_unit_splited]
sorting_with_split, other_ids = inject_some_split_units(
sorting,
split_ids=split_ids,
num_split=num_split,
output_ids=True,
seed=42,
)
def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):

job_kwargs = dict(n_jobs=-1)

sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory")
sorting_analyzer.compute(
[
"random_spikes",
"waveforms",
"templates",
"unit_locations",
"spike_amplitudes",
"spike_locations",
"correlograms",
"template_similarity",
],
**job_kwargs,
)
sorting_analyzer, num_unit_splitted, other_ids = sorting_analyzer_with_splits

if preset is not None:
# do not resolve graph for checking true pairs
Expand All @@ -67,7 +42,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
**job_kwargs,
)
if preset == "x_contaminations":
assert len(merge_unit_groups) == num_unit_splited
assert len(merge_unit_groups) == num_unit_splitted
for true_pair in other_ids.values():
true_pair = tuple(true_pair)
assert true_pair in merge_unit_groups
Expand All @@ -83,56 +58,19 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset):
)


# DEBUG
# import matplotlib.pyplot as plt
# from spikeinterface.curation.auto_merge import normalize_correlogram
# templates_diff = outs['templates_diff']
# correlogram_diff = outs['correlogram_diff']
# bins = outs['bins']
# correlograms_smoothed = outs['correlograms_smoothed']
# correlograms = outs['correlograms']
# win_sizes = outs['win_sizes']

# fig, ax = plt.subplots()
# ax.hist(correlogram_diff.flatten(), bins=np.arange(0, 1, 0.05))

# fig, ax = plt.subplots()
# ax.hist(templates_diff.flatten(), bins=np.arange(0, 1, 0.05))

# m = correlograms.shape[2] // 2

# for unit_id1, unit_id2 in merge_unit_groups[:5]:
# unit_ind1 = sorting_with_split.id_to_index(unit_id1)
# unit_ind2 = sorting_with_split.id_to_index(unit_id2)

# bins2 = bins[:-1] + np.mean(np.diff(bins))
# fig, axs = plt.subplots(ncols=3)
# ax = axs[0]
# ax.plot(bins2, correlograms[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms[unit_ind2, unit_ind2, :], color='r')
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind1, :], color='b')
# ax.plot(bins2, correlograms_smoothed[unit_ind2, unit_ind2, :], color='r')

# ax.set_title(f'{unit_id1} {unit_id2}')
# ax = axs[1]
# ax.plot(bins2, correlograms_smoothed[unit_ind1, unit_ind2, :], color='g')

# auto_corr1 = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind1, :])
# auto_corr2 = normalize_correlogram(correlograms_smoothed[unit_ind2, unit_ind2, :])
# cross_corr = normalize_correlogram(correlograms_smoothed[unit_ind1, unit_ind2, :])

# ax = axs[2]
# ax.plot(bins2, auto_corr1, color='b')
# ax.plot(bins2, auto_corr2, color='r')
# ax.plot(bins2, cross_corr, color='g')

# ax.axvline(bins2[m - win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m + win_sizes[unit_ind1]], color='b')
# ax.axvline(bins2[m - win_sizes[unit_ind2]], color='r')
# ax.axvline(bins2[m + win_sizes[unit_ind2]], color='r')

# ax.set_title(f'corr diff {correlogram_diff[unit_ind1, unit_ind2]} - temp diff {templates_diff[unit_ind1, unit_ind2]}')
# plt.show()
@pytest.mark.parametrize(
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"]
)
def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_for_curation, preset):
job_kwargs = dict(n_jobs=-1)
sorting_analyzer = sorting_analyzer_multi_segment_for_curation
print(sorting_analyzer)

merge_unit_groups = compute_merge_unit_groups(
sorting_analyzer,
preset=preset,
**job_kwargs,
)


def test_auto_merge_units(sorting_analyzer_for_curation):
Expand Down
Loading