Skip to content
54 changes: 33 additions & 21 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"sigma_smooth_ms": 0.6,
"adaptative_window_thresh": 0.5,
},
"template_similarity": {"template_diff_thresh": 0.25},
"template_similarity": {"similarity_method": "l1", "template_diff_thresh": 0.25},
"presence_distance": {"presence_distance_thresh": 100},
"knn": {"k_nn": 10},
"cross_contamination": {
Expand Down Expand Up @@ -310,7 +310,13 @@ def compute_merge_unit_groups(
# STEP : check if potential merge with CC also have template similarity
elif step == "template_similarity":
template_similarity_ext = sorting_analyzer.get_extension("template_similarity")
templates_similarity = template_similarity_ext.get_data()
if template_similarity_ext.params["method"] == params["similarity_method"]:
templates_similarity = template_similarity_ext.get_data()
else:
template_similarity_ext = sorting_analyzer.compute(
"template_similarity", method=params["similarity_method"], save=False
)
templates_similarity = template_similarity_ext.get_data()
templates_diff = 1 - templates_similarity
pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"])
outs["templates_diff"] = templates_diff
Expand Down Expand Up @@ -1054,28 +1060,34 @@ def compute_cross_contaminations(analyzer, pair_mask, cc_thresh, refractory_peri
if pair_mask is None:
pair_mask = np.ones((n, n), dtype="bool")

CC = np.zeros((n, n), dtype=np.float32)
p_values = np.zeros((n, n), dtype=np.float32)

for unit_ind1 in range(len(unit_ids)):
num_segments = sorting.get_num_segments()
CC = np.zeros((num_segments, n, n), dtype=np.float32)
p_values = np.zeros((num_segments, n, n), dtype=np.float32)

unit_id1 = unit_ids[unit_ind1]
spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1))
for segment_index in range(num_segments):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of doing this, concatenate_sortings

for unit_ind1 in range(len(unit_ids)):
unit_id1 = unit_ids[unit_ind1]
spike_train1 = np.array(sorting.get_unit_spike_train(unit_id1, segment_index=segment_index))
for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)):
if not pair_mask[unit_ind1, unit_ind2]:
continue

for unit_ind2 in range(unit_ind1 + 1, len(unit_ids)):
if not pair_mask[unit_ind1, unit_ind2]:
continue
unit_id2 = unit_ids[unit_ind2]
spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2, segment_index=segment_index))
# Compuyting the cross-contamination difference
if contaminations is not None:
C1 = contaminations[unit_ind1]
else:
C1 = None
CC[segment_index, unit_ind1, unit_ind2], p_values[segment_index, unit_ind1, unit_ind2] = (
estimate_cross_contamination(
spike_train1, spike_train2, sf, n_frames, refractory_period, limit=cc_thresh, C1=C1
)
)

unit_id2 = unit_ids[unit_ind2]
spike_train2 = np.array(sorting.get_unit_spike_train(unit_id2))
# Compuyting the cross-contamination difference
if contaminations is not None:
C1 = contaminations[unit_ind1]
else:
C1 = None
CC[unit_ind1, unit_ind2], p_values[unit_ind1, unit_ind2] = estimate_cross_contamination(
spike_train1, spike_train2, sf, n_frames, refractory_period, limit=cc_thresh, C1=C1
)
# average over segments
CC = np.mean(CC, axis=0)
p_values = np.mean(p_values, axis=0)

return CC, p_values

Expand Down