diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 47bbd1f4d1..13b6e886f0 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -116,6 +116,13 @@ PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20}) PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset") +if parse(kilosort.__version__) >= parse("4.1.2"): + PARAMS_TO_TEST_DICT.update({"batch_downsampling": 2}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling") + + PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_init_seed") + PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) @@ -328,6 +335,8 @@ def test_binary_filtered_arguments(self): "scale", "file_object", ] + if parse(kilosort.__version__) >= parse("4.1.2"): + expected_arguments += ["batch_downsampling"] self._check_arguments(BinaryFiltered, expected_arguments) @@ -351,6 +360,12 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp """ recording, paths = recording_and_paths param_key = parameter + + # Non-default batch_downsampling fails for short recordings, as there aren't + # enough batches. Since we test on a 5s recording, we skip it. + if param_key == "batch_downsampling": + return + param_value = PARAMS_TO_TEST_DICT[param_key] # Setup parameters for KS4 and run it natively diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index e8f8a4d9b0..ecedf92efb 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -307,9 +307,29 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if version.parse(ks_version) >= version.parse("4.0.34"): ops = ops[0] - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) + ( + n_chan_bin, + fs, + NT, + nt, + twav_min, + chan_map, + dtype, + do_CAR, + invert, + _, + _, + tmin, + tmax, + artifact, + _, + _, + *possibly_batch_downsampling, + ) = get_run_parameters(ops) + + batch_downsample_dict = {} + if len(possibly_batch_downsampling) > 0: + batch_downsample_dict["batch_downsampling"] = possibly_batch_downsampling[0] # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: @@ -334,6 +354,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): tmax=tmax, artifact_threshold=artifact, file_object=file_object, + **batch_downsample_dict, ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels()))