Skip to content
Merged
9 changes: 9 additions & 0 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@
PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20})
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset")

if parse(kilosort.__version__) >= parse("4.1.2"):
PARAMS_TO_TEST_DICT.update({"batch_downsampling": 1})
PARAMETERS_NOT_AFFECTING_RESULTS.append("batch_downsampling")

PARAMS_TO_TEST_DICT.update({"cluster_init_seed": 2})
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_init_seed")


PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())

Expand Down Expand Up @@ -328,6 +335,8 @@ def test_binary_filtered_arguments(self):
"scale",
"file_object",
]
if parse(kilosort.__version__) >= parse("4.1.2"):
expected_arguments += ["batch_downsampling"]

self._check_arguments(BinaryFiltered, expected_arguments)

Expand Down
27 changes: 24 additions & 3 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,29 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if version.parse(ks_version) >= version.parse("4.0.34"):
ops = ops[0]

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)
(
n_chan_bin,
fs,
NT,
nt,
twav_min,
chan_map,
dtype,
do_CAR,
invert,
_,
_,
tmin,
tmax,
artifact,
_,
_,
*possibly_batch_downsampling,
) = get_run_parameters(ops)

batch_downsample_dict = {}
if len(possibly_batch_downsampling) > 0:
batch_downsample_dict["batch_downsampling"] = possibly_batch_downsampling[0]

# Set preprocessing and drift correction parameters
if not params["skip_kilosort_preprocessing"]:
Expand All @@ -334,6 +354,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
tmax=tmax,
artifact_threshold=artifact,
file_object=file_object,
**batch_downsample_dict,
)
ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None)
ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels()))
Expand Down
Loading