@@ -94,7 +94,7 @@ def _merge_extension_data(
9494 new_sorting_analyzer .sparsity .mask [keep , :], new_unit_ids , new_sorting_analyzer .channel_ids
9595 )
9696
97- new_similarity = compute_similarity_with_templates_array (
97+ new_similarity , _ = compute_similarity_with_templates_array (
9898 new_templates_array ,
9999 all_templates_array ,
100100 method = self .params ["method" ],
@@ -146,7 +146,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
146146 new_sorting_analyzer .sparsity .mask [keep , :], new_unit_ids_f , new_sorting_analyzer .channel_ids
147147 )
148148
149- new_similarity = compute_similarity_with_templates_array (
149+ new_similarity , _ = compute_similarity_with_templates_array (
150150 new_templates_array ,
151151 all_templates_array ,
152152 method = self .params ["method" ],
@@ -188,7 +188,7 @@ def _run(self, verbose=False):
188188 self .sorting_analyzer , return_in_uV = self .sorting_analyzer .return_in_uV
189189 )
190190 sparsity = self .sorting_analyzer .sparsity
191- similarity = compute_similarity_with_templates_array (
191+ similarity , _ = compute_similarity_with_templates_array (
192192 templates_array ,
193193 templates_array ,
194194 method = self .params ["method" ],
@@ -393,7 +393,13 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi
393393
394394
395395def compute_similarity_with_templates_array (
396- templates_array , other_templates_array , method , support = "union" , num_shifts = 0 , sparsity = None , other_sparsity = None
396+ templates_array ,
397+ other_templates_array ,
398+ method ,
399+ support = "union" ,
400+ num_shifts = 0 ,
401+ sparsity = None ,
402+ other_sparsity = None ,
397403):
398404
399405 if method == "cosine_similarity" :
@@ -432,10 +438,11 @@ def compute_similarity_with_templates_array(
432438 templates_array , other_templates_array , num_shifts , method , sparsity_mask , other_sparsity_mask , support = support
433439 )
434440
441+ lags = np .argmin (distances , axis = 0 ) - num_shifts
435442 distances = np .min (distances , axis = 0 )
436443 similarity = 1 - distances
437444
438- return similarity
445+ return similarity , lags
439446
440447
441448def compute_template_similarity_by_pair (
@@ -445,7 +452,7 @@ def compute_template_similarity_by_pair(
445452 templates_array_2 = get_dense_templates_array (sorting_analyzer_2 , return_in_uV = True )
446453 sparsity_1 = sorting_analyzer_1 .sparsity
447454 sparsity_2 = sorting_analyzer_2 .sparsity
448- similarity = compute_similarity_with_templates_array (
455+ similarity , _ = compute_similarity_with_templates_array (
449456 templates_array_1 ,
450457 templates_array_2 ,
451458 method = method ,
0 commit comments