Skip to content

Commit 630b6c1

Browse files
authored
Merge branch 'main' into prepare_release
2 parents 08122dc + 04b67e2 commit 630b6c1

File tree

18 files changed

+267
-107
lines changed

18 files changed

+267
-107
lines changed

.github/scripts/test_kilosort4_ci.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"x_centers": 5,
7777
"binning_depth": 1,
7878
"drift_smoothing": [250, 250, 250],
79-
"artifact_threshold": 200,
79+
"artifact_threshold": 500,
8080
"ccg_threshold": 1e12,
8181
"acg_threshold": 1e12,
8282
"cluster_downsampling": 2,

src/spikeinterface/benchmark/benchmark_plot_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,9 @@ def _plot_performances_vs_metric(
478478
.get_performance()[performance_name]
479479
.to_numpy(dtype="float64")
480480
)
481-
all_xs.append(x)
482-
all_ys.append(y)
481+
mask = ~np.isnan(x) & ~np.isnan(y)
482+
all_xs.append(x[mask])
483+
all_ys.append(y[mask])
483484

484485
if with_sigmoid_fit:
485486
max_snr = max(np.max(x) for x in all_xs)

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,11 +1634,17 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar
16341634
return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs)
16351635
elif isinstance(input, dict):
16361636
params_, job_kwargs = split_job_kwargs(kwargs)
1637-
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
1637+
assert len(params_) == 0, (
1638+
"Too many arguments for SortingAnalyzer.compute_several_extensions(), "
1639+
f"please remove the arguments {set(params_)} from the compute function."
1640+
)
16381641
self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs)
16391642
elif isinstance(input, list):
16401643
params_, job_kwargs = split_job_kwargs(kwargs)
1641-
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
1644+
assert len(params_) == 0, (
1645+
"Too many arguments for SortingAnalyzer.compute_several_extensions(), "
1646+
f"please remove the arguments {set(params_)} from the compute function."
1647+
)
16421648
extensions = {k: {} for k in input}
16431649
if extension_params is not None:
16441650
for ext_name, ext_params in extension_params.items():

src/spikeinterface/postprocessing/template_similarity.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _merge_extension_data(
9494
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids
9595
)
9696

