From 66ae3446887990830234d0444c811d6a83361be7 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:17:17 +0200 Subject: [PATCH 01/47] Reducing memory footprint --- .../postprocessing/template_similarity.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cf0c72952b..31aeedbb24 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,7 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -232,15 +232,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + src = src_template[:, local_mask[j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1) if method == "l1": norm_i = np.sum(np.abs(src)) @@ -273,7 +274,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -304,7 +305,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -313,8 +315,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].flatten() - tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + src = src_template[:, local_mask[j]].flatten() + tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten() norm_i = 0 norm_j = 0 @@ -360,6 +362,34 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy + +def get_mask_for_sparse_template(template_index, + sparsity, + other_sparsity, + support="union") -> np.ndarray: + + other_num_templates = other_sparsity.shape[0] + num_channels = sparsity.shape[1] + + mask = np.ones((other_num_templates, num_channels), dtype=bool) + + if support == "intersection": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + elif support == "union": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + units_overlaps = np.sum(mask, axis=1) > 0 + mask = np.logical_or( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + mask[~units_overlaps] = False + + return mask + + def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): @@ -378,29 +408,17 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - num_templates = templates_array.shape[0] + #num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] - other_num_templates = other_templates_array.shape[0] - - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + #num_channels = templates_array.shape[2] + #other_num_templates = other_templates_array.shape[0] if sparsity is not None and other_sparsity is not None: - - # make the input more flexible with either The object or the array mask sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - - if support == "intersection": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) distances = np.min(distances, axis=0) similarity = 1 - distances From 11958d75fa29c94292376e69feaf2eedaeba46f0 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:33:49 +0200 Subject: [PATCH 02/47] WIP --- .../postprocessing/template_similarity.py | 68 +++++++------------ 1 file changed, 25 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 31aeedbb24..cf0c72952b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,7 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -232,16 +232,15 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, local_mask[j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1) + src = src_template[:, mask[i, j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) if method == "l1": norm_i = np.sum(np.abs(src)) @@ -274,7 +273,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -305,8 +304,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -315,8 +313,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, local_mask[j]].flatten() - tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten() + src = src_template[:, mask[i, j]].flatten() + tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() norm_i = 0 norm_j = 0 @@ -362,34 +360,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy - -def get_mask_for_sparse_template(template_index, - sparsity, - other_sparsity, - support="union") -> np.ndarray: - - other_num_templates = other_sparsity.shape[0] - num_channels = sparsity.shape[1] - - mask = np.ones((other_num_templates, num_channels), dtype=bool) - - if support == "intersection": - mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] - ) # shape (num_templates, other_num_templates, num_channels) - elif support == "union": - mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] - ) # shape (num_templates, other_num_templates, num_channels) - units_overlaps = np.sum(mask, axis=1) > 0 - mask = np.logical_or( - sparsity[template_index, :], other_sparsity[:, :] - ) # shape (num_templates, other_num_templates, num_channels) - mask[~units_overlaps] = False - - return mask - - def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): @@ -408,17 +378,29 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - #num_templates = templates_array.shape[0] + num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - #num_channels = templates_array.shape[2] - #other_num_templates = other_templates_array.shape[0] + num_channels = templates_array.shape[2] + other_num_templates = other_templates_array.shape[0] + + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) if sparsity is not None and other_sparsity is not None: + + # make the input more flexible with either The object or the array mask sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - + + if support == "intersection": + mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) + elif support == "union": + mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) + units_overlaps = np.sum(mask, axis=2) > 0 + mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) + mask[~units_overlaps] = False + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) distances = np.min(distances, axis=0) similarity = 1 - distances From 76fd5d19995c01d553c5dc8ce37f38ec2724d915 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:34:59 +0200 Subject: [PATCH 03/47] WIP --- .../postprocessing/template_similarity.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cf0c72952b..31aeedbb24 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,7 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -232,15 +232,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + src = src_template[:, local_mask[j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1) if method == "l1": norm_i = np.sum(np.abs(src)) @@ -273,7 +274,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -304,7 +305,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -313,8 +315,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and j < i: # no need exhaustive looping when same template continue - src = src_template[:, mask[i, j]].flatten() - tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + src = src_template[:, local_mask[j]].flatten() + tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten() norm_i = 0 norm_j = 0 @@ -360,6 +362,34 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy + +def get_mask_for_sparse_template(template_index, + sparsity, + other_sparsity, + support="union") -> np.ndarray: + + other_num_templates = other_sparsity.shape[0] + num_channels = sparsity.shape[1] + + mask = np.ones((other_num_templates, num_channels), dtype=bool) + + if support == "intersection": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + elif support == "union": + mask = np.logical_and( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + units_overlaps = np.sum(mask, axis=1) > 0 + mask = np.logical_or( + sparsity[template_index, :], other_sparsity[:, :] + ) # shape (num_templates, other_num_templates, num_channels) + mask[~units_overlaps] = False + + return mask + + def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): @@ -378,29 +408,17 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - num_templates = templates_array.shape[0] + #num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] - other_num_templates = other_templates_array.shape[0] - - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + #num_channels = templates_array.shape[2] + #other_num_templates = other_templates_array.shape[0] if sparsity is not None and other_sparsity is not None: - - # make the input more flexible with either The object or the array mask sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - - if support == "intersection": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) distances = np.min(distances, axis=0) similarity = 1 - distances From 2e1098a35b5e70cc2ff95a3ce2e00ec7d65a26e8 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 13:39:18 +0200 Subject: [PATCH 04/47] WIP --- .../postprocessing/template_similarity.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 31aeedbb24..1de81d6b7e 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -277,6 +277,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] + num_channels = sparsity.shape[1] other_num_templates = other_templates_array.shape[0] num_shifts_both_sides = 2 * num_shifts + 1 @@ -285,7 +286,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed - if same_array: # optimisation when array are the same because of symetry in shift shift_loop = list(range(-num_shifts, 1)) @@ -305,7 +305,28 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + + ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays + ## So we inline the function here + #local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) + + if support == "intersection": + local_mask = np.logical_and( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + elif support == "union": + local_mask = np.logical_and( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + units_overlaps = np.sum(local_mask, axis=1) > 0 + local_mask = np.logical_or( + sparsity[i], other_sparsity + ) # shape (num_templates, other_num_templates, num_channels) + local_mask[~units_overlaps] = False + + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -371,19 +392,19 @@ def get_mask_for_sparse_template(template_index, other_num_templates = other_sparsity.shape[0] num_channels = sparsity.shape[1] - mask = np.ones((other_num_templates, num_channels), dtype=bool) + mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) if support == "intersection": mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) elif support == "union": mask = np.logical_and( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) units_overlaps = np.sum(mask, axis=1) > 0 mask = np.logical_or( - sparsity[template_index, :], other_sparsity[:, :] + sparsity[template_index], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) mask[~units_overlaps] = False From a37b8f1f38c8cfc5f83baa527bfce3b610dcfdd6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 11:43:31 +0000 Subject: [PATCH 05/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/template_similarity.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1de81d6b7e..6ce24c2c00 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -208,7 +208,9 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): +def _compute_similarity_matrix_numpy( + templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" +): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -274,7 +276,9 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num import numba @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union"): + def _compute_similarity_matrix_numba( + templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + ): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] num_channels = sparsity.shape[1] @@ -305,11 +309,11 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] - + ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays ## So we inline the function here - #local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - + # local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) if support == "intersection": @@ -325,8 +329,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num sparsity[i], other_sparsity ) # shape (num_templates, other_num_templates, num_channels) local_mask[~units_overlaps] = False - - + overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount in range(len(overlapping_templates)): @@ -383,11 +386,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy - -def get_mask_for_sparse_template(template_index, - sparsity, - other_sparsity, - support="union") -> np.ndarray: +def get_mask_for_sparse_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: other_num_templates = other_sparsity.shape[0] num_channels = sparsity.shape[1] @@ -429,17 +428,19 @@ def compute_similarity_with_templates_array( assert ( templates_array.shape[2] == other_templates_array.shape[2] ), "The number of channels in the templates should be the same for both arrays" - #num_templates = templates_array.shape[0] + # num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - #num_channels = templates_array.shape[2] - #other_num_templates = other_templates_array.shape[0] + # num_channels = templates_array.shape[2] + # other_num_templates = other_templates_array.shape[0] if sparsity is not None and other_sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - + assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) + distances = _compute_similarity_matrix( + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support + ) distances = np.min(distances, axis=0) similarity = 1 - distances From 40b1f6c517487e6bf5b7fb6963a0aaff42f6c311 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 29 Sep 2025 16:03:37 +0200 Subject: [PATCH 06/47] Fixing tests --- .../postprocessing/template_similarity.py | 5 ++++- .../tests/test_template_similarity.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1de81d6b7e..b0a7445e2e 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -437,7 +437,10 @@ def compute_similarity_with_templates_array( if sparsity is not None and other_sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity - + else: + sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) + assert num_shifts < num_samples, "max_lag is too large" distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 9a25af444c..7633e8f3b5 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -107,10 +107,19 @@ def test_equal_results_numba(params): rng = np.random.default_rng(seed=2205) templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) - mask = np.ones((4, 2, 5), dtype=bool) - - result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) - result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + sparsity_mask = np.ones((4, 5), dtype=bool) + other_sparsity_mask = np.ones((2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, + other_templates_array, + sparsity=sparsity_mask, + other_sparsity=other_sparsity_mask, + **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, + other_templates_array, + sparsity=sparsity_mask, + other_sparsity=other_sparsity_mask, + **params) assert np.allclose(result_numpy, result_numba, 1e-3) From d7c2e890ecacaf25d0646de961e3a8423e6b364e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:06:48 +0000 Subject: [PATCH 07/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests/test_template_similarity.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 7633e8f3b5..62d4be2318 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -110,16 +110,12 @@ def test_equal_results_numba(params): sparsity_mask = np.ones((4, 5), dtype=bool) other_sparsity_mask = np.ones((2, 5), dtype=bool) - result_numpy = _compute_similarity_matrix_numba(templates_array, - other_templates_array, - sparsity=sparsity_mask, - other_sparsity=other_sparsity_mask, - **params) - result_numba = _compute_similarity_matrix_numpy(templates_array, - other_templates_array, - sparsity=sparsity_mask, - other_sparsity=other_sparsity_mask, - **params) + result_numpy = _compute_similarity_matrix_numba( + templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + ) + result_numba = _compute_similarity_matrix_numpy( + templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + ) assert np.allclose(result_numpy, result_numba, 1e-3) From b844f3e0beeb202c8cac374d4c783fd851c502ed Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 08:47:27 +0200 Subject: [PATCH 08/47] WIP --- src/spikeinterface/postprocessing/template_similarity.py | 7 +++++-- src/spikeinterface/sortingcomponents/clustering/merge.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0d1a8fccb5..090a91abcd 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -433,11 +433,14 @@ def compute_similarity_with_templates_array( # num_channels = templates_array.shape[2] # other_num_templates = other_templates_array.shape[0] - if sparsity is not None and other_sparsity is not None: + if sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity - other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity else: sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + + if other_sparsity is not None: + other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 9110fa37f0..b1956a0e12 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates( assert len(unit_ids) == templates_array.shape[0] from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - from scipy.sparse.csgraph import connected_components similarity = compute_similarity_with_templates_array( templates_array, From f8e3ba9445106ad71d6a980cd44d3a2751f937fc Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 08:47:27 +0200 Subject: [PATCH 09/47] WIP --- .../postprocessing/template_similarity.py | 21 +++++++++++-------- .../tests/test_template_similarity.py | 4 ++-- .../sortingcomponents/clustering/merge.py | 1 - 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0d1a8fccb5..65f75bbb3d 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -209,7 +209,7 @@ def _get_data(self): def _compute_similarity_matrix_numpy( - templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" ): num_templates = templates_array.shape[0] @@ -234,7 +234,7 @@ def _compute_similarity_matrix_numpy( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) + local_mask = get_mask_for_sparse_template(i, sparsity_mask, other_sparsity_mask, support=support) overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): @@ -277,11 +277,11 @@ def _compute_similarity_matrix_numpy( @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) def _compute_similarity_matrix_numba( - templates_array, other_templates_array, num_shifts, method, sparsity, other_sparsity, support="union" + templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union" ): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = sparsity.shape[1] + num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] num_shifts_both_sides = 2 * num_shifts + 1 @@ -318,15 +318,15 @@ def _compute_similarity_matrix_numba( if support == "intersection": local_mask = np.logical_and( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) elif support == "union": local_mask = np.logical_and( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) units_overlaps = np.sum(local_mask, axis=1) > 0 local_mask = np.logical_or( - sparsity[i], other_sparsity + sparsity_mask[i], other_sparsity_mask ) # shape (num_templates, other_num_templates, num_channels) local_mask[~units_overlaps] = False @@ -433,11 +433,14 @@ def compute_similarity_with_templates_array( # num_channels = templates_array.shape[2] # other_num_templates = other_templates_array.shape[0] - if sparsity is not None and other_sparsity is not None: + if sparsity is not None: sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity - other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity else: sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool) + + if other_sparsity is not None: + other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity + else: other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 62d4be2318..c6663445f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -111,10 +111,10 @@ def test_equal_results_numba(params): other_sparsity_mask = np.ones((2, 5), dtype=bool) result_numpy = _compute_similarity_matrix_numba( - templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params ) result_numba = _compute_similarity_matrix_numpy( - templates_array, other_templates_array, sparsity=sparsity_mask, other_sparsity=other_sparsity_mask, **params + templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params ) assert np.allclose(result_numpy, result_numba, 1e-3) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 9110fa37f0..b1956a0e12 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -541,7 +541,6 @@ def merge_peak_labels_from_templates( assert len(unit_ids) == templates_array.shape[0] from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - from scipy.sparse.csgraph import connected_components similarity = compute_similarity_with_templates_array( templates_array, From 0aa76a3b679597f3e1fe934784f181114e967a7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 07:01:33 +0000 Subject: [PATCH 10/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/tests/test_template_similarity.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index c6663445f8..fa7d19fcbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -111,10 +111,18 @@ def test_equal_results_numba(params): other_sparsity_mask = np.ones((2, 5), dtype=bool) result_numpy = _compute_similarity_matrix_numba( - templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, ) result_numba = _compute_similarity_matrix_numpy( - templates_array, other_templates_array, sparsity_mask=sparsity_mask, other_sparsity_mask=other_sparsity_mask, **params + templates_array, + other_templates_array, + sparsity_mask=sparsity_mask, + other_sparsity_mask=other_sparsity_mask, + **params, ) assert np.allclose(result_numpy, result_numba, 1e-3) From 9858fc63518e37161a2d0b65cf069dc3b06b6a14 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 30 Sep 2025 10:13:54 +0200 Subject: [PATCH 11/47] WIP --- .../sortingcomponents/clustering/circus.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index e1bee8e9ff..7a5297aedb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -200,7 +200,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_after, **job_kwargs_local, ) - sparse_mask2 = sparse_mask + + from spikeinterface.core.sparsity import compute_sparsity + sparse_mask2 = compute_sparsity( + templates, + method="snr", + amplitude_mode="peak_to_peak", + noise_levels=params["noise_levels"], + threshold=0.25, + ).mask + else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd From 341d98009cd8c4bfb87816818f85df8823e58697 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:16:39 +0000 Subject: [PATCH 12/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7a5297aedb..4555de8148 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -200,8 +200,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_after, **job_kwargs_local, ) - + from spikeinterface.core.sparsity import compute_sparsity + sparse_mask2 = compute_sparsity( templates, method="snr", From 76b9a7b2b409687b79e2b623a812c1e73167602b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 3 Oct 2025 09:04:31 +0200 Subject: [PATCH 13/47] WIP --- .../sortingcomponents/clustering/circus.py | 279 ------------------ 1 file changed, 279 deletions(-) delete mode 100644 src/spikeinterface/sortingcomponents/clustering/circus.py diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py deleted file mode 100644 index 4555de8148..0000000000 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -import importlib -from pathlib import Path - -import numpy as np -import random, string - -from spikeinterface.core import get_global_tmp_folder, Templates -from spikeinterface.core import get_global_tmp_folder -from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.tools import _get_optimal_n_jobs -from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd -from spikeinterface.sortingcomponents.clustering.merge import merge_peak_labels_from_templates -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel - - -class CircusClustering: - """ - Circus clustering is based on several local clustering achieved with a - divide-and-conquer strategy. It uses the `hdbscan` or `isosplit6` clustering algorithms to - perform the local clusterings with an iterative and greedy strategy. - More precisely, it first extracts waveforms from the recording, - then performs a Truncated SVD to reduce the dimensionality of the waveforms. - For every peak, it extracts the SVD features and performs local clustering, grouping the peaks - by channel indices. The clustering is done recursively, and the clusters are merged - based on a similarity metric. The final output is a set of labels for each peak, - indicating the cluster to which it belongs. - """ - - _default_params = { - "clusterer": "hdbscan", # 'isosplit6', 'hdbscan', 'isosplit' - "clusterer_kwargs": { - "min_cluster_size": 20, - "cluster_selection_epsilon": 0.5, - "cluster_selection_method": "leaf", - "allow_single_cluster": True, - }, - "cleaning_kwargs": {}, - "remove_mixtures": False, - "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "recursive_kwargs": { - "recursive": True, - "recursive_depth": 3, - "returns_split_count": True, - }, - "split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9}, - "radius_um": 100, - "neighbors_radius_um": 50, - "n_svd": 5, - "few_waveforms": None, - "ms_before": 2.0, - "ms_after": 2.0, - "seed": None, - "noise_threshold": 2, - "templates_from_svd": True, - "noise_levels": None, - "tmp_folder": None, - "do_merge_with_templates": True, - "merge_kwargs": { - "similarity_metric": "l1", - "num_shifts": 3, - "similarity_thresh": 0.8, - }, - "verbose": True, - "memory_limit": 0.25, - "debug": False, - } - - @classmethod - def main_function(cls, recording, peaks, params, job_kwargs=dict()): - - clusterer = params.get("clusterer", "hdbscan") - assert clusterer in [ - "isosplit6", - "hdbscan", - "isosplit", - ], "Circus clustering only supports isosplit6, isosplit or hdbscan" - if clusterer in ["isosplit6", "hdbscan"]: - have_dep = importlib.util.find_spec(clusterer) is not None - if not have_dep: - raise RuntimeError(f"using {clusterer} as a clusterer needs {clusterer} to be installed") - - d = params - verbose = d["verbose"] - - fs = recording.get_sampling_frequency() - ms_before = params["ms_before"] - ms_after = params["ms_after"] - radius_um = params["radius_um"] - neighbors_radius_um = params["neighbors_radius_um"] - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]).absolute() - - tmp_folder.mkdir(parents=True, exist_ok=True) - - # SVD for time compression - if params["few_waveforms"] is None: - few_peaks = select_peaks( - peaks, - recording=recording, - method="uniform", - seed=params["seed"], - n_peaks=10000, - margin=(nbefore, nafter), - ) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) - wfs = few_wfs[:, :, 0] - else: - offset = int(params["waveforms"]["ms_before"] * fs / 1000) - wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] - - # Ensure all waveforms have a positive max - wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] - - # Remove outliers - valid = np.argmax(np.abs(wfs), axis=1) == nbefore - wfs = wfs[valid] - - from sklearn.decomposition import TruncatedSVD - - svd_model = TruncatedSVD(params["n_svd"], random_state=params["seed"]) - svd_model.fit(wfs) - if params["debug"]: - features_folder = tmp_folder / "tsvd_features" - features_folder.mkdir(exist_ok=True) - else: - features_folder = None - - peaks_svd, sparse_mask, svd_model = extract_peaks_svd( - recording, - peaks, - ms_before=ms_before, - ms_after=ms_after, - svd_model=svd_model, - radius_um=radius_um, - folder=features_folder, - seed=params["seed"], - **job_kwargs, - ) - - neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um - - if params["debug"]: - np.save(features_folder / "sparse_mask.npy", sparse_mask) - np.save(features_folder / "peaks.npy", peaks) - - original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters - - split_kwargs = params["split_kwargs"].copy() - split_kwargs["neighbours_mask"] = neighbours_mask - split_kwargs["waveforms_sparse_mask"] = sparse_mask - split_kwargs["seed"] = params["seed"] - split_kwargs["min_size_split"] = 2 * params["clusterer_kwargs"].get("min_cluster_size", 50) - split_kwargs["clusterer_kwargs"] = params["clusterer_kwargs"] - split_kwargs["clusterer"] = params["clusterer"] - - if params["debug"]: - debug_folder = tmp_folder / "split" - else: - debug_folder = None - - peak_labels, _ = split_clusters( - original_labels, - recording, - {"peaks": peaks, "sparse_tsvd": peaks_svd}, - method="local_feature_clustering", - method_kwargs=split_kwargs, - debug_folder=debug_folder, - **params["recursive_kwargs"], - **job_kwargs, - ) - - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_in_uV=False, **job_kwargs) - - if not params["templates_from_svd"]: - from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording - - job_kwargs_local = job_kwargs.copy() - unit_ids = np.unique(peak_labels) - ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4 - job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"]) - templates = get_templates_from_peaks_and_recording( - recording, - peaks, - peak_labels, - ms_before, - ms_after, - **job_kwargs_local, - ) - - from spikeinterface.core.sparsity import compute_sparsity - - sparse_mask2 = compute_sparsity( - templates, - method="snr", - amplitude_mode="peak_to_peak", - noise_levels=params["noise_levels"], - threshold=0.25, - ).mask - - else: - from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - - templates, sparse_mask2 = get_templates_from_peaks_and_svd( - recording, - peaks, - peak_labels, - ms_before, - ms_after, - svd_model, - peaks_svd, - sparse_mask, - operator="median", - ) - - if params["do_merge_with_templates"]: - peak_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = merge_peak_labels_from_templates( - peaks, - peak_labels, - templates.unit_ids, - templates.templates_array, - sparse_mask2, - **params["merge_kwargs"], - ) - - templates = Templates( - templates_array=merge_template_array, - sampling_frequency=fs, - nbefore=templates.nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=new_unit_ids, - probe=recording.get_probe(), - is_in_uV=False, - ) - - labels = templates.unit_ids - - if params["debug"]: - templates_folder = tmp_folder / "dense_templates" - templates.to_zarr(folder_path=templates_folder) - - if params["remove_mixtures"]: - if verbose: - print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - - cleaning_job_kwargs = job_kwargs.copy() - cleaning_job_kwargs["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() - - labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - ) - - if verbose: - print("Kept %d non-duplicated clusters" % len(labels)) - else: - if verbose: - print("Kept %d raw clusters" % len(labels)) - - more_outs = dict( - svd_model=svd_model, - peaks_svd=peaks_svd, - peak_svd_sparse_mask=sparse_mask, - ) - return labels, peak_labels, more_outs From 6a29e3f3841562a43d93563b03bbecb20f3977bb Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 3 Oct 2025 09:25:02 +0200 Subject: [PATCH 14/47] Reducing memory footprint for large number of templates/channels --- .../sortingcomponents/matching/circus.py | 5 +++-- .../sortingcomponents/matching/wobble.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 7c0f7e3dae..2e8949d800 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -248,10 +248,11 @@ def _prepare_templates(self): else: sparsity = self.templates.sparsity.mask - units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) - self.units_overlaps = units_overlaps > 0 + #units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) self.unit_overlaps_indices = {} + self.units_overlaps = {} for i in range(self.num_templates): + self.units_overlaps[i] = np.sum(np.logical_and(sparsity[i, :], sparsity), axis=1) > 0 self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i]) templates_array = self.templates.get_dense_templates().copy() diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 5c15f3e9c3..20509569c7 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -278,10 +278,17 @@ def from_templates(cls, params, templates): Dataclass object for aggregating channel sparsity variables together. """ visible_channels = templates.sparsity.mask - unit_overlap = np.sum( - np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 - ) - unit_overlap = unit_overlap > 0 + num_templates = templates.get_dense_templates().shape[0] + unit_overlap = np.zeros((num_templates, num_templates), dtype=bool) + + for i in range(num_templates): + unit_overlap[i] = np.sum(np.logical_and(visible_channels[i], visible_channels), axis=1) > 0 + + #unit_overlap = np.sum( + # np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 + #) + #unit_overlap = unit_overlap > 0 + unit_overlap = np.repeat(unit_overlap, params.jitter_factor, axis=0) sparsity = cls(visible_channels=visible_channels, unit_overlap=unit_overlap) return sparsity From 5f0e02bd598b4462a569bc80010f1561c2c217f0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 6 Oct 2025 17:56:27 +0200 Subject: [PATCH 15/47] improve iterative_isosplit and remove warnings --- .../clustering/isosplit_isocut.py | 1 + .../clustering/iterative_isosplit.py | 6 +++++- .../clustering/splitting_tools.py | 16 +++++++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py index fa948c88d1..6501c3348f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py +++ b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py @@ -308,6 +308,7 @@ def isosplit( with warnings.catch_warnings(): # sometimes the kmeans do not found enought cluster which should not be an issue + warnings.simplefilter("ignore") _, labels = kmeans2(X, n_init, minit="points", seed=seed) labels = ensure_continuous_labels(labels) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index a65a6f59cc..15aee9ee3c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -46,7 +46,7 @@ class IterativeISOSPLITClustering: "isocut_threshold": 2.0, }, "min_size_split": 25, - "n_pca_features": 3, + "n_pca_features": 6, }, }, "merge_from_templates": { @@ -141,7 +141,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): features, method="local_feature_clustering", debug_folder=debug_folder, + job_kwargs=job_kwargs, + # job_kwargs=dict(n_jobs=1), + + **split_params, # method_kwargs=dict( # clusterer=clusterer, diff --git a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py index bcc6186f58..e38fecc35f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py @@ -301,7 +301,21 @@ def split( elif clusterer_method == "isosplit": from spikeinterface.sortingcomponents.clustering.isosplit_isocut import isosplit - possible_labels = isosplit(final_features, **clustering_kwargs) + min_cluster_size = clustering_kwargs["min_cluster_size"] + + # here the trick is that we do not except more than 4 to 5 clusters per iteration with a presplit of 5 + num_samples = final_features.shape[0] + n_init = int(num_samples / 5 * 5) + if n_init > (num_samples // min_cluster_size): + # avoid warning in isosplit when sample_size is too small + factor = min_cluster_size * 2 + n_init = max(1, num_samples // factor) + + clustering_kwargs_ = clustering_kwargs.copy() + clustering_kwargs_["n_init"] = n_init + + + possible_labels = isosplit(final_features, **clustering_kwargs_) # min_cluster_size = clusterer_kwargs.get("min_cluster_size", 25) # for i in np.unique(possible_labels): From 82223a92adff54a84b553243d0846f417df80a03 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 6 Oct 2025 18:00:17 +0200 Subject: [PATCH 16/47] "n_pca_features" 6 > 3 --- .../sortingcomponents/clustering/iterative_isosplit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 15aee9ee3c..604112bb82 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -46,7 +46,7 @@ class IterativeISOSPLITClustering: "isocut_threshold": 2.0, }, "min_size_split": 25, - "n_pca_features": 6, + "n_pca_features": 3, }, }, "merge_from_templates": { From b76552a3dcf559baa8f30755460eb82e0619bd93 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Oct 2025 09:39:12 +0200 Subject: [PATCH 17/47] iterative isosplit params --- .../sortingcomponents/clustering/splitting_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py index e38fecc35f..6c55321185 100644 --- a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py @@ -303,13 +303,13 @@ def split( min_cluster_size = clustering_kwargs["min_cluster_size"] - # here the trick is that we do not except more than 4 to 5 clusters per iteration with a presplit of 5 + # here the trick is that we do not except more than 4 to 5 clusters per iteration with a presplit of 10 num_samples = final_features.shape[0] - n_init = int(num_samples / 5 * 5) + n_init = 50 if n_init > (num_samples // min_cluster_size): # avoid warning in isosplit when sample_size is too small factor = min_cluster_size * 2 - n_init = max(1, num_samples // factor) + n_init = max(2, num_samples // factor) clustering_kwargs_ = clustering_kwargs.copy() clustering_kwargs_["n_init"] = n_init From 22aa5cd0716222ebc768276f7218c86cf6d89035 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 15 Oct 2025 17:35:12 +0200 Subject: [PATCH 18/47] oups --- .../sortingcomponents/clustering/iterative_isosplit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 604112bb82..41ee6b0d22 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -179,7 +179,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if params["merge_from_features"]: - merge_from_features_kwargs = params["merge_features_kwargs"].copy() + merge_from_features_kwargs = params["merge_from_features"].copy() merge_radius_um = merge_from_features_kwargs.pop("merge_radius_um") post_merge_label1, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_features( From 61a570e328d8000f251069027916f5239ba377f9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 15 Oct 2025 21:47:15 +0200 Subject: [PATCH 19/47] wip --- .../clustering/isosplit_isocut.py | 4 +++- .../clustering/iterative_isosplit.py | 3 +++ .../clustering/merging_tools.py | 2 ++ .../clustering/splitting_tools.py | 19 ++++++++++++------- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py index 6501c3348f..9b64a1eea7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py +++ b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py @@ -586,8 +586,10 @@ def compare_pairs(X, labels, pairs, centroids, covmats, min_cluster_size, isocut (inds2,) = np.nonzero(labels == label2) if (inds1.size > 0) and (inds2.size > 0): - if (inds1.size < min_cluster_size) and (inds2.size < min_cluster_size): + # if (inds1.size < min_cluster_size) and (inds2.size < min_cluster_size): + if (inds1.size < min_cluster_size) or (inds2.size < min_cluster_size): do_merge = True + # do_merge = False else: X1 = X[inds1, :] X2 = X[inds2, :] diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 41ee6b0d22..b60b76c51d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -47,6 +47,9 @@ class IterativeISOSPLITClustering: }, "min_size_split": 25, "n_pca_features": 3, + + # "projection_mode": "tsvd", + "projection_mode": "pca", }, }, "merge_from_templates": { diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 23ec9d8e4c..d75111019c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -389,12 +389,14 @@ def merge( from sklearn.decomposition import PCA tsvd = PCA(n_pca_features, whiten=True) + elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(n_pca_features, random_state=seed) feat = tsvd.fit_transform(feat) + else: feat = feat tsvd = None diff --git a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py index 6c55321185..23f405531f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py @@ -277,15 +277,20 @@ def split( from sklearn.decomposition import PCA tsvd = PCA(n_pca_features, whiten=True) + final_features = tsvd.fit_transform(flatten_features) elif projection_mode == "tsvd": from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(n_pca_features, random_state=seed) - - final_features = tsvd.fit_transform(flatten_features) + final_features = tsvd.fit_transform(flatten_features) + else: final_features = flatten_features tsvd = None + elif n_pca_features is None: + final_features = flatten_features + tsvd = None + if clusterer_method == "hdbscan": from hdbscan import HDBSCAN @@ -317,11 +322,11 @@ def split( possible_labels = isosplit(final_features, **clustering_kwargs_) - # min_cluster_size = clusterer_kwargs.get("min_cluster_size", 25) - # for i in np.unique(possible_labels): - # mask = possible_labels == i - # if np.sum(mask) < min_cluster_size: - # possible_labels[mask] = -1 + for i in np.unique(possible_labels): + mask = possible_labels == i + if np.sum(mask) < min_cluster_size: + possible_labels[mask] = -1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer_method == "isosplit6": # this use the official C++ isosplit6 from Jeremy Magland From 4ec84081d100829a4cf2047361710071285f7827 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 Oct 2025 18:20:58 +0200 Subject: [PATCH 20/47] various try on iterative_isosplit --- .../generation/splitting_tools.py | 518 +++++++++++++----- .../clustering/isosplit_isocut.py | 20 +- .../clustering/iterative_isosplit.py | 52 +- 3 files changed, 457 insertions(+), 133 deletions(-) diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 5bfc7048b5..7c2f239157 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -1,147 +1,415 @@ +from __future__ import annotations + +import warnings + +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + + import numpy as np -from spikeinterface.core.numpyextractors import NumpySorting -from spikeinterface.core.sorting_tools import spike_vector_to_indices +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader + +try: + import numba -def split_sorting_by_times( - sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None +except: + pass # isocut requires numba + +# important all DEBUG and matplotlib are left in the code intentionally + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="local_feature_clustering", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + debug_folder=None, + job_kwargs=None, ): """ - Fonction used to split a sorting based on the times of the units. This - might be used for benchmarking meta merging step (see components) + Run recusrsively (or not) in a multi process pool a local split method. Parameters ---------- - sorting_analyzer : A sortingAnalyzer object - The sortingAnalyzer object whose sorting should be splitted - splitting_probability : float, default 0.5 - probability of being splitted, for any cell in the provided sorting - partial_split_prob : float, default 0.95 - The percentage of spikes that will belong to pre/post splits - unit_ids : list of unit_ids, default None - The list of unit_ids to be splitted, if prespecified - min_snr : float, default=None - If specified, only cells with a snr higher than min_snr might be splitted - seed : int | None, default: None - The seed for random generator. + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features + method: str, default: "local_feature_clustering" + The method name + method_kwargs: dict, default: dict() + The method option + recursive: bool, default: False + Recursive or not + recursive_depth: None or int, default: None + If recursive=True, then this is the max split per spikes + returns_split_count: bool, default: False + Optionally return the split count vector. Same size as labels Returns ------- - new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned """ - sorting = sorting_analyzer.sorting - rng = np.random.RandomState(seed) - fs = sorting_analyzer.sampling_frequency + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs.get("mp_context", None) + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + recursion_level = 1 + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(method=mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + jobs = [] + + if debug_folder is not None: + if debug_folder.exists(): + import shutil + + shutil.rmtree(debug_folder) + debug_folder.mkdir(parents=True, exist_ok=True) + + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if debug_folder is not None: + sub_folder = str(debug_folder / f"split_{label}") + + else: + sub_folder = None + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level, sub_folder)) + + if progress_bar: + pbar = tqdm(desc=f"split_clusters with {method}", total=len(labels_set)) + + for res in jobs: + is_split, local_labels, peak_indices, sub_folder = res.result() + + if progress_bar: + pbar.update(1) + + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + split_count[peak_indices] += 1 + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + recursion_level = np.max(split_count[peak_indices]) + if recursive_depth is not None: + # stop recursivity when recursive_depth is reach + extra_ball = recursion_level < recursive_depth + else: + # recursive always + extra_ball = True - nb_splits = int(splitting_probability * len(sorting.unit_ids)) - if unit_ids is None: - select_from = sorting.unit_ids - if min_snr is not None: - if sorting_analyzer.get_extension("noise_levels") is None: - sorting_analyzer.compute("noise_levels") - if sorting_analyzer.get_extension("quality_metrics") is None: - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if sub_folder is not None: + new_sub_folder = sub_folder + f"_{label}" + else: + new_sub_folder = None + if peak_indices.size > 0: + # print('Relaunched', label, len(peak_indices), recursion_level) + jobs.append( + pool.submit(split_function_wrapper, peak_indices, recursion_level, new_sub_folder) + ) + if progress_bar: + pbar.total += 1 - snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values - select_from = select_from[snr > min_snr] + if progress_bar: + pbar.close() + del pbar - to_split_ids = rng.choice(select_from, nb_splits, replace=False) + if returns_split_count: + return peak_labels, split_count else: - to_split_ids = unit_ids - - spikes = sorting_analyzer.sorting.to_spike_vector(concatenated=False) - new_spikes = spikes[0].copy() - max_index = np.max(new_spikes["unit_index"]) - new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) - spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) - splitted_pairs = [] - for unit_id in to_split_ids: - ind_mask = spike_indices[0][unit_id] - m = np.median(spikes[0][ind_mask]["sample_index"]) - time_mask = spikes[0][ind_mask]["sample_index"] > m - mask = time_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) - new_index = int(unit_id) * np.ones(len(mask), dtype=bool) - new_index[mask] = max_index + 1 - new_spikes["unit_index"][ind_mask] = new_index - new_unit_ids += [max_index + 1] - splitted_pairs += [(unit_id, new_unit_ids[-1])] - max_index += 1 - - new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) - return new_sorting, splitted_pairs - - -def split_sorting_by_amplitudes( - sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker ): - """ - Fonction used to split a sorting based on the amplitudes of the units. This - might be used for benchmarking meta merging step (see components) + global _ctx + _ctx = {} - Parameters - ---------- - sorting_analyzer : A sortingAnalyzer object - The sortingAnalyzer object whose sorting should be splitted - splitting_probability : float, default 0.5 - probability of being splitted, for any cell in the provided sorting - partial_split_prob : float, default 0.95 - The percentage of spikes that will belong to pre/post splits - unit_ids : list of unit_ids, default None - The list of unit_ids to be splitted, if prespecified - min_snr : float, default=None - If specified, only cells with a snr higher than min_snr might be splitted - seed : int | None, default: None - The seed for random generator. + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_worker"] = max_threads_per_worker + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] - Returns - ------- - new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted + +def split_function_wrapper(peak_indices, recursion_level, debug_folder): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, debug_folder, **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices, debug_folder + + +class LocalFeatureClustering: """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf - if sorting_analyzer.get_extension("spike_amplitudes") is None: - sorting_analyzer.compute("spike_amplitudes") - - rng = np.random.RandomState(seed) - fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") - spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) - new_spikes = spikes[0].copy() - amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() - nb_splits = int(splitting_probability * len(sorting_analyzer.sorting.unit_ids)) - - if unit_ids is None: - select_from = sorting_analyzer.sorting.unit_ids - if min_snr is not None: - if sorting_analyzer.get_extension("noise_levels") is None: - sorting_analyzer.compute("noise_levels") - if sorting_analyzer.get_extension("quality_metrics") is None: - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) - - snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values - select_from = select_from[snr > min_snr] - to_split_ids = rng.choice(select_from, nb_splits, replace=False) - else: - to_split_ids = unit_ids - - max_index = np.max(new_spikes["unit_index"]) - new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) - splitted_pairs = [] - spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) - - for unit_id in to_split_ids: - ind_mask = spike_indices[0][unit_id] - thresh = np.median(amplitudes[ind_mask]) - amplitude_mask = amplitudes[ind_mask] > thresh - mask = amplitude_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) - new_index = int(unit_id) * np.ones(len(mask)) - new_index[mask] = max_index + 1 - new_spikes["unit_index"][ind_mask] = new_index - new_unit_ids += [max_index + 1] - splitted_pairs += [(unit_id, new_unit_ids[-1])] - max_index += 1 - - new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) - return new_sorting, splitted_pairs + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isosplit) + """ + + name = "local_feature_clustering" + + @staticmethod + def split( + peak_indices, + peaks, + features, + recursion_level=1, + debug_folder=None, + clusterer={"method": "hdbscan", "min_cluster_size": 25, "min_samples": 5}, + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + n_pca_features=3, + seed=None, + projection_mode="tsvd", + minimum_overlap_ratio=0.25, + ): + + clustering_kwargs = clusterer.copy() + clusterer_method = clustering_kwargs.pop("method") + + assert clusterer_method in ["hdbscan", "isosplit", "isosplit6"] + + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + + target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0)) + num_intersection = len(target_intersection_channels) + num_union = len(target_union_channels) + + # TODO fix this a better way, this when cluster have too few overlapping channels + if (num_intersection / num_union) < minimum_overlap_ratio: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + is_split = False + + if isinstance(n_pca_features, float): + assert 0 < n_pca_features < 1, "n_components should be in ]0, 1[" + nb_dimensions = min(flatten_features.shape[0], flatten_features.shape[1]) + if projection_mode == "pca": + from sklearn.decomposition import PCA + + tsvd = PCA(nb_dimensions, whiten=True) + elif projection_mode == "tsvd": + from sklearn.decomposition import TruncatedSVD + + tsvd = TruncatedSVD(nb_dimensions, random_state=seed) + final_features = tsvd.fit_transform(flatten_features) + n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1 + final_features = final_features[:, :n_explain] + n_pca_features = final_features.shape[1] + elif isinstance(n_pca_features, int): + if flatten_features.shape[1] > n_pca_features: + if projection_mode == "pca": + from sklearn.decomposition import PCA + + tsvd = PCA(n_pca_features, whiten=True) + final_features = tsvd.fit_transform(flatten_features) + elif projection_mode == "tsvd": + from sklearn.decomposition import TruncatedSVD + + tsvd = TruncatedSVD(n_pca_features, random_state=seed) + final_features = tsvd.fit_transform(flatten_features) + + else: + final_features = flatten_features + tsvd = None + elif n_pca_features is None: + final_features = flatten_features + tsvd = None + + + if clusterer_method == "hdbscan": + from hdbscan import HDBSCAN + + clustering_kwargs.update(core_dist_n_jobs=1) + clust = HDBSCAN(**clustering_kwargs) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + clust.fit(final_features) + possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + del clust + elif clusterer_method == "isosplit": + from spikeinterface.sortingcomponents.clustering.isosplit_isocut import isosplit + + min_cluster_size = clustering_kwargs["min_cluster_size"] + + # here the trick is that we do not except more than 4 to 5 clusters + num_samples = final_features.shape[0] + # n_init = 50 + n_init = 20 + if n_init > (num_samples // min_cluster_size): + # avoid warning in isosplit when sample_size is too small + factor = min_cluster_size * 4 + n_init = max(2, num_samples // factor) + + clustering_kwargs_ = clustering_kwargs.copy() + clustering_kwargs_["n_init"] = n_init + + + possible_labels = isosplit(final_features, **clustering_kwargs_) + + for i in np.unique(possible_labels): + mask = possible_labels == i + if np.sum(mask) < min_cluster_size: + possible_labels[mask] = -1 + + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + elif clusterer_method == "isosplit6": + # this use the official C++ isosplit6 from Jeremy Magland + import isosplit6 + + min_cluster_size = clustering_kwargs.get("min_cluster_size", 25) + possible_labels = isosplit6.isosplit6(final_features) + for i in np.unique(possible_labels): + mask = possible_labels == i + if np.sum(mask) < min_cluster_size: + possible_labels[mask] = -1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + else: + raise ValueError(f"wrong clusterer {clusterer}. Possible options are 'hdbscan/isosplit/isosplit6'.") + + DEBUG = False # only for Sam or dirty hacking + # DEBUG = True + # DEBUG = recursion_level > 2 + + if debug_folder is not None or DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.colormaps["tab10"].resampled(len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fig, axs = plt.subplots(nrows=4) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + if final_features.shape[1] == 1: + final_features = np.hstack((final_features, np.zeros_like(final_features))) + + sl = slice(None, None, 100) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask], final_features[:, 1][mask], s=5, color=colors[k]) + if k > -1: + centroid = final_features[:, :2][mask].mean(axis=0) + ax.text(centroid[0], centroid[1], f"Label {k}", fontsize=10, color="k") + ax = axs[1] + ax.plot(flatten_wfs[mask].T, color=colors[k], alpha=0.1) + if k > -1: + ax.plot(np.median(flatten_wfs[mask].T, axis=1), color=colors[k], lw=2) + ax.set_xlabel(f"PCA features") + + ax = axs[3] + if n_pca_features == 1: + bins = np.linspace(final_features[:, 0].min(), final_features[:, 0].max(), 100) + ax.hist(final_features[mask, 0], bins, color=colors[k], alpha=0.1) + else: + ax.plot(final_features[mask].T, color=colors[k], alpha=0.1) + if k > -1 and n_pca_features > 1: + ax.plot(np.median(final_features[mask].T, axis=1), color=colors[k], lw=2) + ax.set_xlabel(f"Projected PCA features, dim{final_features.shape[1]}") + + if tsvd is not None: + ax = axs[2] + sorted_components = np.argsort(tsvd.explained_variance_ratio_)[::-1] + ax.plot(tsvd.explained_variance_ratio_[sorted_components], c="k") + del tsvd + + ymin, ymax = ax.get_ylim() + ax.plot([n_pca_features, n_pca_features], [ymin, ymax], "k--") + + axs[0].set_title(f"{clusterer} level={recursion_level}") + if not DEBUG: + fig.savefig(str(debug_folder) + ".png") + plt.close(fig) + else: + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + LocalFeatureClustering, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py index 9b64a1eea7..e00b2e231a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py +++ b/src/spikeinterface/sortingcomponents/clustering/isosplit_isocut.py @@ -341,6 +341,22 @@ def isosplit( iteration_number = 0 + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(ncols=2) + # # cmap = plt.colormaps['nipy_spectral'].resampled(active_labels.size) + # cmap = plt.colormaps['nipy_spectral'].resampled(n_init) + # # colors = {l: cmap(i) for i, l in enumerate(active_labels)} + # colors = {i: cmap(i) for i in range(n_init)} + # ax = axs[0] + # ax.scatter(X[:, 0], X[:, 1], c=labels, cmap='nipy_spectral', s=4) + # ax.set_title(f'n={X.shape[0]} c={active_labels.size} n_init={n_init} min_cluster_size={min_cluster_size} final_pass={final_pass}') + # ax = axs[1] + # for i, l in enumerate(active_labels): + # mask = labels == l + # ax.plot(X[mask, :].T, color=colors[l], alpha=0.4) + # plt.show() + + while True: # iterations iteration_number += 1 # print(' iterations', iteration_number) @@ -618,7 +634,9 @@ def compare_pairs(X, labels, pairs, centroids, covmats, min_cluster_size, isocut (modified_inds2,) = np.nonzero(L12[inds1.size :] == 1) # protect against pure swaping between label1<>label2 - pure_swaping = modified_inds1.size != inds1.size and modified_inds2.size != inds2.size + # pure_swaping = modified_inds1.size == inds1.size and modified_inds2.size == inds2.size + pure_swaping = (modified_inds1.size / inds1.size + modified_inds2.size / inds2.size) >= 1.0 + if modified_inds1.size > 0 and not pure_swaping: something_was_redistributed = True diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index b60b76c51d..ed2dcec881 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -33,23 +33,32 @@ class IterativeISOSPLITClustering: "motion": None, "seed": None, "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 120.0, "motion": None}, + "pre_label": { + "mode": "channel", + # "mode": "vertical_bin", + + }, "split": { - "split_radius_um": 40.0, + # "split_radius_um": 40.0, + "split_radius_um": 60.0, "recursive": True, "recursive_depth": 5, "method_kwargs": { "clusterer": { "method": "isosplit", - "n_init": 50, + # "method": "isosplit6", + # "n_init": 50, "min_cluster_size": 10, "max_iterations_per_pass": 500, - "isocut_threshold": 2.0, + # "isocut_threshold": 2.0, + "isocut_threshold": 2.5, }, "min_size_split": 25, - "n_pca_features": 3, + # "n_pca_features": 3, + "n_pca_features": 10, - # "projection_mode": "tsvd", - "projection_mode": "pca", + "projection_mode": "tsvd", + # "projection_mode": "pca", }, }, "merge_from_templates": { @@ -58,6 +67,7 @@ class IterativeISOSPLITClustering: "similarity_thresh": 0.8, }, "merge_from_features": None, + # "merge_from_features": {}, "clean": { "minimum_cluster_size": 10, }, @@ -122,7 +132,35 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): split_params["method_kwargs"]["waveforms_sparse_mask"] = sparse_mask split_params["method_kwargs"]["feature_name"] = "peaks_svd" - original_labels = peaks["channel_index"] + + if params["pre_label"]["mode"] == "channel": + original_labels = peaks["channel_index"] + elif params["pre_label"]["mode"] == "vertical_bin": + # 2 params + direction = "y" + bin_um = 40. + + channel_locations = recording.get_channel_locations() + dim = "xyz".index(direction) + channel_depth = channel_locations[:, dim] + + # bins + min_ = np.min(channel_depth) + max_ = np.max(channel_depth) + num_windows = int((max_ - min_) // bin_um) + num_windows = max(num_windows, 1) + border = ((max_ - min_) % bin_um) / 2 + vertical_bins = np.zeros(num_windows+3) + vertical_bins[1:-1] = np.arange(num_windows + 1) * bin_um + min_ + border + vertical_bins[0] = -np.inf + vertical_bins[-1] = np.inf + print(min_, max_) + print(vertical_bins) + print(vertical_bins.size) + # peak depth + peak_depths = channel_depth[peaks["channel_index"]] + # label by bin + original_labels = np.digitize(peak_depths, vertical_bins) # clusterer = params["split"]["clusterer"] # clusterer_kwargs = params["split"]["clusterer_kwargs"] From 19e77fa372ee9f2d68f5c99c1c2b427515bdff56 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 22 Oct 2025 08:49:26 +0200 Subject: [PATCH 21/47] fix git bug --- .../generation/splitting_tools.py | 518 +++++------------- 1 file changed, 125 insertions(+), 393 deletions(-) diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 7c2f239157..5bfc7048b5 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -1,415 +1,147 @@ -from __future__ import annotations - -import warnings - -from multiprocessing import get_context -from threadpoolctl import threadpool_limits -from tqdm.auto import tqdm - - import numpy as np +from spikeinterface.core.numpyextractors import NumpySorting +from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs - -from .tools import aggregate_sparse_features, FeaturesLoader - -try: - import numba -except: - pass # isocut requires numba - -# important all DEBUG and matplotlib are left in the code intentionally - - -def split_clusters( - peak_labels, - recording, - features_dict_or_folder, - method="local_feature_clustering", - method_kwargs={}, - recursive=False, - recursive_depth=None, - returns_split_count=False, - debug_folder=None, - job_kwargs=None, +def split_sorting_by_times( + sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None ): """ - Run recusrsively (or not) in a multi process pool a local split method. + Fonction used to split a sorting based on the times of the units. This + might be used for benchmarking meta merging step (see components) Parameters ---------- - peak_labels: numpy.array - Peak label before split - recording: Recording - Recording object - features_dict_or_folder: dict or folder - A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features - method: str, default: "local_feature_clustering" - The method name - method_kwargs: dict, default: dict() - The method option - recursive: bool, default: False - Recursive or not - recursive_depth: None or int, default: None - If recursive=True, then this is the max split per spikes - returns_split_count: bool, default: False - Optionally return the split count vector. Same size as labels + sorting_analyzer : A sortingAnalyzer object + The sortingAnalyzer object whose sorting should be splitted + splitting_probability : float, default 0.5 + probability of being splitted, for any cell in the provided sorting + partial_split_prob : float, default 0.95 + The percentage of spikes that will belong to pre/post splits + unit_ids : list of unit_ids, default None + The list of unit_ids to be splitted, if prespecified + min_snr : float, default=None + If specified, only cells with a snr higher than min_snr might be splitted + seed : int | None, default: None + The seed for random generator. Returns ------- - new_labels: numpy.ndarray - The labels of peaks after split. - split_count: numpy.ndarray - Optionally returned + new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted """ - job_kwargs = fix_job_kwargs(job_kwargs) - n_jobs = job_kwargs["n_jobs"] - mp_context = job_kwargs.get("mp_context", None) - progress_bar = job_kwargs["progress_bar"] - max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) - - original_labels = peak_labels - peak_labels = peak_labels.copy() - split_count = np.zeros(peak_labels.size, dtype=int) - recursion_level = 1 - Executor = get_poolexecutor(n_jobs) - - with Executor( - max_workers=n_jobs, - initializer=split_worker_init, - mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), - ) as pool: - labels_set = np.setdiff1d(peak_labels, [-1]) - current_max_label = np.max(labels_set) + 1 - jobs = [] - - if debug_folder is not None: - if debug_folder.exists(): - import shutil - - shutil.rmtree(debug_folder) - debug_folder.mkdir(parents=True, exist_ok=True) - - for label in labels_set: - peak_indices = np.flatnonzero(peak_labels == label) - if debug_folder is not None: - sub_folder = str(debug_folder / f"split_{label}") - - else: - sub_folder = None - if peak_indices.size > 0: - jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level, sub_folder)) - - if progress_bar: - pbar = tqdm(desc=f"split_clusters with {method}", total=len(labels_set)) - - for res in jobs: - is_split, local_labels, peak_indices, sub_folder = res.result() - - if progress_bar: - pbar.update(1) - - if not is_split: - continue - - mask = local_labels >= 0 - peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label - peak_labels[peak_indices[~mask]] = local_labels[~mask] - split_count[peak_indices] += 1 - current_max_label += np.max(local_labels[mask]) + 1 - - if recursive: - recursion_level = np.max(split_count[peak_indices]) - if recursive_depth is not None: - # stop recursivity when recursive_depth is reach - extra_ball = recursion_level < recursive_depth - else: - # recursive always - extra_ball = True + sorting = sorting_analyzer.sorting + rng = np.random.RandomState(seed) + fs = sorting_analyzer.sampling_frequency - if extra_ball: - new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) - for label in new_labels_set: - peak_indices = np.flatnonzero(peak_labels == label) - if sub_folder is not None: - new_sub_folder = sub_folder + f"_{label}" - else: - new_sub_folder = None - if peak_indices.size > 0: - # print('Relaunched', label, len(peak_indices), recursion_level) - jobs.append( - pool.submit(split_function_wrapper, peak_indices, recursion_level, new_sub_folder) - ) - if progress_bar: - pbar.total += 1 + nb_splits = int(splitting_probability * len(sorting.unit_ids)) + if unit_ids is None: + select_from = sorting.unit_ids + if min_snr is not None: + if sorting_analyzer.get_extension("noise_levels") is None: + sorting_analyzer.compute("noise_levels") + if sorting_analyzer.get_extension("quality_metrics") is None: + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) - if progress_bar: - pbar.close() - del pbar + snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] - if returns_split_count: - return peak_labels, split_count + to_split_ids = rng.choice(select_from, nb_splits, replace=False) else: - return peak_labels - - -global _ctx - - -def split_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker + to_split_ids = unit_ids + + spikes = sorting_analyzer.sorting.to_spike_vector(concatenated=False) + new_spikes = spikes[0].copy() + max_index = np.max(new_spikes["unit_index"]) + new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) + spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) + splitted_pairs = [] + for unit_id in to_split_ids: + ind_mask = spike_indices[0][unit_id] + m = np.median(spikes[0][ind_mask]["sample_index"]) + time_mask = spikes[0][ind_mask]["sample_index"] > m + mask = time_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) + new_index = int(unit_id) * np.ones(len(mask), dtype=bool) + new_index[mask] = max_index + 1 + new_spikes["unit_index"][ind_mask] = new_index + new_unit_ids += [max_index + 1] + splitted_pairs += [(unit_id, new_unit_ids[-1])] + max_index += 1 + + new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) + return new_sorting, splitted_pairs + + +def split_sorting_by_amplitudes( + sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None ): - global _ctx - _ctx = {} - - _ctx["recording"] = recording - features_dict_or_folder - _ctx["original_labels"] = original_labels - _ctx["method"] = method - _ctx["method_kwargs"] = method_kwargs - _ctx["method_class"] = split_methods_dict[method] - _ctx["max_threads_per_worker"] = max_threads_per_worker - _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) - _ctx["peaks"] = _ctx["features"]["peaks"] - - -def split_function_wrapper(peak_indices, recursion_level, debug_folder): - global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_worker"]): - is_split, local_labels = _ctx["method_class"].split( - peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, debug_folder, **_ctx["method_kwargs"] - ) - return is_split, local_labels, peak_indices, debug_folder - - -class LocalFeatureClustering: - """ - This method is a refactorized mix between: - * old tridesclous code - * "herding_split()" in DART/spikepsvae by Charlie Windolf - - The idea simple : - * agregate features (svd or even waveforms) with sparse channel. - * run a local feature reduction (pca or svd) - * try a new split (hdscan or isosplit) """ + Fonction used to split a sorting based on the amplitudes of the units. This + might be used for benchmarking meta merging step (see components) - name = "local_feature_clustering" - - @staticmethod - def split( - peak_indices, - peaks, - features, - recursion_level=1, - debug_folder=None, - clusterer={"method": "hdbscan", "min_cluster_size": 25, "min_samples": 5}, - feature_name="sparse_tsvd", - neighbours_mask=None, - waveforms_sparse_mask=None, - min_size_split=25, - n_pca_features=3, - seed=None, - projection_mode="tsvd", - minimum_overlap_ratio=0.25, - ): - - clustering_kwargs = clusterer.copy() - clusterer_method = clustering_kwargs.pop("method") - - assert clusterer_method in ["hdbscan", "isosplit", "isosplit6"] - - local_labels = np.zeros(peak_indices.size, dtype=np.int64) - - # can be sparse_tsvd or sparse_wfs - sparse_features = features[feature_name] - - assert waveforms_sparse_mask is not None - - # target channel subset is done intersect local channels + neighbours - local_chans = np.unique(peaks["channel_index"][peak_indices]) - - target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) - target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0)) - num_intersection = len(target_intersection_channels) - num_union = len(target_union_channels) - - # TODO fix this a better way, this when cluster have too few overlapping channels - if (num_intersection / num_union) < minimum_overlap_ratio: - return False, None - - aligned_wfs, dont_have_channels = aggregate_sparse_features( - peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels - ) - - local_labels[dont_have_channels] = -2 - kept = np.flatnonzero(~dont_have_channels) - - if kept.size < min_size_split: - return False, None - - aligned_wfs = aligned_wfs[kept, :, :] - flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) - - is_split = False - - if isinstance(n_pca_features, float): - assert 0 < n_pca_features < 1, "n_components should be in ]0, 1[" - nb_dimensions = min(flatten_features.shape[0], flatten_features.shape[1]) - if projection_mode == "pca": - from sklearn.decomposition import PCA - - tsvd = PCA(nb_dimensions, whiten=True) - elif projection_mode == "tsvd": - from sklearn.decomposition import TruncatedSVD - - tsvd = TruncatedSVD(nb_dimensions, random_state=seed) - final_features = tsvd.fit_transform(flatten_features) - n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1 - final_features = final_features[:, :n_explain] - n_pca_features = final_features.shape[1] - elif isinstance(n_pca_features, int): - if flatten_features.shape[1] > n_pca_features: - if projection_mode == "pca": - from sklearn.decomposition import PCA - - tsvd = PCA(n_pca_features, whiten=True) - final_features = tsvd.fit_transform(flatten_features) - elif projection_mode == "tsvd": - from sklearn.decomposition import TruncatedSVD - - tsvd = TruncatedSVD(n_pca_features, random_state=seed) - final_features = tsvd.fit_transform(flatten_features) - - else: - final_features = flatten_features - tsvd = None - elif n_pca_features is None: - final_features = flatten_features - tsvd = None - - - if clusterer_method == "hdbscan": - from hdbscan import HDBSCAN - - clustering_kwargs.update(core_dist_n_jobs=1) - clust = HDBSCAN(**clustering_kwargs) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - clust.fit(final_features) - possible_labels = clust.labels_ - is_split = np.setdiff1d(possible_labels, [-1]).size > 1 - del clust - elif clusterer_method == "isosplit": - from spikeinterface.sortingcomponents.clustering.isosplit_isocut import isosplit - - min_cluster_size = clustering_kwargs["min_cluster_size"] - - # here the trick is that we do not except more than 4 to 5 clusters - num_samples = final_features.shape[0] - # n_init = 50 - n_init = 20 - if n_init > (num_samples // min_cluster_size): - # avoid warning in isosplit when sample_size is too small - factor = min_cluster_size * 4 - n_init = max(2, num_samples // factor) - - clustering_kwargs_ = clustering_kwargs.copy() - clustering_kwargs_["n_init"] = n_init - - - possible_labels = isosplit(final_features, **clustering_kwargs_) - - for i in np.unique(possible_labels): - mask = possible_labels == i - if np.sum(mask) < min_cluster_size: - possible_labels[mask] = -1 - - is_split = np.setdiff1d(possible_labels, [-1]).size > 1 - elif clusterer_method == "isosplit6": - # this use the official C++ isosplit6 from Jeremy Magland - import isosplit6 - - min_cluster_size = clustering_kwargs.get("min_cluster_size", 25) - possible_labels = isosplit6.isosplit6(final_features) - for i in np.unique(possible_labels): - mask = possible_labels == i - if np.sum(mask) < min_cluster_size: - possible_labels[mask] = -1 - is_split = np.setdiff1d(possible_labels, [-1]).size > 1 - else: - raise ValueError(f"wrong clusterer {clusterer}. Possible options are 'hdbscan/isosplit/isosplit6'.") - - DEBUG = False # only for Sam or dirty hacking - # DEBUG = True - # DEBUG = recursion_level > 2 - - if debug_folder is not None or DEBUG: - import matplotlib.pyplot as plt - - labels_set = np.setdiff1d(possible_labels, [-1]) - colors = plt.colormaps["tab10"].resampled(len(labels_set)) - colors = {k: colors(i) for i, k in enumerate(labels_set)} - colors[-1] = "k" - fig, axs = plt.subplots(nrows=4) - - flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) - - if final_features.shape[1] == 1: - final_features = np.hstack((final_features, np.zeros_like(final_features))) - - sl = slice(None, None, 100) - for k in np.unique(possible_labels): - mask = possible_labels == k - ax = axs[0] - ax.scatter(final_features[:, 0][mask], final_features[:, 1][mask], s=5, color=colors[k]) - if k > -1: - centroid = final_features[:, :2][mask].mean(axis=0) - ax.text(centroid[0], centroid[1], f"Label {k}", fontsize=10, color="k") - ax = axs[1] - ax.plot(flatten_wfs[mask].T, color=colors[k], alpha=0.1) - if k > -1: - ax.plot(np.median(flatten_wfs[mask].T, axis=1), color=colors[k], lw=2) - ax.set_xlabel(f"PCA features") - - ax = axs[3] - if n_pca_features == 1: - bins = np.linspace(final_features[:, 0].min(), final_features[:, 0].max(), 100) - ax.hist(final_features[mask, 0], bins, color=colors[k], alpha=0.1) - else: - ax.plot(final_features[mask].T, color=colors[k], alpha=0.1) - if k > -1 and n_pca_features > 1: - ax.plot(np.median(final_features[mask].T, axis=1), color=colors[k], lw=2) - ax.set_xlabel(f"Projected PCA features, dim{final_features.shape[1]}") - - if tsvd is not None: - ax = axs[2] - sorted_components = np.argsort(tsvd.explained_variance_ratio_)[::-1] - ax.plot(tsvd.explained_variance_ratio_[sorted_components], c="k") - del tsvd - - ymin, ymax = ax.get_ylim() - ax.plot([n_pca_features, n_pca_features], [ymin, ymax], "k--") - - axs[0].set_title(f"{clusterer} level={recursion_level}") - if not DEBUG: - fig.savefig(str(debug_folder) + ".png") - plt.close(fig) - else: - plt.show() - - if not is_split: - return is_split, None - - local_labels[kept] = possible_labels - - return is_split, local_labels + Parameters + ---------- + sorting_analyzer : A sortingAnalyzer object + The sortingAnalyzer object whose sorting should be splitted + splitting_probability : float, default 0.5 + probability of being splitted, for any cell in the provided sorting + partial_split_prob : float, default 0.95 + The percentage of spikes that will belong to pre/post splits + unit_ids : list of unit_ids, default None + The list of unit_ids to be splitted, if prespecified + min_snr : float, default=None + If specified, only cells with a snr higher than min_snr might be splitted + seed : int | None, default: None + The seed for random generator. + Returns + ------- + new_sorting, splitted_pairs : The new splitted sorting, and the pairs that have been splitted + """ -split_methods_list = [ - LocalFeatureClustering, -] -split_methods_dict = {e.name: e for e in split_methods_list} + if sorting_analyzer.get_extension("spike_amplitudes") is None: + sorting_analyzer.compute("spike_amplitudes") + + rng = np.random.RandomState(seed) + fs = sorting_analyzer.sampling_frequency + from spikeinterface.core.template_tools import get_template_extremum_channel + + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) + new_spikes = spikes[0].copy() + amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + nb_splits = int(splitting_probability * len(sorting_analyzer.sorting.unit_ids)) + + if unit_ids is None: + select_from = sorting_analyzer.sorting.unit_ids + if min_snr is not None: + if sorting_analyzer.get_extension("noise_levels") is None: + sorting_analyzer.compute("noise_levels") + if sorting_analyzer.get_extension("quality_metrics") is None: + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + + snr = sorting_analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + to_split_ids = rng.choice(select_from, nb_splits, replace=False) + else: + to_split_ids = unit_ids + + max_index = np.max(new_spikes["unit_index"]) + new_unit_ids = list(sorting_analyzer.sorting.unit_ids.copy()) + splitted_pairs = [] + spike_indices = spike_vector_to_indices(spikes, sorting_analyzer.unit_ids, absolute_index=True) + + for unit_id in to_split_ids: + ind_mask = spike_indices[0][unit_id] + thresh = np.median(amplitudes[ind_mask]) + amplitude_mask = amplitudes[ind_mask] > thresh + mask = amplitude_mask & (rng.rand(len(ind_mask)) <= partial_split_prob).astype(bool) + new_index = int(unit_id) * np.ones(len(mask)) + new_index[mask] = max_index + 1 + new_spikes["unit_index"][ind_mask] = new_index + new_unit_ids += [max_index + 1] + splitted_pairs += [(unit_id, new_unit_ids[-1])] + max_index += 1 + + new_sorting = NumpySorting(new_spikes, sampling_frequency=fs, unit_ids=new_unit_ids) + return new_sorting, splitted_pairs From 8330277f886c226fdca93522882c25016f5901a0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 27 Oct 2025 12:36:57 +0100 Subject: [PATCH 22/47] improve isocut and tdc2 --- .../sorters/internal/tridesclous2.py | 19 +++++++++++++------ .../clustering/iterative_isosplit.py | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index daaadd941d..84aa652d10 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -197,8 +197,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" # ) + + # recording_w = whiten(recording, mode="global") + unit_ids, clustering_label, more_outs = find_clusters_from_peaks( recording, + # recording_w, peaks, method="iterative-isosplit", method_kwargs=clustering_kwargs, @@ -251,17 +255,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe=recording_for_peeler.get_probe(), is_in_uV=False, ) + + # sparsity is a mix between radius and sparsity_threshold = params["templates"]["sparsity_threshold"] - sparsity = compute_sparsity( - templates_dense, method="snr", noise_levels=noise_levels, threshold=sparsity_threshold - ) + radius_um = params["waveforms"]["radius_um"] + sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) + sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", + noise_levels=noise_levels, threshold=sparsity_threshold) + sparsity.mask = sparsity.mask & sparsity_snr.mask templates = templates_dense.to_sparse(sparsity) - # templates = remove_empty_templates(templates) templates = clean_templates( - templates_dense, - sparsify_threshold=params["templates"]["sparsity_threshold"], + templates, + sparsify_threshold=None, noise_levels=noise_levels, min_snr=params["templates"]["min_snr"], max_jitter_ms=None, diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index ed2dcec881..74cb864c61 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -51,7 +51,7 @@ class IterativeISOSPLITClustering: "min_cluster_size": 10, "max_iterations_per_pass": 500, # "isocut_threshold": 2.0, - "isocut_threshold": 2.5, + "isocut_threshold": 2.2, }, "min_size_split": 25, # "n_pca_features": 3, From 84aeb9273fdae5bf6e276d75de30141484b4e15c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 28 Oct 2025 14:24:45 +0100 Subject: [PATCH 23/47] tdc2 improvement --- src/spikeinterface/sorters/internal/tridesclous2.py | 6 ++++-- .../sortingcomponents/clustering/iterative_isosplit.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 84aa652d10..da5e489460 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -47,13 +47,15 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 4}, + "svd": {"n_components": 8}, "clustering": { "recursive_depth": 5, }, "templates": { "ms_before": 2.0, "ms_after": 3.0, + # "ms_before": 1.5, + # "ms_after": 2.5, "max_spikes_per_unit": 400, "sparsity_threshold": 1.5, "min_snr": 2.5, @@ -130,7 +132,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Done correct_motion()") - recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32") + recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20., dtype="float32") if apply_cmr: recording = common_reference(recording) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 74cb864c61..edc0df89c5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -173,7 +173,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): split_params["returns_split_count"] = True if params["seed"] is not None: - split_params["method_kwargs"]["clusterer"] = params["seed"] + split_params["method_kwargs"]["clusterer"]["seed"] = params["seed"] post_split_label, split_count = split_clusters( original_labels, From 41b5d6b6167960f8a32791536cce3e49446cd9cc Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 28 Oct 2025 20:16:03 +0100 Subject: [PATCH 24/47] WIP --- .../postprocessing/template_similarity.py | 48 ------------------- 1 file changed, 48 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 2940b863ee..aed01b6a2c 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -234,11 +234,7 @@ def _compute_similarity_matrix_numpy( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] -<<<<<<< HEAD - local_mask = get_mask_for_sparse_template(i, sparsity_mask, other_sparsity_mask, support=support) -======= local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support) ->>>>>>> 8533a52d77f11188af8cb01eef358b6e7fa8bec7 overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] for gcount, j in enumerate(overlapping_templates): @@ -316,25 +312,6 @@ def _compute_similarity_matrix_numba( ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays ## So we inline the function here -<<<<<<< HEAD - # local_mask = get_mask_for_sparse_template(i, sparsity, other_sparsity, support=support) - - local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) - - if support == "intersection": - local_mask = np.logical_and( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) - elif support == "union": - local_mask = np.logical_and( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) - units_overlaps = np.sum(local_mask, axis=1) > 0 - local_mask = np.logical_or( - sparsity_mask[i], other_sparsity_mask - ) # shape (num_templates, other_num_templates, num_channels) - local_mask[~units_overlaps] = False -======= # local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support) if support == "intersection": @@ -347,7 +324,6 @@ def _compute_similarity_matrix_numba( ) # shape (other_num_templates, num_channels) elif support == "dense": local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) ->>>>>>> 8533a52d77f11188af8cb01eef358b6e7fa8bec7 overlapping_templates = np.flatnonzero(np.sum(local_mask, 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] @@ -405,29 +381,6 @@ def _compute_similarity_matrix_numba( _compute_similarity_matrix = _compute_similarity_matrix_numpy -<<<<<<< HEAD -def get_mask_for_sparse_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: - - other_num_templates = other_sparsity.shape[0] - num_channels = sparsity.shape[1] - - mask = np.ones((other_num_templates, num_channels), dtype=np.bool_) - - if support == "intersection": - mask = np.logical_and( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) - elif support == "union": - mask = np.logical_and( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) - units_overlaps = np.sum(mask, axis=1) > 0 - mask = np.logical_or( - sparsity[template_index], other_sparsity - ) # shape (num_templates, other_num_templates, num_channels) - mask[~units_overlaps] = False - -======= def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray: if support == "intersection": @@ -436,7 +389,6 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels) elif support == "dense": mask = np.ones(other_sparsity.shape, dtype=bool) ->>>>>>> 8533a52d77f11188af8cb01eef358b6e7fa8bec7 return mask From ab470f5fb1abf3efb27d91096f53545782295dc7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 28 Oct 2025 20:19:07 +0100 Subject: [PATCH 25/47] WIP --- .../sorters/internal/spyking_circus2.py | 50 +++++++++++++------ .../clustering/iterative_hdbscan.py | 12 ++--- src/spikeinterface/sortingcomponents/tools.py | 15 ++++-- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1e47051c5c..f3080427bc 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -15,15 +15,15 @@ _set_optimal_chunk_size, ) from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core import compute_sparsity class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, + "general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 20}, "whitening": {"mode": "local", "regularize": False}, "detection": { "method": "matched_filtering", @@ -37,7 +37,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, + "clustering": {"method": "iterative-isosplit", "method_kwargs": dict()}, + "cleaning" : {"min_snr" : 5, "max_jitter_ms" : 0.1, "sparsify_threshold" : None}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -114,7 +115,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() ms_before = params["general"].get("ms_before", 0.5) ms_after = params["general"].get("ms_after", 1.5) - radius_um = params["general"].get("radius_um", 100) + radius_um = params["general"].get("radius_um", 100.0) detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5) peak_sign = params["detection"].get("peak_sign", "neg") deterministic = params["deterministic_peaks_detection"] @@ -130,7 +131,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") - if num_channels > 1: + if num_channels >= 32: recording_f = common_reference(recording_f) else: if verbose: @@ -325,7 +326,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if not clustering_from_svd: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording - templates = get_templates_from_peaks_and_recording( + dense_templates = get_templates_from_peaks_and_recording( recording_w, selected_peaks, peak_labels, @@ -333,10 +334,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after, job_kwargs=job_kwargs, ) + + sparsity = compute_sparsity(dense_templates, method="radius", radius_um=radius_um) + threshold = params["cleaning"].get("sparsify_threshold", None) + if threshold is not None: + sparsity_snr = compute_sparsity(dense_templates, method="snr", amplitude_mode="peak_to_peak", + noise_levels=noise_levels, threshold=threshold) + sparsity.mask = sparsity.mask & sparsity_snr.mask + + templates = dense_templates.to_sparse(sparsity) + else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - templates, _ = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -348,15 +359,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): operator="median", ) # this release the peak_svd memmap file + templates = dense_templates.to_sparse(new_sparse_mask) del more_outs + cleaning_kwargs = params.get("cleaning", {}).copy() + cleaning_kwargs["noise_levels"] = noise_levels + cleaning_kwargs["remove_empty"] = True templates = clean_templates( templates, - noise_levels=noise_levels, - min_snr=detect_threshold, - max_jitter_ms=0.1, - remove_empty=True, + **cleaning_kwargs ) if verbose: @@ -416,7 +428,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if sorting.get_non_empty_unit_ids().size > 0: final_analyzer = final_cleaning_circus( - recording_w, sorting, templates, job_kwargs=job_kwargs, **merging_params + recording_w, + sorting, + templates, + noise_levels=noise_levels, + job_kwargs=job_kwargs, + **merging_params ) final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer") @@ -451,14 +468,15 @@ def final_cleaning_circus( max_distance_um=50, template_diff_thresh=np.arange(0.05, 0.5, 0.05), debug_folder=None, - job_kwargs=None, + noise_levels=None, + job_kwargs=dict(), ): from spikeinterface.sortingcomponents.tools import create_sorting_analyzer_with_existing_templates from spikeinterface.curation.auto_merge import auto_merge_units # First we compute the needed extensions - analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates) + analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates, noise_levels=noise_levels) analyzer.compute("unit_locations", method="center_of_mass", **job_kwargs) analyzer.compute("template_similarity", **similarity_kwargs) @@ -480,4 +498,4 @@ def final_cleaning_circus( **job_kwargs, ) - return final_sa + return final_sa \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 130718b6f7..61598526f0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -31,21 +31,19 @@ class IterativeHDBSCANClustering: "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0}, "seed": None, "split": { - "split_radius_um": 50.0, + "split_radius_um": 75.0, "recursive": True, "recursive_depth": 3, "method_kwargs": { "clusterer": { "method": "hdbscan", "min_cluster_size": 20, - "cluster_selection_epsilon": 0.5, - "cluster_selection_method": "leaf", "allow_single_cluster": True, }, - "n_pca_features": 0.9, + "n_pca_features": 3, }, }, - "merge_from_templates": dict(), + "merge_from_templates": dict(similarity_thresh=0.9), "merge_from_features": None, "debug_folder": None, "verbose": True, @@ -71,7 +69,7 @@ class IterativeHDBSCANClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - split_radius_um = params["split"].pop("split_radius_um", 50) + split_radius_um = params["split"].pop("split_radius_um", 75) peaks_svd = params["peaks_svd"] ms_before = peaks_svd["ms_before"] ms_after = peaks_svd["ms_after"] @@ -169,4 +167,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd=peaks_svd, peak_svd_sparse_mask=sparse_mask, ) - return labels, peak_labels, more_outs + return labels, peak_labels, more_outs \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index fa8a86562f..5c945e1b06 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -10,13 +10,12 @@ HAVE_PSUTIL = False from spikeinterface.core.sparsity import ChannelSparsity -from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.analyzer_extension_core import ComputeTemplates +from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels @@ -437,7 +436,7 @@ def remove_empty_templates(templates): return templates.select_units(templates.unit_ids[not_empty]) -def create_sorting_analyzer_with_existing_templates(sorting, recording, templates, remove_empty=True): +def create_sorting_analyzer_with_existing_templates(sorting, recording, templates, remove_empty=True, noise_levels=None): sparsity = templates.sparsity templates_array = templates.get_dense_templates().copy() @@ -459,6 +458,14 @@ def create_sorting_analyzer_with_existing_templates(sorting, recording, template sa.extensions["templates"].data["std"] = np.zeros(templates_array.shape, dtype=np.float32) sa.extensions["templates"].run_info["run_completed"] = True sa.extensions["templates"].run_info["runtime_s"] = 0 + + if noise_levels is not None: + sa.extensions["noise_levels"] = ComputeNoiseLevels(sa) + sa.extensions["noise_levels"].params = {} + sa.extensions["noise_levels"].data["noise_levels"] = noise_levels + sa.extensions["noise_levels"].run_info["run_completed"] = True + sa.extensions["noise_levels"].run_info["runtime_s"] = 0 + return sa @@ -529,4 +536,4 @@ def clean_templates( to_select = templates.unit_ids[np.flatnonzero(sparsity.mask.sum(axis=1) > 0)] templates = templates.select_units(to_select) - return templates + return templates \ No newline at end of file From 936c31b9e3e54dc51002887c9879c4aa320f7128 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Oct 2025 19:20:51 +0000 Subject: [PATCH 26/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 30 ++++++++++--------- .../clustering/iterative_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 6 ++-- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f3080427bc..8a3c516a15 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -38,7 +38,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-isosplit", "method_kwargs": dict()}, - "cleaning" : {"min_snr" : 5, "max_jitter_ms" : 0.1, "sparsify_threshold" : None}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -337,9 +337,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsity = compute_sparsity(dense_templates, method="radius", radius_um=radius_um) threshold = params["cleaning"].get("sparsify_threshold", None) - if threshold is not None: - sparsity_snr = compute_sparsity(dense_templates, method="snr", amplitude_mode="peak_to_peak", - noise_levels=noise_levels, threshold=threshold) + if threshold is not None: + sparsity_snr = compute_sparsity( + dense_templates, + method="snr", + amplitude_mode="peak_to_peak", + noise_levels=noise_levels, + threshold=threshold, + ) sparsity.mask = sparsity.mask & sparsity_snr.mask templates = dense_templates.to_sparse(sparsity) @@ -366,10 +371,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels cleaning_kwargs["remove_empty"] = True - templates = clean_templates( - templates, - **cleaning_kwargs - ) + templates = clean_templates(templates, **cleaning_kwargs) if verbose: print("Kept %d clean clusters" % len(templates.unit_ids)) @@ -428,12 +430,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if sorting.get_non_empty_unit_ids().size > 0: final_analyzer = final_cleaning_circus( - recording_w, - sorting, - templates, + recording_w, + sorting, + templates, noise_levels=noise_levels, - job_kwargs=job_kwargs, - **merging_params + job_kwargs=job_kwargs, + **merging_params, ) final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer") @@ -498,4 +500,4 @@ def final_cleaning_circus( **job_kwargs, ) - return final_sa \ No newline at end of file + return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 61598526f0..ddc6725a25 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -167,4 +167,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd=peaks_svd, peak_svd_sparse_mask=sparse_mask, ) - return labels, peak_labels, more_outs \ No newline at end of file + return labels, peak_labels, more_outs diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 5c945e1b06..d4d1f0df67 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -436,7 +436,9 @@ def remove_empty_templates(templates): return templates.select_units(templates.unit_ids[not_empty]) -def create_sorting_analyzer_with_existing_templates(sorting, recording, templates, remove_empty=True, noise_levels=None): +def create_sorting_analyzer_with_existing_templates( + sorting, recording, templates, remove_empty=True, noise_levels=None +): sparsity = templates.sparsity templates_array = templates.get_dense_templates().copy() @@ -536,4 +538,4 @@ def clean_templates( to_select = templates.unit_ids[np.flatnonzero(sparsity.mask.sum(axis=1) > 0)] templates = templates.select_units(to_select) - return templates \ No newline at end of file + return templates From 1b5ef48b4c80b3fd83a79746b25a83cb25faf8dc Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 28 Oct 2025 21:31:38 +0100 Subject: [PATCH 27/47] Alignment during merging --- .../postprocessing/template_similarity.py | 8 +++-- .../clustering/iterative_hdbscan.py | 2 +- .../clustering/merging_tools.py | 35 +++++++++++++++---- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index aed01b6a2c..f38368148a 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -393,7 +393,7 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi def compute_similarity_with_templates_array( - templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None + templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None, return_lags=False ): if method == "cosine_similarity": @@ -432,10 +432,14 @@ def compute_similarity_with_templates_array( templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support ) + lags = np.argmin(distances, axis=0) - num_shifts distances = np.min(distances, axis=0) similarity = 1 - distances - return similarity + if return_lags: + return similarity, lags + else: + return similarity def compute_template_similarity_by_pair( diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 61598526f0..f8dac0c9f3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -43,7 +43,7 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "merge_from_templates": dict(similarity_thresh=0.9), + "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=10), "merge_from_features": None, "debug_folder": None, "verbose": True, diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index d64d0cae3b..8588150341 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -542,7 +542,7 @@ def merge_peak_labels_from_templates( from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array - similarity = compute_similarity_with_templates_array( + similarity, lags = compute_similarity_with_templates_array( templates_array, templates_array, method=similarity_metric, @@ -550,12 +550,14 @@ def merge_peak_labels_from_templates( support="union", sparsity=template_sparse_mask, other_sparsity=template_sparse_mask, + return_lags=True ) + pair_mask = similarity > similarity_thresh clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = ( _apply_pair_mask_on_labels_and_recompute_templates( - pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask + pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags ) ) @@ -563,7 +565,7 @@ def merge_peak_labels_from_templates( def _apply_pair_mask_on_labels_and_recompute_templates( - pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask + pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags=None ): """ Resolve pairs graph. @@ -604,9 +606,30 @@ def _apply_pair_mask_on_labels_and_recompute_templates( clean_labels[peak_labels == label] = unit_ids[g0] keep_template[l] = False weights /= weights.sum() - merge_template_array[g0, :, :] = np.sum( - merge_template_array[merge_group, :, :] * weights[:, np.newaxis, np.newaxis], axis=0 - ) + + if lags is None: + merge_template_array[g0, :, :] = np.sum( + merge_template_array[merge_group, :, :] * weights[:, np.newaxis, np.newaxis], axis=0 + ) + else: + # with shifts + accumulated_template = np.zeros_like(merge_template_array[g0, :, :]) + for i, l in enumerate(merge_group): + shift = lags[g0, l] + if shift > 0: + # template is shifted to right + temp = np.zeros_like(accumulated_template) + temp[shift:, :] = merge_template_array[l, :-shift, :] + elif shift < 0: + # template is shifted to left + temp = np.zeros_like(accumulated_template) + temp[:shift, :] = merge_template_array[l, -shift:, :] + else: + temp = merge_template_array[l, :, :] + + accumulated_template += temp * weights[i] + + merge_template_array[g0, :, :] = accumulated_template merge_sparsity_mask[g0, :] = np.all(template_sparse_mask[merge_group, :], axis=0) merge_template_array = merge_template_array[keep_template, :, :] From 8bff173be5646852b713d02983780dc5ef778389 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:32:19 +0000 Subject: [PATCH 28/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/template_similarity.py | 11 +++++++++-- .../sortingcomponents/clustering/merging_tools.py | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index f38368148a..0b9793340c 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -393,7 +393,14 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi def compute_similarity_with_templates_array( - templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None, return_lags=False + templates_array, + other_templates_array, + method, + support="union", + num_shifts=0, + sparsity=None, + other_sparsity=None, + return_lags=False, ): if method == "cosine_similarity": @@ -436,7 +443,7 @@ def compute_similarity_with_templates_array( distances = np.min(distances, axis=0) similarity = 1 - distances - if return_lags: + if return_lags: return similarity, lags else: return similarity diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 8588150341..54d4a5a1cf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -550,7 +550,7 @@ def merge_peak_labels_from_templates( support="union", sparsity=template_sparse_mask, other_sparsity=template_sparse_mask, - return_lags=True + return_lags=True, ) pair_mask = similarity > similarity_thresh @@ -606,7 +606,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates( clean_labels[peak_labels == label] = unit_ids[g0] keep_template[l] = False weights /= weights.sum() - + if lags is None: merge_template_array[g0, :, :] = np.sum( merge_template_array[merge_group, :, :] * weights[:, np.newaxis, np.newaxis], axis=0 From f5869d7ebe38ef602c720828e900c3d95e90059b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 28 Oct 2025 21:34:31 +0100 Subject: [PATCH 29/47] WIP --- .../sortingcomponents/clustering/iterative_hdbscan.py | 2 +- .../sortingcomponents/clustering/merging_tools.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 6e492a825d..8856ee3cc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -43,7 +43,7 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=10), + "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=10, use_lags=True), "merge_from_features": None, "debug_folder": None, "verbose": True, diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 8588150341..93e43c7f6d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -530,6 +530,7 @@ def merge_peak_labels_from_templates( similarity_metric="l1", similarity_thresh=0.8, num_shifts=3, + use_lags=False ): """ Low level function used in sorting components for merging templates based on similarity metrics. @@ -555,6 +556,9 @@ def merge_peak_labels_from_templates( pair_mask = similarity > similarity_thresh + if not use_lags: + lags = None + clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = ( _apply_pair_mask_on_labels_and_recompute_templates( pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags From 27eb0774e853ee084d6024c0a73f6a97ee610d29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:35:29 +0000 Subject: [PATCH 30/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/merging_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 1c9497bc23..77bb4948c1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -530,7 +530,7 @@ def merge_peak_labels_from_templates( similarity_metric="l1", similarity_thresh=0.8, num_shifts=3, - use_lags=False + use_lags=False, ): """ Low level function used in sorting components for merging templates based on similarity metrics. From 4026a079cc517ec3b0f5f7f57797a1b93cd60a95 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 28 Oct 2025 22:43:36 +0100 Subject: [PATCH 31/47] WIP --- .../clustering/iterative_hdbscan.py | 2 +- .../sortingcomponents/clustering/merging_tools.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 8856ee3cc5..592d0abc11 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -43,7 +43,7 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=10, use_lags=True), + "merge_from_templates": dict(similarity_thresh=0.5, num_shifts=3, use_lags=False), "merge_from_features": None, "debug_folder": None, "verbose": True, diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 1c9497bc23..cba814b20f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -588,6 +588,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates( merge_template_array = templates_array.copy() merge_sparsity_mask = template_sparse_mask.copy() new_unit_ids = np.zeros(n_components, dtype=unit_ids.dtype) + for c in range(n_components): merge_group = np.flatnonzero(group_labels == c) g0 = merge_group[0] @@ -618,8 +619,10 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: # with shifts accumulated_template = np.zeros_like(merge_template_array[g0, :, :]) + #import matplotlib.pyplot as plt + #fig, ax = plt.subplots(1, 2) for i, l in enumerate(merge_group): - shift = lags[g0, l] + shift = -lags[g0, l] if shift > 0: # template is shifted to right temp = np.zeros_like(accumulated_template) @@ -631,7 +634,16 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: temp = merge_template_array[l, :, :] + #if l == g0: + # ax[0].plot(temp, c='r') + # ax[1].plot(temp, c='r') + #else: + # ax[0].plot(temp, c='gray', alpha=0.5) + # ax[1].plot(merge_template_array[l, :, :], c='gray', alpha=0.5) + #print(shift, lags[l, g0]) + accumulated_template += temp * weights[i] + #plt.show() merge_template_array[g0, :, :] = accumulated_template merge_sparsity_mask[g0, :] = np.all(template_sparse_mask[merge_group, :], axis=0) From c6f4708686037d50888d39dd7636b6d87fd8a11f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:44:30 +0000 Subject: [PATCH 32/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/merging_tools.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 235e8bbe54..b5f9808c92 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -619,8 +619,8 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: # with shifts accumulated_template = np.zeros_like(merge_template_array[g0, :, :]) - #import matplotlib.pyplot as plt - #fig, ax = plt.subplots(1, 2) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots(1, 2) for i, l in enumerate(merge_group): shift = -lags[g0, l] if shift > 0: @@ -634,16 +634,16 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: temp = merge_template_array[l, :, :] - #if l == g0: + # if l == g0: # ax[0].plot(temp, c='r') # ax[1].plot(temp, c='r') - #else: + # else: # ax[0].plot(temp, c='gray', alpha=0.5) # ax[1].plot(merge_template_array[l, :, :], c='gray', alpha=0.5) - #print(shift, lags[l, g0]) + # print(shift, lags[l, g0]) accumulated_template += temp * weights[i] - #plt.show() + # plt.show() merge_template_array[g0, :, :] = accumulated_template merge_sparsity_mask[g0, :] = np.all(template_sparse_mask[merge_group, :], axis=0) From 356508dff3ba83ffd641ab5f126e05b0334acbc9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 29 Oct 2025 08:28:23 +0100 Subject: [PATCH 33/47] fix nan in plot perf vs snr --- src/spikeinterface/benchmark/benchmark_plot_tools.py | 5 +++-- .../sortingcomponents/clustering/iterative_isosplit.py | 3 --- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index b32cf3df45..7bd81c1b2c 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -478,8 +478,9 @@ def _plot_performances_vs_metric( .get_performance()[performance_name] .to_numpy(dtype="float64") ) - all_xs.append(x) - all_ys.append(y) + mask = ~np.isnan(x) & ~np.isnan(y) + all_xs.append(x[mask]) + all_ys.append(y[mask]) if with_sigmoid_fit: max_snr = max(np.max(x) for x in all_xs) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index edc0df89c5..8111d4528c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -154,9 +154,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): vertical_bins[1:-1] = np.arange(num_windows + 1) * bin_um + min_ + border vertical_bins[0] = -np.inf vertical_bins[-1] = np.inf - print(min_, max_) - print(vertical_bins) - print(vertical_bins.size) # peak depth peak_depths = channel_depth[peaks["channel_index"]] # label by bin From 3f226195cefc2be2f8b450eeed4e3707e3c91211 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 29 Oct 2025 08:41:01 +0100 Subject: [PATCH 34/47] WIP --- .../sortingcomponents/clustering/iterative_hdbscan.py | 2 +- .../sortingcomponents/clustering/merging_tools.py | 11 ----------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 592d0abc11..5a1ed5d64e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -43,7 +43,7 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "merge_from_templates": dict(similarity_thresh=0.5, num_shifts=3, use_lags=False), + "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), "merge_from_features": None, "debug_folder": None, "verbose": True, diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 235e8bbe54..4e6206b5c6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -619,8 +619,6 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: # with shifts accumulated_template = np.zeros_like(merge_template_array[g0, :, :]) - #import matplotlib.pyplot as plt - #fig, ax = plt.subplots(1, 2) for i, l in enumerate(merge_group): shift = -lags[g0, l] if shift > 0: @@ -634,16 +632,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates( else: temp = merge_template_array[l, :, :] - #if l == g0: - # ax[0].plot(temp, c='r') - # ax[1].plot(temp, c='r') - #else: - # ax[0].plot(temp, c='gray', alpha=0.5) - # ax[1].plot(merge_template_array[l, :, :], c='gray', alpha=0.5) - #print(shift, lags[l, g0]) - accumulated_template += temp * weights[i] - #plt.show() merge_template_array[g0, :, :] = accumulated_template merge_sparsity_mask[g0, :] = np.all(template_sparse_mask[merge_group, :], axis=0) From 04f4c09f0d8217d8a8be763197c3490fb781ebac Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 29 Oct 2025 09:25:37 +0100 Subject: [PATCH 35/47] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 8a3c516a15..b7c6492f9c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -37,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"method": "iterative-isosplit", "method_kwargs": dict()}, + "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, From 9ddc13810ea194d9f0131f4b7dc338e4bbf7f3bc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 30 Oct 2025 11:47:34 +0100 Subject: [PATCH 36/47] rename spitting_tools to tersplit_tools to avoid double file with same name --- .../clustering/{splitting_tools.py => itersplit_tools.py} | 0 .../tests/{test_split_tools.py => test_itersplit_tool.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/spikeinterface/sortingcomponents/clustering/{splitting_tools.py => itersplit_tools.py} (100%) rename src/spikeinterface/sortingcomponents/clustering/tests/{test_split_tools.py => test_itersplit_tool.py} (100%) diff --git a/src/spikeinterface/sortingcomponents/clustering/splitting_tools.py b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py similarity index 100% rename from src/spikeinterface/sortingcomponents/clustering/splitting_tools.py rename to src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py diff --git a/src/spikeinterface/sortingcomponents/clustering/tests/test_split_tools.py b/src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py similarity index 100% rename from src/spikeinterface/sortingcomponents/clustering/tests/test_split_tools.py rename to src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py From dce3b9618a74a3f19d65f28351dafc91270149c1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 30 Oct 2025 11:48:22 +0100 Subject: [PATCH 37/47] compute_similarity_with_templates_array returan lags always --- .../postprocessing/template_similarity.py | 14 +++++--------- .../tests/test_template_similarity.py | 7 ++++--- .../clustering/iterative_hdbscan.py | 2 +- .../sortingcomponents/clustering/merging_tools.py | 1 - .../clustering/tests/test_itersplit_tool.py | 2 +- src/spikeinterface/widgets/collision.py | 2 +- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0b9793340c..91923521f1 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -94,7 +94,7 @@ def _merge_extension_data( new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids ) - new_similarity = compute_similarity_with_templates_array( + new_similarity, _ = compute_similarity_with_templates_array( new_templates_array, all_templates_array, method=self.params["method"], @@ -146,7 +146,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids ) - new_similarity = compute_similarity_with_templates_array( + new_similarity, _ = compute_similarity_with_templates_array( new_templates_array, all_templates_array, method=self.params["method"], @@ -188,7 +188,7 @@ def _run(self, verbose=False): self.sorting_analyzer, return_in_uV=self.sorting_analyzer.return_in_uV ) sparsity = self.sorting_analyzer.sparsity - similarity = compute_similarity_with_templates_array( + similarity, _ = compute_similarity_with_templates_array( templates_array, templates_array, method=self.params["method"], @@ -400,7 +400,6 @@ def compute_similarity_with_templates_array( num_shifts=0, sparsity=None, other_sparsity=None, - return_lags=False, ): if method == "cosine_similarity": @@ -443,10 +442,7 @@ def compute_similarity_with_templates_array( distances = np.min(distances, axis=0) similarity = 1 - distances - if return_lags: - return similarity, lags - else: - return similarity + return similarity, lags def compute_template_similarity_by_pair( @@ -456,7 +452,7 @@ def compute_template_similarity_by_pair( templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_in_uV=True) sparsity_1 = sorting_analyzer_1.sparsity sparsity_2 = sorting_analyzer_2.sparsity - similarity = compute_similarity_with_templates_array( + similarity, _ = compute_similarity_with_templates_array( templates_array_1, templates_array_2, method=method, diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index fa7d19fcbc..9fa7a73fec 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -82,8 +82,9 @@ def test_compute_similarity_with_templates_array(params): templates_array = rng.random(size=(2, 20, 5)) other_templates_array = rng.random(size=(4, 20, 5)) - similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params) + similarity, lags = compute_similarity_with_templates_array(templates_array, other_templates_array, **params) print(similarity.shape) + print(lags) pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") @@ -141,5 +142,5 @@ def test_equal_results_numba(params): test.cache_folder = Path("./cache_folder") test.test_extension(params=dict(method="l2")) - # params = dict(method="cosine", num_shifts=8) - # test_compute_similarity_with_templates_array(params) + params = dict(method="cosine", num_shifts=8) + test_compute_similarity_with_templates_array(params) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 5a1ed5d64e..0c089229ee 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -8,7 +8,7 @@ from spikeinterface.core.recording_tools import get_channel_distances from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates -from spikeinterface.sortingcomponents.clustering.splitting_tools import split_clusters +from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index ae7cfc88e6..4813b7e88a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -553,7 +553,6 @@ def merge_peak_labels_from_templates( support="union", sparsity=template_sparse_mask, other_sparsity=template_sparse_mask, - return_lags=True, ) pair_mask = similarity > similarity_thresh diff --git a/src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py b/src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py index 85fb13445c..3724c1c4f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py +++ b/src/spikeinterface/sortingcomponents/clustering/tests/test_itersplit_tool.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from spikeinterface.sortingcomponents.clustering.splitting_tools import split_clusters +from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters # TODO diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index ab41bba931..377286459b 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -91,7 +91,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): templates_array = dp.templates_array[template_inds, :, :].copy() flat_templates = templates_array.reshape(templates_array.shape[0], -1) - similarity_matrix = compute_similarity_with_templates_array( + similarity_matrix, _ = compute_similarity_with_templates_array( templates_array, templates_array, method=dp.metric, From 055176e2bdbc516cfcbe336f759f8259988403fe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 30 Oct 2025 11:48:35 +0100 Subject: [PATCH 38/47] tdc2 params ajustement --- .../sorters/internal/tridesclous2.py | 14 ++++---------- .../clustering/iterative_isosplit.py | 8 ++++---- .../clustering/itersplit_tools.py | 4 ++-- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index da5e489460..33c0d1bd66 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -41,21 +41,19 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, "filtering": { "freq_min": 150.0, - "freq_max": 5000.0, + "freq_max": 6000.0, "ftype": "bessel", "filter_order": 2, }, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 8}, + "svd": {"n_components": 10}, "clustering": { - "recursive_depth": 5, + "recursive_depth": 3, }, "templates": { "ms_before": 2.0, "ms_after": 3.0, - # "ms_before": 1.5, - # "ms_after": 2.5, "max_spikes_per_unit": 400, "sparsity_threshold": 1.5, "min_snr": 2.5, @@ -199,12 +197,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" # ) - + # whitenning do not improve in tdc2 # recording_w = whiten(recording, mode="global") unit_ids, clustering_label, more_outs = find_clusters_from_peaks( recording, - # recording_w, peaks, method="iterative-isosplit", method_kwargs=clustering_kwargs, @@ -212,9 +209,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - # peak_shifts = extra_out["peak_shifts"] - # new_peaks = peaks.copy() - # new_peaks["sample_index"] -= peak_shifts new_peaks = peaks mask = clustering_label >= 0 diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 8111d4528c..25f7644abe 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import get_channel_distances, Templates, ChannelSparsity -from spikeinterface.sortingcomponents.clustering.splitting_tools import split_clusters +from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters # from spikeinterface.sortingcomponents.clustering.merge import merge_clusters from spikeinterface.sortingcomponents.clustering.merging_tools import ( @@ -42,7 +42,7 @@ class IterativeISOSPLITClustering: # "split_radius_um": 40.0, "split_radius_um": 60.0, "recursive": True, - "recursive_depth": 5, + "recursive_depth": 3, "method_kwargs": { "clusterer": { "method": "isosplit", @@ -50,8 +50,8 @@ class IterativeISOSPLITClustering: # "n_init": 50, "min_cluster_size": 10, "max_iterations_per_pass": 500, - # "isocut_threshold": 2.0, - "isocut_threshold": 2.2, + "isocut_threshold": 2.0, + # "isocut_threshold": 2.2, }, "min_size_split": 25, # "n_pca_features": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py index 23f405531f..3e883470f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py @@ -308,9 +308,9 @@ def split( min_cluster_size = clustering_kwargs["min_cluster_size"] - # here the trick is that we do not except more than 4 to 5 clusters per iteration with a presplit of 10 + # here the trick is that we do not except more than 4 to 5 clusters per iteration, so n_init=15 is a good choice num_samples = final_features.shape[0] - n_init = 50 + n_init = 15 if n_init > (num_samples // min_cluster_size): # avoid warning in isosplit when sample_size is too small factor = min_cluster_size * 2 From 9f7aa02eba45029c36d46e519d5fc4eca71f85b4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 30 Oct 2025 15:48:22 +0100 Subject: [PATCH 39/47] start lupin --- src/spikeinterface/sorters/internal/lupin.py | 339 ++++++++++++++++++ .../sorters/internal/tests/test_lupin.py | 18 + src/spikeinterface/sorters/sorterlist.py | 2 + 3 files changed, 359 insertions(+) create mode 100644 src/spikeinterface/sorters/internal/lupin.py create mode 100644 src/spikeinterface/sorters/internal/tests/test_lupin.py diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py new file mode 100644 index 0000000000..bd7708a5f3 --- /dev/null +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +from .si_based import ComponentsBasedSorter + +from copy import deepcopy + +from spikeinterface.core import ( + get_noise_levels, + NumpySorting, + estimate_templates_with_accumulator, + Templates, + compute_sparsity, +) + +from spikeinterface.core.job_tools import fix_job_kwargs + +from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten +from spikeinterface.core.basesorting import minimum_spike_dtype + +from spikeinterface.sortingcomponents.tools import cache_preprocessing + + +import numpy as np + + +class LupinSorter(ComponentsBasedSorter): + """ + Gentleman thief sorter + """ + sorter_name = "lupin" + + _default_params = { + "apply_preprocessing": True, + "apply_motion_correction": False, + "motion_correction": {"preset": "dredge_fast"}, + "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "waveforms": { + "ms_before": 0.5, + "ms_after": 1.5, + "radius_um": 120.0, + }, + "filtering": { + "freq_min": 150.0, + "freq_max": 6000.0, + "ftype": "bessel", + "filter_order": 2, + }, + "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, + "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + "svd": {"n_components": 10}, + "clustering": { + "recursive_depth": 3, + }, + "templates": { + "ms_before": 2.0, + "ms_after": 3.0, + "max_spikes_per_unit": 400, + "sparsity_threshold": 1.5, + "min_snr": 2.5, + # "peak_shift_ms": 0.2, + }, + "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, + "job_kwargs": {}, + "save_array": True, + "debug": False, + } + + _params_description = { + "apply_preprocessing": "Apply internal preprocessing or not", + "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", + "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", + "filtering": "A dictonary containing filtering params: freq_min, freq_max", + "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + "svd": "A dictonary containing svd params: n_components", + "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", + "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", + "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", + "job_kwargs": "A dictionary containing job kwargs", + "save_array": "Save or not intermediate arrays", + } + + handle_multi_segment = True + + @classmethod + def get_sorter_version(cls): + return "2025.09" + + @classmethod + def _run_from_folder(cls, sorter_output_folder, params, verbose): + + from spikeinterface.sortingcomponents.matching import find_spikes_from_templates + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods + from spikeinterface.sortingcomponents.tools import remove_empty_templates + from spikeinterface.preprocessing import correct_motion + from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording + from spikeinterface.sortingcomponents.tools import clean_templates + + job_kwargs = params["job_kwargs"].copy() + job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs["progress_bar"] = verbose + + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + + num_chans = recording_raw.get_num_channels() + sampling_frequency = recording_raw.get_sampling_frequency() + + apply_cmr = num_chans >= 32 + + # preprocessing + if params["apply_preprocessing"]: + if params["apply_motion_correction"]: + rec_for_motion = recording_raw + if params["apply_preprocessing"]: + rec_for_motion = bandpass_filter( + rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32" + ) + if apply_cmr: + rec_for_motion = common_reference(rec_for_motion) + if verbose: + print("Start correct_motion()") + _, motion_info = correct_motion( + rec_for_motion, + folder=sorter_output_folder / "motion", + output_motion_info=True, + **params["motion_correction"], + ) + if verbose: + print("Done correct_motion()") + + recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20., dtype="float32") + if apply_cmr: + recording = common_reference(recording) + + if params["apply_motion_correction"]: + interpolate_motion_kwargs = dict( + border_mode="force_extrapolate", + spatial_interpolation_method="kriging", + sigma_um=20.0, + p=2, + ) + + recording = InterpolateMotionRecording( + recording, + motion_info["motion"], + **interpolate_motion_kwargs, + ) + + recording = zscore(recording, dtype="float32") + # whitening is really bad when dirft correction is applied and this changd nothing when no dirft + # recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) + + # used only if "folder" or "zarr" + cache_folder = sorter_output_folder / "cache_preprocessing" + recording = cache_preprocessing( + recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + ) + + noise_levels = np.ones(num_chans, dtype="float32") + else: + recording = recording_raw + noise_levels = get_noise_levels(recording, return_in_uV=False) + + # detection + detection_params = params["detection"].copy() + detection_params["noise_levels"] = noise_levels + all_peaks = detect_peaks( + recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs + ) + + if verbose: + print(f"detect_peaks(): {len(all_peaks)} peaks found") + + # selection + selection_params = params["selection"].copy() + n_peaks = params["selection"]["n_peaks_per_channel"] * num_chans + n_peaks = max(selection_params["min_n_peaks"], n_peaks) + peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) + + if verbose: + print(f"select_peaks(): {len(peaks)} peaks kept for clustering") + + # routing clustering params into the big IterativeISOSPLITClustering params tree + clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) + clustering_kwargs["peaks_svd"].update(params["waveforms"]) + clustering_kwargs["peaks_svd"].update(params["svd"]) + clustering_kwargs["split"].update(params["clustering"]) + if params["debug"]: + clustering_kwargs["debug_folder"] = sorter_output_folder + + # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": + # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None + # if not have_sisosplit6: + # raise ValueError( + # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" + # ) + + # whitenning do not improve in tdc2 + # recording_w = whiten(recording, mode="global") + + unit_ids, clustering_label, more_outs = find_clusters_from_peaks( + recording, + peaks, + method="iterative-isosplit", + method_kwargs=clustering_kwargs, + extra_outputs=True, + job_kwargs=job_kwargs, + ) + + new_peaks = peaks + + mask = clustering_label >= 0 + sorting_pre_peeler = NumpySorting.from_samples_and_labels( + new_peaks["sample_index"][mask], + clustering_label[mask], + sampling_frequency, + unit_ids=unit_ids, + ) + if verbose: + print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") + + recording_for_peeler = recording + + # if "templates" in more_outs: + # # No, bad idea because templates are too short + # # clustering also give templates + # templates = more_outs["templates"] + + # we recompute the template even if the clustering give it already because we use different ms_before/ms_after + nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) + nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) + + templates_array = estimate_templates_with_accumulator( + recording_for_peeler, + sorting_pre_peeler.to_spike_vector(), + sorting_pre_peeler.unit_ids, + nbefore, + nafter, + return_in_uV=False, + **job_kwargs, + ) + templates_dense = Templates( + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + channel_ids=recording_for_peeler.channel_ids, + unit_ids=sorting_pre_peeler.unit_ids, + sparsity_mask=None, + probe=recording_for_peeler.get_probe(), + is_in_uV=False, + ) + + + # sparsity is a mix between radius and + sparsity_threshold = params["templates"]["sparsity_threshold"] + radius_um = params["waveforms"]["radius_um"] + sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) + sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", + noise_levels=noise_levels, threshold=sparsity_threshold) + sparsity.mask = sparsity.mask & sparsity_snr.mask + templates = templates_dense.to_sparse(sparsity) + + templates = clean_templates( + templates, + sparsify_threshold=None, + noise_levels=noise_levels, + min_snr=params["templates"]["min_snr"], + max_jitter_ms=None, + remove_empty=True, + ) + + ## peeler + matching_method = params["matching"].pop("method") + gather_mode = params["matching"].pop("gather_mode", "memory") + matching_params = params["matching"].get("matching_kwargs", {}).copy() + if matching_method in ("tdc-peeler",): + matching_params["noise_levels"] = noise_levels + + pipeline_kwargs = dict(gather_mode=gather_mode) + if gather_mode == "npy": + pipeline_kwargs["folder"] = sorter_output_folder / "matching" + spikes = find_spikes_from_templates( + recording_for_peeler, + templates, + method=matching_method, + method_kwargs=matching_params, + pipeline_kwargs=pipeline_kwargs, + job_kwargs=job_kwargs, + ) + + final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) + final_spikes["sample_index"] = spikes["sample_index"] + final_spikes["unit_index"] = spikes["cluster_index"] + final_spikes["segment_index"] = spikes["segment_index"] + sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) + + ## DEBUG auto merge + auto_merge = True + if auto_merge: + from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus + + # max_distance_um = merging_params.get("max_distance_um", 50) + # merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) + + analyzer_final = final_cleaning_circus( + recording_for_peeler, + sorting, + templates, + similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1}, + sparsity_overlap=0.5, + censor_ms=3.0, + max_distance_um=50, + template_diff_thresh=np.arange(0.05, 0.4, 0.05), + debug_folder=None, + job_kwargs=job_kwargs, + ) + sorting = NumpySorting.from_sorting(analyzer_final.sorting) + + if params["save_array"]: + sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") + + np.save(sorter_output_folder / "noise_levels.npy", noise_levels) + np.save(sorter_output_folder / "all_peaks.npy", all_peaks) + np.save(sorter_output_folder / "peaks.npy", peaks) + np.save(sorter_output_folder / "clustering_label.npy", clustering_label) + np.save(sorter_output_folder / "spikes.npy", spikes) + templates.to_zarr(sorter_output_folder / "templates.zarr") + + # final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) + # final_spikes["sample_index"] = spikes["sample_index"] + # final_spikes["unit_index"] = spikes["cluster_index"] + # final_spikes["segment_index"] = spikes["segment_index"] + # sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) + + sorting = sorting.save(folder=sorter_output_folder / "sorting") + + return sorting diff --git a/src/spikeinterface/sorters/internal/tests/test_lupin.py b/src/spikeinterface/sorters/internal/tests/test_lupin.py new file mode 100644 index 0000000000..df2666be1d --- /dev/null +++ b/src/spikeinterface/sorters/internal/tests/test_lupin.py @@ -0,0 +1,18 @@ +import unittest + +from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite + +from spikeinterface.sorters import LupinSorter, run_sorter + +from pathlib import Path + + +class LupinSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): + SorterClass = LupinSorter + + +if __name__ == "__main__": + test = LupinSorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" + test.setUp() + test.test_with_run() diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index ed8ba7b6bc..cb1577f0be 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -24,6 +24,7 @@ from .internal.spyking_circus2 import Spykingcircus2Sorter from .internal.tridesclous2 import Tridesclous2Sorter from .internal.simplesorter import SimpleSorter +from .internal.lupin import LupinSorter sorter_full_list = [ # external @@ -49,6 +50,7 @@ Spykingcircus2Sorter, Tridesclous2Sorter, SimpleSorter, + LupinSorter, ] # archived From 60c9782a6173b67f4b5315cd67c2a59d2d4043be Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 31 Oct 2025 11:28:46 +0100 Subject: [PATCH 40/47] lupin wip --- src/spikeinterface/sorters/internal/lupin.py | 49 +++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index bd7708a5f3..762dc7f0c4 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -59,8 +59,9 @@ class LupinSorter(ComponentsBasedSorter): "min_snr": 2.5, # "peak_shift_ms": 0.2, }, - "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, + "matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {}, + "seed": None, "save_array": True, "debug": False, } @@ -84,11 +85,12 @@ class LupinSorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.09" + return "2025.11" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): + from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_recording from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -102,6 +104,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(job_kwargs) job_kwargs["progress_bar"] = verbose + seed = params["seed"] + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() @@ -150,7 +154,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = zscore(recording, dtype="float32") # whitening is really bad when dirft correction is applied and this changd nothing when no dirft - # recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) + recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" @@ -158,16 +162,32 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] ) - noise_levels = np.ones(num_chans, dtype="float32") + noise_levels = get_noise_levels(recording, return_in_uV=False) else: recording = recording_raw noise_levels = get_noise_levels(recording, return_in_uV=False) # detection detection_params = params["detection"].copy() - detection_params["noise_levels"] = noise_levels + # detection_params["noise_levels"] = noise_levels + + ms_before = params["templates"]["ms_before"] + ms_after = params["templates"]["ms_after"] + + prototype, few_waveforms, few_peaks = get_prototype_and_waveforms_from_recording( + recording, + n_peaks=10_000, + ms_before=ms_before, + ms_after=ms_after, + seed=seed, + noise_levels=noise_levels, + job_kwargs=job_kwargs, + ) + detection_params_ = detection_params.copy() + detection_params_["prototype"] = prototype + detection_params_["ms_before"] = ms_before all_peaks = detect_peaks( - recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs + recording, method="matched_filtering", method_kwargs=detection_params_, job_kwargs=job_kwargs ) if verbose: @@ -190,16 +210,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder - # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": - # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None - # if not have_sisosplit6: - # raise ValueError( - # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" - # ) - - # whitenning do not improve in tdc2 - # recording_w = whiten(recording, mode="global") - unit_ids, clustering_label, more_outs = find_clusters_from_peaks( recording, peaks, @@ -252,7 +262,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): is_in_uV=False, ) - # sparsity is a mix between radius and sparsity_threshold = params["templates"]["sparsity_threshold"] radius_um = params["waveforms"]["radius_um"] @@ -328,12 +337,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") - # final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) - # final_spikes["sample_index"] = spikes["sample_index"] - # final_spikes["unit_index"] = spikes["cluster_index"] - # final_spikes["segment_index"] = spikes["segment_index"] - # sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) - sorting = sorting.save(folder=sorter_output_folder / "sorting") return sorting From 85eabb371dba79f5db35a7e1d60f0c0cd3094ff6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 31 Oct 2025 11:29:48 +0100 Subject: [PATCH 41/47] tdc sc versions --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b7c6492f9c..e4a18d6248 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -86,7 +86,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.09" + return "2025.10" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 33c0d1bd66..b6bab7a1a1 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -84,7 +84,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.09" + return "2025.10" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): From f546475e045ad8f22b4a22b6264c7c6b6da83780 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 31 Oct 2025 16:40:40 +0100 Subject: [PATCH 42/47] Clean lupin parameters --- src/spikeinterface/sorters/internal/lupin.py | 198 +++++++++--------- .../sorters/internal/spyking_circus2.py | 13 +- .../sorters/internal/tridesclous2.py | 10 +- src/spikeinterface/sortingcomponents/tools.py | 79 +++++-- 4 files changed, 162 insertions(+), 138 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 762dc7f0c4..834d309ebd 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -17,7 +17,7 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing import numpy as np @@ -25,60 +25,62 @@ class LupinSorter(ComponentsBasedSorter): """ - Gentleman thief sorter + Gentleman thief spike sorter. + + This sorter is composed by pieces of code and ideas stolen everywhere : yass, tridesclous, spkyking-circus, kilosort. + It should be the best sorter we can build using spikeinterface.sortingcomponents """ sorter_name = "lupin" _default_params = { "apply_preprocessing": True, "apply_motion_correction": False, - "motion_correction": {"preset": "dredge_fast"}, - "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, - "waveforms": { - "ms_before": 0.5, - "ms_after": 1.5, - "radius_um": 120.0, - }, - "filtering": { - "freq_min": 150.0, - "freq_max": 6000.0, - "ftype": "bessel", - "filter_order": 2, - }, - "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, - "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 10}, - "clustering": { - "recursive_depth": 3, - }, - "templates": { - "ms_before": 2.0, - "ms_after": 3.0, - "max_spikes_per_unit": 400, - "sparsity_threshold": 1.5, - "min_snr": 2.5, - # "peak_shift_ms": 0.2, - }, - "matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"}, + "motion_correction_preset" : "dredge_fast", + "clustering_ms_before": 0.3, + "clustering_ms_after": 1.3, + "radius_um": 120., + "freq_min": 150.0, + "freq_max": 6000.0, + "cache_preprocessing_mode" : "auto", + "peak_sign": "neg", + "detect_threshold": 5, + "n_peaks_per_channel": 5000, + "n_svd_components": 10, + "clustering_recursive_depth": 3, + "ms_before": 2.0, + "ms_after": 3.0, + "sparsity_threshold": 1.5, + "template_min_snr": 2.5, + "gather_mode": "memory", "job_kwargs": {}, "seed": None, - "save_array": True, + "save_array": False, "debug": False, } _params_description = { "apply_preprocessing": "Apply internal preprocessing or not", - "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", - "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", - "filtering": "A dictonary containing filtering params: freq_min, freq_max", - "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", - "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", - "svd": "A dictonary containing svd params: n_components", - "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", - "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", - "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", - "job_kwargs": "A dictionary containing job kwargs", - "save_array": "Save or not intermediate arrays", + "apply_motion_correction": "Apply motion correction or not", + "motion_correction_preset": "Motion correction preset", + "clustering_ms_before": "Milliseconds before the spike peak for clustering", + "clustering_ms_after": "Milliseconds after the spike peak for clustering", + "radius_um": "Radius for sparsity", + "freq_min": "Low frequency", + "freq_max": "High frequency", + "peak_sign": "Sign of peaks neg/pos/both", + "detect_threshold": "Treshold for peak detection", + "n_peaks_per_channel": "Number of spike per channel for clustering", + "n_svd_components": "Number of SVD components for clustering", + "clustering_recursive_depth": "Clustering recussivity", + "ms_before": "Milliseconds before the spike peak for template matching", + "ms_after": "Milliseconds after the spike peak for template matching", + "sparsity_threshold": "Threshold to sparsify templates before template matching", + "template_min_snr": "Threshold to remove templates before template matching", + "gather_mode": "How to accumalte spike in matching : memory/npy", + "job_kwargs": "The famous and fabulous job_kwargs", + "seed": "Seed for random number", + "save_array": "Save or not intermediate arrays in the folder", + "debug": "Save debug files", } handle_multi_segment = True @@ -105,6 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["progress_bar"] = verbose seed = params["seed"] + radius_um = params["radius_um"] recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -129,12 +132,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): rec_for_motion, folder=sorter_output_folder / "motion", output_motion_info=True, - **params["motion_correction"], + preset=params["motion_correction_preset"], ) if verbose: print("Done correct_motion()") - recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20., dtype="float32") + recording = bandpass_filter(recording_raw, freq_min=params["freq_min"], freq_max=params["freq_max"], + ftype="bessel", filter_order=2, margin_ms=20., dtype="float32") + if apply_cmr: recording = common_reference(recording) @@ -152,28 +157,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **interpolate_motion_kwargs, ) - recording = zscore(recording, dtype="float32") - # whitening is really bad when dirft correction is applied and this changd nothing when no dirft - recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) + recording = whiten(recording, dtype="float32", mode="local", radius_um=radius_um) # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" - recording = cache_preprocessing( - recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + recording, cache_info = cache_preprocessing( + recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs, ) noise_levels = get_noise_levels(recording, return_in_uV=False) else: recording = recording_raw noise_levels = get_noise_levels(recording, return_in_uV=False) + cache_info = None # detection - detection_params = params["detection"].copy() - # detection_params["noise_levels"] = noise_levels - - ms_before = params["templates"]["ms_before"] - ms_after = params["templates"]["ms_after"] - + ms_before = params["ms_before"] + ms_after = params["ms_after"] prototype, few_waveforms, few_peaks = get_prototype_and_waveforms_from_recording( recording, n_peaks=10_000, @@ -183,33 +183,36 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): noise_levels=noise_levels, job_kwargs=job_kwargs, ) - detection_params_ = detection_params.copy() - detection_params_["prototype"] = prototype - detection_params_["ms_before"] = ms_before + detection_params = dict( + peak_sign=params["peak_sign"], + detect_threshold=params["detect_threshold"], + exclude_sweep_ms=1.5, + radius_um=radius_um/2., # half the svd radius is enough for detection + prototype=prototype, + ms_before=ms_before, + ) all_peaks = detect_peaks( - recording, method="matched_filtering", method_kwargs=detection_params_, job_kwargs=job_kwargs + recording, method="matched_filtering", method_kwargs=detection_params, job_kwargs=job_kwargs ) if verbose: print(f"detect_peaks(): {len(all_peaks)} peaks found") # selection - selection_params = params["selection"].copy() - n_peaks = params["selection"]["n_peaks_per_channel"] * num_chans - n_peaks = max(selection_params["min_n_peaks"], n_peaks) + n_peaks = max(params["n_peaks_per_channel"] * num_chans, 20_000) peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) - if verbose: print(f"select_peaks(): {len(peaks)} peaks kept for clustering") - # routing clustering params into the big IterativeISOSPLITClustering params tree + # Clustering clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) - clustering_kwargs["peaks_svd"].update(params["waveforms"]) - clustering_kwargs["peaks_svd"].update(params["svd"]) - clustering_kwargs["split"].update(params["clustering"]) + clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] + clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] + clustering_kwargs["peaks_svd"]["radius_um"] = params["radius_um"] + clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components"] + clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder - unit_ids, clustering_label, more_outs = find_clusters_from_peaks( recording, peaks, @@ -218,7 +221,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): extra_outputs=True, job_kwargs=job_kwargs, ) - new_peaks = peaks mask = clustering_label >= 0 @@ -231,19 +233,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") - recording_for_peeler = recording - - # if "templates" in more_outs: - # # No, bad idea because templates are too short - # # clustering also give templates - # templates = more_outs["templates"] - - # we recompute the template even if the clustering give it already because we use different ms_before/ms_after - nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) - nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) - + # Template + nbefore = int(ms_before * sampling_frequency / 1000.0) + nafter = int(ms_after * sampling_frequency / 1000.0) templates_array = estimate_templates_with_accumulator( - recording_for_peeler, + recording, sorting_pre_peeler.to_spike_vector(), sorting_pre_peeler.unit_ids, nbefore, @@ -255,16 +249,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, - channel_ids=recording_for_peeler.channel_ids, + channel_ids=recording.channel_ids, unit_ids=sorting_pre_peeler.unit_ids, sparsity_mask=None, - probe=recording_for_peeler.get_probe(), + probe=recording.get_probe(), is_in_uV=False, ) - # sparsity is a mix between radius and - sparsity_threshold = params["templates"]["sparsity_threshold"] - radius_um = params["waveforms"]["radius_um"] + sparsity_threshold = params["sparsity_threshold"] + radius_um = params["radius_um"] sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=sparsity_threshold) @@ -275,26 +268,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates, sparsify_threshold=None, noise_levels=noise_levels, - min_snr=params["templates"]["min_snr"], + min_snr=params["template_min_snr"], max_jitter_ms=None, remove_empty=True, ) - ## peeler - matching_method = params["matching"].pop("method") - gather_mode = params["matching"].pop("gather_mode", "memory") - matching_params = params["matching"].get("matching_kwargs", {}).copy() - if matching_method in ("tdc-peeler",): - matching_params["noise_levels"] = noise_levels - + # Template matching + gather_mode = params["gather_mode"] pipeline_kwargs = dict(gather_mode=gather_mode) if gather_mode == "npy": pipeline_kwargs["folder"] = sorter_output_folder / "matching" + spikes = find_spikes_from_templates( - recording_for_peeler, + recording, templates, - method=matching_method, - method_kwargs=matching_params, + method="wobble", + method_kwargs={}, pipeline_kwargs=pipeline_kwargs, job_kwargs=job_kwargs, ) @@ -305,16 +294,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): final_spikes["segment_index"] = spikes["segment_index"] sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) - ## DEBUG auto merge auto_merge = True + analyzer_final = None if auto_merge: + # TODO expose some of theses parameters from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus - # max_distance_um = merging_params.get("max_distance_um", 50) - # merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) - analyzer_final = final_cleaning_circus( - recording_for_peeler, + recording, sorting, templates, similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1}, @@ -329,14 +316,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["save_array"]: sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") - np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) np.save(sorter_output_folder / "peaks.npy", peaks) np.save(sorter_output_folder / "clustering_label.npy", clustering_label) np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") + if analyzer_final is not None: + analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") sorting = sorting.save(folder=sorter_output_folder / "sorting") + + del recording + clean_cache_preprocessing(cache_info) + return sorting diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e4a18d6248..637dc5bd66 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,6 +11,7 @@ from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, + clean_cache_preprocessing, get_shuffled_recording_slices, _set_optimal_chunk_size, ) @@ -182,7 +183,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): elif recording_w.check_serializability("pickle"): recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None) - recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"]) + recording_w, cache_info = cache_preprocessing(recording_w, job_kwargs=job_kwargs, **params["cache_preprocessing"]) ## Then, we are detecting peaks with a locally_exclusive method detection_method = params["detection"].get("method", "matched_filtering") @@ -444,16 +445,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"Kept {len(sorting.unit_ids)} units after final merging") - folder_to_delete = None - cache_mode = params["cache_preprocessing"].get("mode", "memory") - delete_cache = params["cache_preprocessing"].get("delete_cache", True) - - if cache_mode in ["folder", "zarr"] and delete_cache: - folder_to_delete = recording_w._kwargs["folder_path"] - del recording_w - if folder_to_delete is not None: - shutil.rmtree(folder_to_delete) + clean_cache_preprocessing(cache_info) sorting = sorting.save(folder=sorting_folder) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index b6bab7a1a1..09cbe72d72 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -20,7 +20,7 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing import numpy as np @@ -154,14 +154,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" - recording = cache_preprocessing( - recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + recording, cache_info = cache_preprocessing( + recording, folder=cache_folder, job_kwargs=job_kwargs, **params["cache_preprocessing"] ) noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw noise_levels = get_noise_levels(recording, return_in_uV=False) + cache_info = None # detection detection_params = params["detection"].copy() @@ -336,4 +337,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorter_output_folder / "sorting") + del recording, recording_for_peeler + clean_cache_preprocessing(cache_info) + return sorting diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index d4d1f0df67..b1cd9f835d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import shutil try: import psutil @@ -361,8 +362,22 @@ def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): return job_kwargs +def _check_cache_memory(recording, memory_limit, total_memory): + if total_memory is None: + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + return recording.get_total_memory_size() < memory_usage + else: + return False + else: + return recording.get_total_memory_size() < total_memory + + + + def cache_preprocessing( - recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs + recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, job_kwargs=None, folder=None, ): """ Cache the preprocessing of a recording object @@ -380,50 +395,70 @@ def cache_preprocessing( The total memory to use for the job in bytes delete_cache: bool If True, delete the cache after the job - **extra_kwargs: dict - The extra kwargs for the job Returns ------- recording: Recording The cached recording object + cache_info: dict + Dict containing info for cleaning cache """ - save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + + cache_info = dict( + mode=mode + ) if mode == "memory": if total_memory is None: - if HAVE_PSUTIL: - assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" - memory_usage = memory_limit * psutil.virtual_memory().available - if recording.get_total_memory_size() < memory_usage: - recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) - else: - import warnings - - warnings.warn("Recording too large to be preloaded in RAM...") - else: - import warnings - - warnings.warn("psutil is required to preload in memory given only a fraction of available memory") - else: - if recording.get_total_memory_size() < total_memory: + mem_ok = _check_cache_memory(recording, memory_limit, total_memory) + if mem_ok: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings warnings.warn("Recording too large to be preloaded in RAM...") + cache_info["mode"] = "no-cache" + elif mode == "folder": - recording = recording.save_to_folder(**extra_kwargs) + assert folder is not None, "cache_preprocessing(): folder must be given" + recording = recording.save_to_folder(folder=folder) + cache_info["folder"] = folder elif mode == "zarr": - recording = recording.save_to_zarr(**extra_kwargs) + assert folder is not None, "cache_preprocessing(): folder must be given" + recording = recording.save_to_zarr(folder=folder) + cache_info["folder"] = folder elif mode == "no-cache": recording = recording + elif mode == "auto": + mem_ok = _check_cache_memory(recording, memory_limit, total_memory) + if mem_ok: + # first try memory first + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + cache_info["mode"] = "memory" + elif folder is not None: + # then try folder + recording = recording.save_to_folder(folder=folder) + cache_info["mode"] = "folder" + else: + recording = recording + cache_info["mode"] = "no-cache" else: raise ValueError(f"cache_preprocessing() wrong mode={mode}") - return recording + return recording, cache_info + +def clean_cache_preprocessing(cache_info): + """ + Delete folder eventually created by cache_preprocessing(). + Important : the cached recording must be deleted first. + """ + if cache_info is None or "mode" not in cache_info: + return + if cache_info["mode"] in ("folder", "zarr"): + shutil.rmtree(cache_info["folder"]) def remove_empty_templates(templates): From 93888082433b7068c2a347177fd9b17ab9dc6365 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 3 Nov 2025 18:56:29 +0100 Subject: [PATCH 43/47] lupin whitten before motion correction --- src/spikeinterface/sorters/internal/lupin.py | 12 +++--------- src/spikeinterface/sortingcomponents/tools.py | 1 + 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index c64dac9dc6..82e131ac1e 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -143,6 +143,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if apply_cmr: recording = common_reference(recording) + recording = whiten(recording, dtype="float32", mode="local", radius_um=radius_um) + if params["apply_motion_correction"]: interpolate_motion_kwargs = dict( border_mode="force_extrapolate", @@ -157,8 +159,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **interpolate_motion_kwargs, ) - recording = whiten(recording, dtype="float32", mode="local", radius_um=radius_um) - # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" recording, cache_info = cache_preprocessing( @@ -234,10 +234,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") # Template - ## DEBUG - job_kwargs2 = job_kwargs.copy() - job_kwargs2['n_jobs'] = 5 - ## DEBUG nbefore = int(ms_before * sampling_frequency / 1000.0) nafter = int(ms_after * sampling_frequency / 1000.0) @@ -248,9 +244,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore, nafter, return_in_uV=False, - ## DEBUG - **job_kwargs2, - ## DEBUG + **job_kwargs, ) templates_dense = Templates( templates_array=templates_array, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index b1cd9f835d..1296d2fd47 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -442,6 +442,7 @@ def cache_preprocessing( # then try folder recording = recording.save_to_folder(folder=folder) cache_info["mode"] = "folder" + cache_info["folder"] = folder else: recording = recording cache_info["mode"] = "no-cache" From 02668d3fff9a1edfbe84b00dfdd1108f5de90866 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 6 Nov 2025 13:47:40 +0100 Subject: [PATCH 44/47] better sparsity for template in lupin sorter --- src/spikeinterface/core/__init__.py | 1 + src/spikeinterface/sorters/internal/lupin.py | 71 ++++++++++++------- .../sortingcomponents/clustering/tools.py | 54 +------------- src/spikeinterface/sortingcomponents/tools.py | 33 ++++++++- 4 files changed, 77 insertions(+), 82 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 44d805377f..24c64162ee 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -111,6 +111,7 @@ ) from .sorting_tools import ( spike_vector_to_spike_trains, + spike_vector_to_indices, random_spikes_selection, apply_merges_to_sorting, apply_splits_to_sorting, diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 82e131ac1e..a8a4f6750a 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -38,23 +38,26 @@ class LupinSorter(ComponentsBasedSorter): "motion_correction_preset" : "dredge_fast", "clustering_ms_before": 0.3, "clustering_ms_after": 1.3, - "radius_um": 120., + "whitening_radius_um": 100., + "detection_radius_um": 50., + "features_radius_um": 75., + "template_radius_um" : 100., "freq_min": 150.0, - "freq_max": 6000.0, + "freq_max": 7000.0, "cache_preprocessing_mode" : "auto", "peak_sign": "neg", "detect_threshold": 5, "n_peaks_per_channel": 5000, - "n_svd_components": 10, + "n_svd_components": 3, "clustering_recursive_depth": 3, - "ms_before": 2.0, - "ms_after": 3.0, + "ms_before": 1.0, + "ms_after": 2.5, "sparsity_threshold": 1.5, "template_min_snr": 2.5, "gather_mode": "memory", "job_kwargs": {}, "seed": None, - "save_array": False, + "save_array": True, "debug": False, } @@ -100,14 +103,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording - from spikeinterface.sortingcomponents.tools import clean_templates + from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label job_kwargs = params["job_kwargs"].copy() job_kwargs = fix_job_kwargs(job_kwargs) job_kwargs["progress_bar"] = verbose seed = params["seed"] - radius_um = params["radius_um"] recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -143,7 +145,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if apply_cmr: recording = common_reference(recording) - recording = whiten(recording, dtype="float32", mode="local", radius_um=radius_um) + recording = whiten(recording, dtype="float32", mode="local", radius_um=params["whitening_radius_um"], + # chunk_duration="2s", + # apply_mean=True, + # regularize=True, + # regularize_kwargs=dict(method="LedoitWolf"), + ) + if params["apply_motion_correction"]: interpolate_motion_kwargs = dict( @@ -187,7 +195,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): peak_sign=params["peak_sign"], detect_threshold=params["detect_threshold"], exclude_sweep_ms=1.5, - radius_um=radius_um/2., # half the svd radius is enough for detection + radius_um=params["detection_radius_um"], prototype=prototype, ms_before=ms_before, ) @@ -208,7 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] - clustering_kwargs["peaks_svd"]["radius_um"] = params["radius_um"] + clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] if params["debug"]: @@ -221,20 +229,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): extra_outputs=True, job_kwargs=job_kwargs, ) - new_peaks = peaks + mask = clustering_label >= 0 + kept_peaks = peaks[mask] + kept_labels = clustering_label[mask] + sorting_pre_peeler = NumpySorting.from_samples_and_labels( - new_peaks["sample_index"][mask], - clustering_label[mask], + kept_peaks["sample_index"], + kept_labels, sampling_frequency, unit_ids=unit_ids, ) if verbose: - print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") + print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found") - # Template + # preestimate the sparsity unsing peaks channel + spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) + sparsity, unit_locations = compute_sparsity_from_peaks_and_label(kept_peaks, spike_vector["unit_index"], + sorting_pre_peeler.unit_ids, recording, params["template_radius_um"]) + + # Template are sparse from radius using unit_location nbefore = int(ms_before * sampling_frequency / 1000.0) nafter = int(ms_after * sampling_frequency / 1000.0) templates_array = estimate_templates_with_accumulator( @@ -244,30 +260,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore, nafter, return_in_uV=False, + sparsity_mask=sparsity.mask, **job_kwargs, ) - templates_dense = Templates( + templates = Templates( templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, channel_ids=recording.channel_ids, unit_ids=sorting_pre_peeler.unit_ids, - sparsity_mask=None, + sparsity_mask=sparsity.mask, probe=recording.get_probe(), is_in_uV=False, ) - sparsity_threshold = params["sparsity_threshold"] - radius_um = params["radius_um"] - sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) - sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", - noise_levels=noise_levels, threshold=sparsity_threshold) - sparsity.mask = sparsity.mask & sparsity_snr.mask - templates = templates_dense.to_sparse(sparsity) - + # sparsity_threshold = params["sparsity_threshold"] + # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=params["features_radius_um"]) + # sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", + # noise_levels=noise_levels, threshold=sparsity_threshold) + # sparsity.mask = sparsity.mask & sparsity_snr.mask + # templates = templates_dense.to_sparse(sparsity) + + # this spasify more templates = clean_templates( templates, - sparsify_threshold=None, + sparsify_threshold=params["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["template_min_snr"], max_jitter_ms=None, diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9134ff1c5c..8d9f585237 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -97,59 +97,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_target return aligned_features, dont_have_channels - -# def compute_template_from_sparse( -# peaks, labels, labels_set, sparse_waveforms, sparse_target_mask, total_channels, peak_shifts=None -# ): -# """ -# Compute template average from single sparse waveforms buffer. - -# Parameters -# ---------- -# peaks - -# labels - -# labels_set - -# sparse_waveforms (or features) - -# sparse_target_mask - -# total_channels - -# peak_shifts - -# Returns -# ------- -# templates: numpy.array -# Templates shape : (len(labels_set), num_samples, total_channels) -# """ - -# # NOTE SAM I think this is wrong, we should remove - -# n = len(labels_set) - -# templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) - -# for i, label in enumerate(labels_set): -# peak_indices = np.flatnonzero(labels == label) - -# local_chans = np.unique(peaks["channel_index"][peak_indices]) -# target_channels = np.flatnonzero(np.all(sparse_target_mask[local_chans, :], axis=0)) - -# aligned_wfs, dont_have_channels = aggregate_sparse_features( -# peaks, peak_indices, sparse_waveforms, sparse_target_mask, target_channels -# ) - -# if peak_shifts is not None: -# apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) - -# templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) - -# return templates - - def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ Apply a shift a spike level to realign waveforms buffers. @@ -362,3 +309,4 @@ def get_templates_from_peaks_and_svd( ) return dense_templates, final_sparsity_mask + diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1296d2fd47..c2a20dbc84 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,6 +19,7 @@ from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.sorting_tools import spike_vector_to_indices, get_numba_vector_to_list_of_spiketrain def make_multi_method_doc(methods, ident=" "): @@ -535,6 +536,8 @@ def clean_templates( ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: + if templates.are_templates_sparse(): + templates = templates.to_dense() sparsity = compute_sparsity( templates, method="snr", @@ -542,8 +545,6 @@ def clean_templates( noise_levels=noise_levels, threshold=sparsify_threshold, ) - if templates.are_templates_sparse(): - templates = templates.to_dense() templates = templates.to_sparse(sparsity) ## We removed non empty templates @@ -575,3 +576,31 @@ def clean_templates( templates = templates.select_units(to_select) return templates + +def compute_sparsity_from_peaks_and_label(peaks, unit_indices, unit_ids, recording, radius_um): + """ + Compute the sparisty after clustering. + This uses the peak channel to compute the baricenter of cluster. + Then make a radius around it. + """ + # handle only 2D channels + channel_locations = recording.get_channel_locations()[:, :2] + num_units = unit_ids.size + num_chans = recording.channel_ids.size + + vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain() + indices = np.arange(unit_indices.size, dtype=np.int64) + list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units) + unit_locations = np.zeros((num_units, 2), dtype=float) + sparsity_mask = np.zeros((num_units, num_chans), dtype=bool) + for unit_ind in range(num_units): + spike_inds = list_of_spike_indices[unit_ind] + unit_chans, count = np.unique(peaks[spike_inds]["channel_index"], return_counts=True) + weights = count / np.sum(count) + unit_loc = np.average(channel_locations[unit_chans, :], weights=weights, axis=0) + unit_locations[unit_ind, :] = unit_loc + (chan_inds,) = np.nonzero(np.linalg.norm(channel_locations - unit_loc[None, :], axis=1) <= radius_um) + sparsity_mask[unit_ind, chan_inds] = True + + sparsity = ChannelSparsity(sparsity_mask, unit_ids, recording.channel_ids) + return sparsity, unit_locations \ No newline at end of file From 9ecf0f3ff6bfb7a945aad1abf0a0b64a4fe06c39 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 6 Nov 2025 14:33:57 +0100 Subject: [PATCH 45/47] lupin n_pca control --- src/spikeinterface/sorters/internal/lupin.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index a8a4f6750a..e071abe736 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -48,7 +48,8 @@ class LupinSorter(ComponentsBasedSorter): "peak_sign": "neg", "detect_threshold": 5, "n_peaks_per_channel": 5000, - "n_svd_components": 3, + "n_svd_components_per_channel": 5, + "n_pca_features": 3, "clustering_recursive_depth": 3, "ms_before": 1.0, "ms_after": 2.5, @@ -73,7 +74,8 @@ class LupinSorter(ComponentsBasedSorter): "peak_sign": "Sign of peaks neg/pos/both", "detect_threshold": "Treshold for peak detection", "n_peaks_per_channel": "Number of spike per channel for clustering", - "n_svd_components": "Number of SVD components for clustering", + "n_svd_components_per_channel": "Number of SVD components per channel for clustering", + "n_pca_features" : "Secondary PCA features reducation before local isosplit", "clustering_recursive_depth": "Clustering recussivity", "ms_before": "Milliseconds before the spike peak for template matching", "ms_after": "Milliseconds after the spike peak for template matching", @@ -217,8 +219,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] - clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components"] + clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] + clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] + + + if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder unit_ids, clustering_label, more_outs = find_clusters_from_peaks( From 301eb4c64eaaca35a22672621d9ff30199c87ba7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 14 Nov 2025 15:21:36 +0100 Subject: [PATCH 46/47] fix rmtree when errors for recording cache in lupin --- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 9d3964f8c1..3d88a027f6 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -47,7 +47,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 10}, + "svd": {"n_components": 5}, "clustering": { "recursive_depth": 3, }, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index c2a20dbc84..f0b9058717 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -460,7 +460,7 @@ def clean_cache_preprocessing(cache_info): if cache_info is None or "mode" not in cache_info: return if cache_info["mode"] in ("folder", "zarr"): - shutil.rmtree(cache_info["folder"]) + shutil.rmtree(cache_info["folder"], ignore_errors=True) def remove_empty_templates(templates): From 1a4821671b87297639d895151c8c65d9bfa8a48a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 14 Nov 2025 18:01:51 +0100 Subject: [PATCH 47/47] propagate template spasity from lupin to tdc2 --- src/spikeinterface/sorters/internal/lupin.py | 2 +- .../sorters/internal/tridesclous2.py | 92 ++++++++++++------- 2 files changed, 61 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index e071abe736..53511701d6 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -169,7 +169,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **interpolate_motion_kwargs, ) - # used only if "folder" or "zarr" + # Cache in mem or folder cache_folder = sorter_output_folder / "cache_preprocessing" recording, cache_info = cache_preprocessing( recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs, diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 3d88a027f6..cdb595f6ca 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -33,7 +33,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "apply_motion_correction": False, "motion_correction": {"preset": "dredge_fast"}, - "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "cache_preprocessing_mode" : "auto", "waveforms": { "ms_before": 0.5, "ms_after": 1.5, @@ -58,6 +58,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "sparsity_threshold": 1.5, "min_snr": 2.5, # "peak_shift_ms": 0.2, + "radius_um":100., }, "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {}, @@ -96,7 +97,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording - from spikeinterface.sortingcomponents.tools import clean_templates + from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label job_kwargs = params["job_kwargs"].copy() job_kwargs = fix_job_kwargs(job_kwargs) @@ -153,10 +154,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # whitening is really bad when dirft correction is applied and this changd nothing when no dirft # recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) - # used only if "folder" or "zarr" + # Cache in mem or folder cache_folder = sorter_output_folder / "cache_preprocessing" recording, cache_info = cache_preprocessing( - recording, folder=cache_folder, job_kwargs=job_kwargs, **params["cache_preprocessing"] + recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs, ) noise_levels = np.ones(num_chans, dtype="float32") @@ -211,24 +212,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - new_peaks = peaks - mask = clustering_label >= 0 + kept_peaks = peaks[mask] + kept_labels = clustering_label[mask] + sorting_pre_peeler = NumpySorting.from_samples_and_labels( - new_peaks["sample_index"][mask], - clustering_label[mask], + kept_peaks["sample_index"], + kept_labels, sampling_frequency, unit_ids=unit_ids, ) if verbose: - print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") + print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found") recording_for_peeler = recording - # if "templates" in more_outs: - # # No, bad idea because templates are too short - # # clustering also give templates - # templates = more_outs["templates"] + # preestimate the sparsity unsing peaks channel + spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) + sparsity, unit_locations = compute_sparsity_from_peaks_and_label(kept_peaks, spike_vector["unit_index"], + sorting_pre_peeler.unit_ids, recording, params["templates"]["radius_um"]) + # we recompute the template even if the clustering give it already because we use different ms_before/ms_after nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) @@ -241,37 +244,58 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore, nafter, return_in_uV=False, + sparsity_mask=sparsity.mask, **job_kwargs, ) - templates_dense = Templates( + # templates_dense = Templates( + # templates_array=templates_array, + # sampling_frequency=sampling_frequency, + # nbefore=nbefore, + # channel_ids=recording_for_peeler.channel_ids, + # unit_ids=sorting_pre_peeler.unit_ids, + # sparsity_mask=None, + # probe=recording_for_peeler.get_probe(), + # is_in_uV=False, + # ) + templates = Templates( templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, - channel_ids=recording_for_peeler.channel_ids, + channel_ids=recording.channel_ids, unit_ids=sorting_pre_peeler.unit_ids, - sparsity_mask=None, - probe=recording_for_peeler.get_probe(), + sparsity_mask=sparsity.mask, + probe=recording.get_probe(), is_in_uV=False, ) - - # sparsity is a mix between radius and - sparsity_threshold = params["templates"]["sparsity_threshold"] - radius_um = params["waveforms"]["radius_um"] - sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) - sparsity_snr = compute_sparsity( - templates_dense, - method="snr", - amplitude_mode="peak_to_peak", - noise_levels=noise_levels, - threshold=sparsity_threshold, - ) - sparsity.mask = sparsity.mask & sparsity_snr.mask - templates = templates_dense.to_sparse(sparsity) + # sparsity is a mix between radius and + # sparsity_threshold = params["templates"]["sparsity_threshold"] + # radius_um = params["waveforms"]["radius_um"] + # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) + # sparsity_snr = compute_sparsity( + # templates_dense, + # method="snr", + # amplitude_mode="peak_to_peak", + # noise_levels=noise_levels, + # threshold=sparsity_threshold, + # ) + # sparsity.mask = sparsity.mask & sparsity_snr.mask + # templates = templates_dense.to_sparse(sparsity) + + # templates = clean_templates( + # templates, + # sparsify_threshold=None, + # noise_levels=noise_levels, + # min_snr=params["templates"]["min_snr"], + # max_jitter_ms=None, + # remove_empty=True, + # ) + + # this spasify more templates = clean_templates( templates, - sparsify_threshold=None, + sparsify_threshold=params["templates"]["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["templates"]["min_snr"], max_jitter_ms=None, @@ -305,6 +329,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## DEBUG auto merge auto_merge = True + analyzer_final = None if auto_merge: from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus @@ -334,6 +359,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "clustering_label.npy", clustering_label) np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") + if analyzer_final is not None: + analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") + # final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) # final_spikes["sample_index"] = spikes["sample_index"]