From ec0c1958dece2eea764235ea93dfedcc1ca74b45 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 7 Nov 2025 12:49:35 +0000 Subject: [PATCH 01/10] deal with batch_downsampling --- .../sorters/external/kilosort4.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index e8f8a4d9b0..4023f54207 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -307,9 +307,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if version.parse(ks_version) >= version.parse("4.0.34"): ops = ops[0] - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) + ( + n_chan_bin, + fs, + NT, + nt, + twav_min, + chan_map, + dtype, + do_CAR, + invert, + _, + _, + tmin, + tmax, + artifact, + _, + _, + *possibly_batch_downsampling, + ) = get_run_parameters(ops) + + batch_downsample_dict = {} + if len(possibly_batch_downsampling) > 0: + batch_downsample_dict["batch_downsampling"] = possibly_batch_downsampling[0] + + print(f"{batch_downsample_dict=}") # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: @@ -334,6 +356,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): tmax=tmax, artifact_threshold=artifact, file_object=file_object, + **batch_downsample_dict, ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) From a892bd61c2e0951a32fd2dc1667bc05b74bcd458 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 7 Nov 2025 12:53:01 +0000 Subject: [PATCH 02/10] remove print --- src/spikeinterface/sorters/external/kilosort4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 4023f54207..ecedf92efb 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -331,8 +331,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if len(possibly_batch_downsampling) > 0: batch_downsample_dict["batch_downsampling"] = possibly_batch_downsampling[0] - print(f"{batch_downsample_dict=}") - # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) From fe6b53ed562bcf7e9a6ff5f593149a53abb6301d Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 7 Nov 2025 16:18:13 +0000 Subject: [PATCH 03/10] update ks4 tests --- .github/scripts/test_kilosort4_ci.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 47bbd1f4d1..a0c0eaecae 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -116,6 +116,13 @@ PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20}) PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset") +if parse(kilosort.__version__) >= parse("4.1.2"): + PARAMS_TO_TEST_DICT.update({"batch_downsampling": 2}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling") + + PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_init_seed") + PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) @@ -328,6 +335,8 @@ def test_binary_filtered_arguments(self): "scale", "file_object", ] + if parse(kilosort.__version__) >= parse("4.1.2"): + expected_arguments += ["batch_downsampling"] self._check_arguments(BinaryFiltered, expected_arguments) From 1441fda0544c493f0ec5fcdf0cea032859cc5770 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 10:39:49 +0000 Subject: [PATCH 04/10] batch downsampling default to 1 --- .github/scripts/test_kilosort4_ci.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index a0c0eaecae..4599178d7b 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -117,7 +117,7 @@ PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset") if parse(kilosort.__version__) >= parse("4.1.2"): - PARAMS_TO_TEST_DICT.update({"batch_downsampling": 2}) + PARAMS_TO_TEST_DICT.update({"batch_downsampling": 1}) PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling") PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2}) From c881a62c1174c846d5d0945b3faea70b8dc479e6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Nov 2025 13:16:19 +0100 Subject: [PATCH 05/10] Update .github/scripts/test_kilosort4_ci.py --- .github/scripts/test_kilosort4_ci.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 4599178d7b..a0c0eaecae 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -117,7 +117,7 @@ PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset") if parse(kilosort.__version__) >= parse("4.1.2"): - PARAMS_TO_TEST_DICT.update({"batch_downsampling": 1}) + PARAMS_TO_TEST_DICT.update({"batch_downsampling": 2}) PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling") PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2}) From be50cb640c5471d6455ee580f5c9211c2171bf41 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 13:03:47 +0000 Subject: [PATCH 06/10] check to see if batch downsampling works with longer recs --- .github/scripts/test_kilosort4_ci.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index a0c0eaecae..6948614e42 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -177,7 +177,7 @@ def _get_ground_truth_recording(self): """ num_channels = 32 recording, _ = si.generate_ground_truth_recording( - durations=[5], + durations=[10], seed=0, num_channels=num_channels, num_units=5, From fb370fa23f1cf0f4c4531c03dcafe2fe1e58f8d2 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 14:06:40 +0000 Subject: [PATCH 07/10] skip test_kilosort4_main batch_downsample test --- .github/scripts/test_kilosort4_ci.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 6948614e42..9014251a21 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -81,6 +81,7 @@ "acg_threshold": 1e12, "cluster_downsampling": 2, "duplicate_spike_ms": 0.3, + "batch_downsampling": 2, } PARAMETERS_NOT_AFFECTING_RESULTS = [ @@ -177,7 +178,7 @@ def _get_ground_truth_recording(self): """ num_channels = 32 recording, _ = si.generate_ground_truth_recording( - durations=[10], + durations=[5], seed=0, num_channels=num_channels, num_units=5, @@ -360,6 +361,10 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp """ recording, paths = recording_and_paths param_key = parameter + + if param_key == "batch_downsampling": + return + param_value = PARAMS_TO_TEST_DICT[param_key] # Setup parameters for KS4 and run it natively From e9431c551db626e42de0416ba2ba8f6e400349b3 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 14:09:45 +0000 Subject: [PATCH 08/10] re-add cluster_init_seed --- .github/scripts/test_kilosort4_ci.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 9014251a21..68706e1bf1 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -82,6 +82,7 @@ "cluster_downsampling": 2, "duplicate_spike_ms": 0.3, "batch_downsampling": 2, + "cluster_init_seed": 2, } PARAMETERS_NOT_AFFECTING_RESULTS = [ From 37358a434e9ee384fe5c186caefe7b6a36bc6c9a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 14:41:37 +0000 Subject: [PATCH 09/10] dont give new params to old versions --- .github/scripts/test_kilosort4_ci.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 68706e1bf1..686a1def34 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -81,8 +81,6 @@ "acg_threshold": 1e12, "cluster_downsampling": 2, "duplicate_spike_ms": 0.3, - "batch_downsampling": 2, - "cluster_init_seed": 2, } PARAMETERS_NOT_AFFECTING_RESULTS = [ From 9715a2ecbcedfba692ec668f7111898031d23c97 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 11 Nov 2025 09:47:30 +0000 Subject: [PATCH 10/10] Add comment about batch_downsampling --- .github/scripts/test_kilosort4_ci.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 686a1def34..13b6e886f0 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -361,6 +361,8 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp recording, paths = recording_and_paths param_key = parameter + # Non-default batch_downsampling fails for short recordings, as there aren't + # enough batches. Since we test on a 5s recording, we skip it. if param_key == "batch_downsampling": return