97-
new_similarity = compute_similarity_with_templates_array(
97+
new_similarity, _ = compute_similarity_with_templates_array(
9898
new_templates_array,
9999
all_templates_array,
100100
method=self.params["method"],
@@ -146,7 +146,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
146146
new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids
147147
)
148148

149-
new_similarity = compute_similarity_with_templates_array(
149+
new_similarity, _ = compute_similarity_with_templates_array(
150150
new_templates_array,
151151
all_templates_array,
152152
method=self.params["method"],
@@ -188,7 +188,7 @@ def _run(self, verbose=False):
188188
self.sorting_analyzer, return_in_uV=self.sorting_analyzer.return_in_uV
189189
)
190190
sparsity = self.sorting_analyzer.sparsity
191-
similarity = compute_similarity_with_templates_array(
191+
similarity, _ = compute_similarity_with_templates_array(
192192
templates_array,
193193
templates_array,
194194
method=self.params["method"],
@@ -393,7 +393,13 @@ def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsi
393393

394394

395395
def compute_similarity_with_templates_array(
396-
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
396+
templates_array,
397+
other_templates_array,
398+
method,
399+
support="union",
400+
num_shifts=0,
401+
sparsity=None,
402+
other_sparsity=None,
397403
):
398404

399405
if method == "cosine_similarity":
@@ -432,10 +438,11 @@ def compute_similarity_with_templates_array(
432438
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support
433439
)
434440

441+
lags = np.argmin(distances, axis=0) - num_shifts
435442
distances = np.min(distances, axis=0)
436443
similarity = 1 - distances
437444

438-
return similarity
445+
return similarity, lags
439446

440447

441448
def compute_template_similarity_by_pair(
@@ -445,7 +452,7 @@ def compute_template_similarity_by_pair(
445452
templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_in_uV=True)
446453
sparsity_1 = sorting_analyzer_1.sparsity
447454
sparsity_2 = sorting_analyzer_2.sparsity
448-
similarity = compute_similarity_with_templates_array(
455+
similarity, _ = compute_similarity_with_templates_array(
449456
templates_array_1,
450457
templates_array_2,
451458
method=method,

src/spikeinterface/postprocessing/tests/test_template_similarity.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ def test_compute_similarity_with_templates_array(params):
8282
templates_array = rng.random(size=(2, 20, 5))
8383
other_templates_array = rng.random(size=(4, 20, 5))
8484

85-
similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params)
85+
similarity, lags = compute_similarity_with_templates_array(templates_array, other_templates_array, **params)
8686
print(similarity.shape)
87+
print(lags)
8788

8889

8990
pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")
@@ -141,5 +142,5 @@ def test_equal_results_numba(params):
141142
test.cache_folder = Path("./cache_folder")
142143
test.test_extension(params=dict(method="l2"))
143144

144-
# params = dict(method="cosine", num_shifts=8)
145-
# test_compute_similarity_with_templates_array(params)
145+
params = dict(method="cosine", num_shifts=8)
146+
test_compute_similarity_with_templates_array(params)

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _dynamic_params(cls):
9595
# we skip some parameters that are not relevant for the user
9696
# n_chan_bin/sampling_frequency: retrieved from the recording
9797
# tmin/tmax: same ase time/frame_slice in SpikeInterface
98-
skip_main = ["n_chan_bin", "sampling_frequency", "tmin", "tmax"]
98+
skip_main = ["fs", "n_chan_bin", "tmin", "tmax"]
9999
default_params = {}
100100
default_params_descriptions = {}
101101
ks_params = ks.parameters.MAIN_PARAMETERS.copy()

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
_set_optimal_chunk_size,
1616
)
1717
from spikeinterface.core.basesorting import minimum_spike_dtype
18+
from spikeinterface.core import compute_sparsity
1819

1920

2021
class Spykingcircus2Sorter(ComponentsBasedSorter):
2122
sorter_name = "spykingcircus2"
2223

2324
_default_params = {
24-
"general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100},
25-
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
26-
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10},
25+
"general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0},
26+
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 20},
2727
"whitening": {"mode": "local", "regularize": False},
2828
"detection": {
2929
"method": "matched_filtering",
@@ -38,8 +38,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
3838
"motion_correction": {"preset": "dredge_fast"},
3939
"merging": {"max_distance_um": 50},
4040
"clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()},
41+
"cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None},
4142
"matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()},
4243
"apply_preprocessing": True,
44+
"apply_whitening": True,
4345
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
4446
"chunk_preprocessing": {"memory_limit": None},
4547
"multi_units_only": False,
@@ -85,7 +87,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
8587

8688
@classmethod
8789
def get_sorter_version(cls):
88-
return "2025.09"
90+
return "2025.10"
8991

9092
@classmethod
9193
def _run_from_folder(cls, sorter_output_folder, params, verbose):
@@ -114,30 +116,50 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
114116
num_channels = recording.get_num_channels()
115117
ms_before = params["general"].get("ms_before", 0.5)
116118
ms_after = params["general"].get("ms_after", 1.5)
117-
radius_um = params["general"].get("radius_um", 100)
119+
radius_um = params["general"].get("radius_um", 100.0)
118120
detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5)
119121
peak_sign = params["detection"].get("peak_sign", "neg")
120122
deterministic = params["deterministic_peaks_detection"]
121123
debug = params["debug"]
122124
seed = params["seed"]
123125
apply_preprocessing = params["apply_preprocessing"]
126+
apply_whitening = params["apply_whitening"]
124127
apply_motion_correction = params["apply_motion_correction"]
125128
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after))
126129

127130
## First, we are filtering the data
128131
filtering_params = params["filtering"].copy()
129132
if apply_preprocessing:
130133
if verbose:
131-
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
134+
if apply_whitening:
135+
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
136+
else:
137+
print("Preprocessing the recording (bandpass filtering + CMR)")
132138
recording_f = bandpass_filter(recording, **filtering_params, dtype="float32")
133-
if num_channels > 1:
139+
if num_channels >= 32:
134140
recording_f = common_reference(recording_f)
135141
else:
136142
if verbose:
137143
print("Skipping preprocessing (whitening only)")
138144
recording_f = recording
139145
recording_f.annotate(is_filtered=True)
140146

147+
if apply_whitening:
148+
## We need to whiten before the template matching step, to boost the results
149+
# TODO add , regularize=True chen ready
150+
whitening_kwargs = params["whitening"].copy()
151+
whitening_kwargs["dtype"] = "float32"
152+
whitening_kwargs["seed"] = params["seed"]
153+
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
154+
if num_channels == 1:
155+
whitening_kwargs["regularize"] = False
156+
if whitening_kwargs["regularize"]:
157+
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}
158+
whitening_kwargs["apply_mean"] = True
159+
recording_w = whiten(recording_f, **whitening_kwargs)
160+
else:
161+
recording_w = recording_f
162+
141163
valid_geometry = check_probe_for_drift_correction(recording_f)
142164
if apply_motion_correction:
143165
if not valid_geometry:
@@ -151,27 +173,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
151173
motion_correction_kwargs = params["motion_correction"].copy()
152174
motion_correction_kwargs.update({"folder": motion_folder})
153175
noise_levels = get_noise_levels(
154-
recording_f, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
176+
recording_w, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
155177
)
156178
motion_correction_kwargs["detect_kwargs"] = {"noise_levels": noise_levels}
157-
recording_f = correct_motion(recording_f, **motion_correction_kwargs, **job_kwargs)
179+
recording_w = correct_motion(recording_w, **motion_correction_kwargs, **job_kwargs)
158180
else:
159181
motion_folder = None
160182

161-
## We need to whiten before the template matching step, to boost the results
162-
# TODO add , regularize=True chen ready
163-
whitening_kwargs = params["whitening"].copy()
164-
whitening_kwargs["dtype"] = "float32"
165-
whitening_kwargs["seed"] = params["seed"]
166-
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
167-
if num_channels == 1:
168-
whitening_kwargs["regularize"] = False
169-
if whitening_kwargs["regularize"]:
170-
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}
171-
whitening_kwargs["apply_mean"] = True
172-
173-
recording_w = whiten(recording_f, **whitening_kwargs)
174-
175183
noise_levels = get_noise_levels(
176184
recording_w, return_in_uV=False, random_slices_kwargs={"seed": seed}, **job_kwargs
177185
)
@@ -325,18 +333,33 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
325333
if not clustering_from_svd:
326334
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording
327335

328-
templates = get_templates_from_peaks_and_recording(
336+
dense_templates = get_templates_from_peaks_and_recording(
329337
recording_w,
330338
selected_peaks,
331339
peak_labels,
332340
ms_before,
333341
ms_after,
334342
job_kwargs=job_kwargs,
335343
)
344+
345+
sparsity = compute_sparsity(dense_templates, method="radius", radius_um=radius_um)
346+
threshold = params["cleaning"].get("sparsify_threshold", None)
347+
if threshold is not None:
348+
sparsity_snr = compute_sparsity(
349+
dense_templates,
350+
method="snr",
351+
amplitude_mode="peak_to_peak",
352+
noise_levels=noise_levels,
353+
threshold=threshold,
354+
)
355+
sparsity.mask = sparsity.mask & sparsity_snr.mask
356+
357+
templates = dense_templates.to_sparse(sparsity)
358+
336359
else:
337360
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd
338361

339-
templates, _ = get_templates_from_peaks_and_svd(
362+
dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd(
340363
recording_w,
341364
selected_peaks,
342365
peak_labels,
@@ -348,16 +371,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
348371
operator="median",
349372
)
350373
# this release the peak_svd memmap file
374+
templates = dense_templates.to_sparse(new_sparse_mask)
351375

352376
del more_outs
353377

354-
templates = clean_templates(
355-
templates,
356-
noise_levels=noise_levels,
357-
min_snr=detect_threshold,
358-
max_jitter_ms=0.1,
359-
remove_empty=True,
360-
)
378+
cleaning_kwargs = params.get("cleaning", {}).copy()
379+
cleaning_kwargs["noise_levels"] = noise_levels
380+
cleaning_kwargs["remove_empty"] = True
381+
templates = clean_templates(templates, **cleaning_kwargs)
361382

362383
if verbose:
363384
print("Kept %d clean clusters" % len(templates.unit_ids))
@@ -416,7 +437,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
416437

417438
if sorting.get_non_empty_unit_ids().size > 0:
418439
final_analyzer = final_cleaning_circus(
419-
recording_w, sorting, templates, job_kwargs=job_kwargs, **merging_params
440+
recording_w,
441+
sorting,
442+
templates,
443+
noise_levels=noise_levels,
444+
job_kwargs=job_kwargs,
445+
**merging_params,
420446
)
421447
final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer")
422448

@@ -451,14 +477,15 @@ def final_cleaning_circus(
451477
max_distance_um=50,
452478
template_diff_thresh=np.arange(0.05, 0.5, 0.05),
453479
debug_folder=None,
454-
job_kwargs=None,
480+
noise_levels=None,
481+
job_kwargs=dict(),
455482
):
456483

457484
from spikeinterface.sortingcomponents.tools import create_sorting_analyzer_with_existing_templates
458485
from spikeinterface.curation.auto_merge import auto_merge_units
459486

460487
# First we compute the needed extensions
461-
analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates)
488+
analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates, noise_levels=noise_levels)
462489
analyzer.compute("unit_locations", method="center_of_mass", **job_kwargs)
463490
analyzer.compute("template_similarity", **similarity_kwargs)
464491

0 commit comments

Comments
 (0)