diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 44d805377f..24c64162ee 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -111,6 +111,7 @@ ) from .sorting_tools import ( spike_vector_to_spike_trains, + spike_vector_to_indices, random_spikes_selection, apply_merges_to_sorting, apply_splits_to_sorting, diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py new file mode 100644 index 0000000000..53511701d6 --- /dev/null +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +from .si_based import ComponentsBasedSorter + +from copy import deepcopy + +from spikeinterface.core import ( + get_noise_levels, + NumpySorting, + estimate_templates_with_accumulator, + Templates, + compute_sparsity, +) + +from spikeinterface.core.job_tools import fix_job_kwargs + +from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten +from spikeinterface.core.basesorting import minimum_spike_dtype + +from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing + + +import numpy as np + + +class LupinSorter(ComponentsBasedSorter): + """ + Gentleman thief spike sorter. + + This sorter is composed by pieces of code and ideas stolen everywhere : yass, tridesclous, spkyking-circus, kilosort. + It should be the best sorter we can build using spikeinterface.sortingcomponents + """ + sorter_name = "lupin" + + _default_params = { + "apply_preprocessing": True, + "apply_motion_correction": False, + "motion_correction_preset" : "dredge_fast", + "clustering_ms_before": 0.3, + "clustering_ms_after": 1.3, + "whitening_radius_um": 100., + "detection_radius_um": 50., + "features_radius_um": 75., + "template_radius_um" : 100., + "freq_min": 150.0, + "freq_max": 7000.0, + "cache_preprocessing_mode" : "auto", + "peak_sign": "neg", + "detect_threshold": 5, + "n_peaks_per_channel": 5000, + "n_svd_components_per_channel": 5, + "n_pca_features": 3, + "clustering_recursive_depth": 3, + "ms_before": 1.0, + "ms_after": 2.5, + "sparsity_threshold": 1.5, + "template_min_snr": 2.5, + "gather_mode": "memory", + "job_kwargs": {}, + "seed": None, + "save_array": True, + "debug": False, + } + + _params_description = { + "apply_preprocessing": "Apply internal preprocessing or not", + "apply_motion_correction": "Apply motion correction or not", + "motion_correction_preset": "Motion correction preset", + "clustering_ms_before": "Milliseconds before the spike peak for clustering", + "clustering_ms_after": "Milliseconds after the spike peak for clustering", + "radius_um": "Radius for sparsity", + "freq_min": "Low frequency", + "freq_max": "High frequency", + "peak_sign": "Sign of peaks neg/pos/both", + "detect_threshold": "Treshold for peak detection", + "n_peaks_per_channel": "Number of spike per channel for clustering", + "n_svd_components_per_channel": "Number of SVD components per channel for clustering", + "n_pca_features" : "Secondary PCA features reducation before local isosplit", + "clustering_recursive_depth": "Clustering recussivity", + "ms_before": "Milliseconds before the spike peak for template matching", + "ms_after": "Milliseconds after the spike peak for template matching", + "sparsity_threshold": "Threshold to sparsify templates before template matching", + "template_min_snr": "Threshold to remove templates before template matching", + "gather_mode": "How to accumalte spike in matching : memory/npy", + "job_kwargs": "The famous and fabulous job_kwargs", + "seed": "Seed for random number", + "save_array": "Save or not intermediate arrays in the folder", + "debug": "Save debug files", + } + + handle_multi_segment = True + + @classmethod + def get_sorter_version(cls): + return "2025.11" + + @classmethod + def _run_from_folder(cls, sorter_output_folder, params, verbose): + + from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_recording + from spikeinterface.sortingcomponents.matching import find_spikes_from_templates + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods + from spikeinterface.sortingcomponents.tools import remove_empty_templates + from spikeinterface.preprocessing import correct_motion + from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording + from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label + + job_kwargs = params["job_kwargs"].copy() + job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs["progress_bar"] = verbose + + seed = params["seed"] + + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + + num_chans = recording_raw.get_num_channels() + sampling_frequency = recording_raw.get_sampling_frequency() + + apply_cmr = num_chans >= 32 + + # preprocessing + if params["apply_preprocessing"]: + if params["apply_motion_correction"]: + rec_for_motion = recording_raw + if params["apply_preprocessing"]: + rec_for_motion = bandpass_filter( + rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32" + ) + if apply_cmr: + rec_for_motion = common_reference(rec_for_motion) + if verbose: + print("Start correct_motion()") + _, motion_info = correct_motion( + rec_for_motion, + folder=sorter_output_folder / "motion", + output_motion_info=True, + preset=params["motion_correction_preset"], + ) + if verbose: + print("Done correct_motion()") + + recording = bandpass_filter(recording_raw, freq_min=params["freq_min"], freq_max=params["freq_max"], + ftype="bessel", filter_order=2, margin_ms=20., dtype="float32") + + if apply_cmr: + recording = common_reference(recording) + + recording = whiten(recording, dtype="float32", mode="local", radius_um=params["whitening_radius_um"], + # chunk_duration="2s", + # apply_mean=True, + # regularize=True, + # regularize_kwargs=dict(method="LedoitWolf"), + ) + + + if params["apply_motion_correction"]: + interpolate_motion_kwargs = dict( + border_mode="force_extrapolate", + spatial_interpolation_method="kriging", + sigma_um=20.0, + p=2, + ) + + recording = InterpolateMotionRecording( + recording, + motion_info["motion"], + **interpolate_motion_kwargs, + ) + + # Cache in mem or folder + cache_folder = sorter_output_folder / "cache_preprocessing" + recording, cache_info = cache_preprocessing( + recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs, + ) + + noise_levels = get_noise_levels(recording, return_in_uV=False) + else: + recording = recording_raw + noise_levels = get_noise_levels(recording, return_in_uV=False) + cache_info = None + + # detection + ms_before = params["ms_before"] + ms_after = params["ms_after"] + prototype, few_waveforms, few_peaks = get_prototype_and_waveforms_from_recording( + recording, + n_peaks=10_000, + ms_before=ms_before, + ms_after=ms_after, + seed=seed, + noise_levels=noise_levels, + job_kwargs=job_kwargs, + ) + detection_params = dict( + peak_sign=params["peak_sign"], + detect_threshold=params["detect_threshold"], + exclude_sweep_ms=1.5, + radius_um=params["detection_radius_um"], + prototype=prototype, + ms_before=ms_before, + ) + all_peaks = detect_peaks( + recording, method="matched_filtering", method_kwargs=detection_params, job_kwargs=job_kwargs + ) + + if verbose: + print(f"detect_peaks(): {len(all_peaks)} peaks found") + + # selection + n_peaks = max(params["n_peaks_per_channel"] * num_chans, 20_000) + peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) + if verbose: + print(f"select_peaks(): {len(peaks)} peaks kept for clustering") + + # Clustering + clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) + clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] + clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] + clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] + clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] + clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] + clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] + + + + if params["debug"]: + clustering_kwargs["debug_folder"] = sorter_output_folder + unit_ids, clustering_label, more_outs = find_clusters_from_peaks( + recording, + peaks, + method="iterative-isosplit", + method_kwargs=clustering_kwargs, + extra_outputs=True, + job_kwargs=job_kwargs, + ) + + + mask = clustering_label >= 0 + kept_peaks = peaks[mask] + kept_labels = clustering_label[mask] + + sorting_pre_peeler = NumpySorting.from_samples_and_labels( + kept_peaks["sample_index"], + kept_labels, + sampling_frequency, + unit_ids=unit_ids, + ) + if verbose: + print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found") + + + # preestimate the sparsity unsing peaks channel + spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) + sparsity, unit_locations = compute_sparsity_from_peaks_and_label(kept_peaks, spike_vector["unit_index"], + sorting_pre_peeler.unit_ids, recording, params["template_radius_um"]) + + # Template are sparse from radius using unit_location + nbefore = int(ms_before * sampling_frequency / 1000.0) + nafter = int(ms_after * sampling_frequency / 1000.0) + templates_array = estimate_templates_with_accumulator( + recording, + sorting_pre_peeler.to_spike_vector(), + sorting_pre_peeler.unit_ids, + nbefore, + nafter, + return_in_uV=False, + sparsity_mask=sparsity.mask, + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + channel_ids=recording.channel_ids, + unit_ids=sorting_pre_peeler.unit_ids, + sparsity_mask=sparsity.mask, + probe=recording.get_probe(), + is_in_uV=False, + ) + + # sparsity_threshold = params["sparsity_threshold"] + # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=params["features_radius_um"]) + # sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak", + # noise_levels=noise_levels, threshold=sparsity_threshold) + # sparsity.mask = sparsity.mask & sparsity_snr.mask + # templates = templates_dense.to_sparse(sparsity) + + # this spasify more + templates = clean_templates( + templates, + sparsify_threshold=params["sparsity_threshold"], + noise_levels=noise_levels, + min_snr=params["template_min_snr"], + max_jitter_ms=None, + remove_empty=True, + ) + + # Template matching + gather_mode = params["gather_mode"] + pipeline_kwargs = dict(gather_mode=gather_mode) + if gather_mode == "npy": + pipeline_kwargs["folder"] = sorter_output_folder / "matching" + + spikes = find_spikes_from_templates( + recording, + templates, + method="wobble", + method_kwargs={}, + pipeline_kwargs=pipeline_kwargs, + job_kwargs=job_kwargs, + ) + + final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) + final_spikes["sample_index"] = spikes["sample_index"] + final_spikes["unit_index"] = spikes["cluster_index"] + final_spikes["segment_index"] = spikes["segment_index"] + sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) + + auto_merge = True + analyzer_final = None + if auto_merge: + # TODO expose some of theses parameters + from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus + + analyzer_final = final_cleaning_circus( + recording, + sorting, + templates, + similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1}, + sparsity_overlap=0.5, + censor_ms=3.0, + max_distance_um=50, + template_diff_thresh=np.arange(0.05, 0.4, 0.05), + debug_folder=None, + job_kwargs=job_kwargs, + ) + sorting = NumpySorting.from_sorting(analyzer_final.sorting) + + if params["save_array"]: + sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") + np.save(sorter_output_folder / "noise_levels.npy", noise_levels) + np.save(sorter_output_folder / "all_peaks.npy", all_peaks) + np.save(sorter_output_folder / "peaks.npy", peaks) + np.save(sorter_output_folder / "clustering_label.npy", clustering_label) + np.save(sorter_output_folder / "spikes.npy", spikes) + templates.to_zarr(sorter_output_folder / "templates.zarr") + if analyzer_final is not None: + analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") + + sorting = sorting.save(folder=sorter_output_folder / "sorting") + + + del recording + clean_cache_preprocessing(cache_info) + + return sorting diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4ed6548ca1..7c1ea44f48 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,6 +11,7 @@ from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, + clean_cache_preprocessing, get_shuffled_recording_slices, _set_optimal_chunk_size, ) @@ -189,7 +190,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): elif recording_w.check_serializability("pickle"): recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None) - recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"]) + recording_w, cache_info = cache_preprocessing(recording_w, job_kwargs=job_kwargs, **params["cache_preprocessing"]) ## Then, we are detecting peaks with a locally_exclusive method detection_method = params["detection"].get("method", "matched_filtering") @@ -455,16 +456,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"Kept {len(sorting.unit_ids)} units after final merging") - folder_to_delete = None - cache_mode = params["cache_preprocessing"].get("mode", "memory") - delete_cache = params["cache_preprocessing"].get("delete_cache", True) - - if cache_mode in ["folder", "zarr"] and delete_cache: - folder_to_delete = recording_w._kwargs["folder_path"] - del recording_w - if folder_to_delete is not None: - shutil.rmtree(folder_to_delete) + clean_cache_preprocessing(cache_info) sorting = sorting.save(folder=sorting_folder) diff --git a/src/spikeinterface/sorters/internal/tests/test_lupin.py b/src/spikeinterface/sorters/internal/tests/test_lupin.py new file mode 100644 index 0000000000..df2666be1d --- /dev/null +++ b/src/spikeinterface/sorters/internal/tests/test_lupin.py @@ -0,0 +1,18 @@ +import unittest + +from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite + +from spikeinterface.sorters import LupinSorter, run_sorter + +from pathlib import Path + + +class LupinSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): + SorterClass = LupinSorter + + +if __name__ == "__main__": + test = LupinSorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" + test.setUp() + test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 22ec7dac5e..cdb595f6ca 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -20,7 +20,7 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing import numpy as np @@ -33,7 +33,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "apply_preprocessing": True, "apply_motion_correction": False, "motion_correction": {"preset": "dredge_fast"}, - "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "cache_preprocessing_mode" : "auto", "waveforms": { "ms_before": 0.5, "ms_after": 1.5, @@ -47,7 +47,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 10}, + "svd": {"n_components": 5}, "clustering": { "recursive_depth": 3, }, @@ -58,6 +58,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "sparsity_threshold": 1.5, "min_snr": 2.5, # "peak_shift_ms": 0.2, + "radius_um":100., }, "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {}, @@ -96,7 +97,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording - from spikeinterface.sortingcomponents.tools import clean_templates + from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label job_kwargs = params["job_kwargs"].copy() job_kwargs = fix_job_kwargs(job_kwargs) @@ -131,6 +132,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Done correct_motion()") recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20.0, dtype="float32") + if apply_cmr: recording = common_reference(recording) @@ -152,16 +154,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # whitening is really bad when dirft correction is applied and this changd nothing when no dirft # recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) - # used only if "folder" or "zarr" + # Cache in mem or folder cache_folder = sorter_output_folder / "cache_preprocessing" - recording = cache_preprocessing( - recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"] + recording, cache_info = cache_preprocessing( + recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs, ) noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw noise_levels = get_noise_levels(recording, return_in_uV=False) + cache_info = None # detection detection_params = params["detection"].copy() @@ -209,24 +212,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - new_peaks = peaks - mask = clustering_label >= 0 + kept_peaks = peaks[mask] + kept_labels = clustering_label[mask] + sorting_pre_peeler = NumpySorting.from_samples_and_labels( - new_peaks["sample_index"][mask], - clustering_label[mask], + kept_peaks["sample_index"], + kept_labels, sampling_frequency, unit_ids=unit_ids, ) if verbose: - print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found") + print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found") recording_for_peeler = recording - # if "templates" in more_outs: - # # No, bad idea because templates are too short - # # clustering also give templates - # templates = more_outs["templates"] + # preestimate the sparsity unsing peaks channel + spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) + sparsity, unit_locations = compute_sparsity_from_peaks_and_label(kept_peaks, spike_vector["unit_index"], + sorting_pre_peeler.unit_ids, recording, params["templates"]["radius_um"]) + # we recompute the template even if the clustering give it already because we use different ms_before/ms_after nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) @@ -239,36 +244,58 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore, nafter, return_in_uV=False, + sparsity_mask=sparsity.mask, **job_kwargs, ) - templates_dense = Templates( + # templates_dense = Templates( + # templates_array=templates_array, + # sampling_frequency=sampling_frequency, + # nbefore=nbefore, + # channel_ids=recording_for_peeler.channel_ids, + # unit_ids=sorting_pre_peeler.unit_ids, + # sparsity_mask=None, + # probe=recording_for_peeler.get_probe(), + # is_in_uV=False, + # ) + templates = Templates( templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, - channel_ids=recording_for_peeler.channel_ids, + channel_ids=recording.channel_ids, unit_ids=sorting_pre_peeler.unit_ids, - sparsity_mask=None, - probe=recording_for_peeler.get_probe(), + sparsity_mask=sparsity.mask, + probe=recording.get_probe(), is_in_uV=False, ) - # sparsity is a mix between radius and - sparsity_threshold = params["templates"]["sparsity_threshold"] - radius_um = params["waveforms"]["radius_um"] - sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) - sparsity_snr = compute_sparsity( - templates_dense, - method="snr", - amplitude_mode="peak_to_peak", - noise_levels=noise_levels, - threshold=sparsity_threshold, - ) - sparsity.mask = sparsity.mask & sparsity_snr.mask - templates = templates_dense.to_sparse(sparsity) + # sparsity is a mix between radius and + # sparsity_threshold = params["templates"]["sparsity_threshold"] + # radius_um = params["waveforms"]["radius_um"] + # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um) + # sparsity_snr = compute_sparsity( + # templates_dense, + # method="snr", + # amplitude_mode="peak_to_peak", + # noise_levels=noise_levels, + # threshold=sparsity_threshold, + # ) + # sparsity.mask = sparsity.mask & sparsity_snr.mask + # templates = templates_dense.to_sparse(sparsity) + + # templates = clean_templates( + # templates, + # sparsify_threshold=None, + # noise_levels=noise_levels, + # min_snr=params["templates"]["min_snr"], + # max_jitter_ms=None, + # remove_empty=True, + # ) + + # this spasify more templates = clean_templates( templates, - sparsify_threshold=None, + sparsify_threshold=params["templates"]["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["templates"]["min_snr"], max_jitter_ms=None, @@ -302,6 +329,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## DEBUG auto merge auto_merge = True + analyzer_final = None if auto_merge: from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus @@ -331,6 +359,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "clustering_label.npy", clustering_label) np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") + if analyzer_final is not None: + analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") + # final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) # final_spikes["sample_index"] = spikes["sample_index"] @@ -340,4 +371,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorter_output_folder / "sorting") + del recording, recording_for_peeler + clean_cache_preprocessing(cache_info) + return sorting diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index ed8ba7b6bc..cb1577f0be 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -24,6 +24,7 @@ from .internal.spyking_circus2 import Spykingcircus2Sorter from .internal.tridesclous2 import Tridesclous2Sorter from .internal.simplesorter import SimpleSorter +from .internal.lupin import LupinSorter sorter_full_list = [ # external @@ -49,6 +50,7 @@ Spykingcircus2Sorter, Tridesclous2Sorter, SimpleSorter, + LupinSorter, ] # archived diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index d4a7ad8b8b..746c3b784e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -176,6 +176,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): features, method="local_feature_clustering", debug_folder=debug_folder, + job_kwargs=job_kwargs, # job_kwargs=dict(n_jobs=1), **split_params, diff --git a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py index da1233771c..0411166de6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py @@ -283,7 +283,6 @@ def split( tsvd = TruncatedSVD(n_pca_features, random_state=seed) final_features = tsvd.fit_transform(flatten_features) - else: final_features = flatten_features tsvd = None diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 0608733fbb..2bc9701da9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -396,7 +396,6 @@ def merge( tsvd = TruncatedSVD(n_pca_features, random_state=seed) feat = tsvd.fit_transform(feat) - else: feat = feat tsvd = None diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9134ff1c5c..8d9f585237 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -97,59 +97,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_target return aligned_features, dont_have_channels - -# def compute_template_from_sparse( -# peaks, labels, labels_set, sparse_waveforms, sparse_target_mask, total_channels, peak_shifts=None -# ): -# """ -# Compute template average from single sparse waveforms buffer. - -# Parameters -# ---------- -# peaks - -# labels - -# labels_set - -# sparse_waveforms (or features) - -# sparse_target_mask - -# total_channels - -# peak_shifts - -# Returns -# ------- -# templates: numpy.array -# Templates shape : (len(labels_set), num_samples, total_channels) -# """ - -# # NOTE SAM I think this is wrong, we should remove - -# n = len(labels_set) - -# templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) - -# for i, label in enumerate(labels_set): -# peak_indices = np.flatnonzero(labels == label) - -# local_chans = np.unique(peaks["channel_index"][peak_indices]) -# target_channels = np.flatnonzero(np.all(sparse_target_mask[local_chans, :], axis=0)) - -# aligned_wfs, dont_have_channels = aggregate_sparse_features( -# peaks, peak_indices, sparse_waveforms, sparse_target_mask, target_channels -# ) - -# if peak_shifts is not None: -# apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) - -# templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) - -# return templates - - def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ Apply a shift a spike level to realign waveforms buffers. @@ -362,3 +309,4 @@ def get_templates_from_peaks_and_svd( ) return dense_templates, final_sparsity_mask + diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 7680321722..1314905aeb 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import shutil try: import psutil @@ -18,6 +19,7 @@ from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.sorting_tools import spike_vector_to_indices, get_numba_vector_to_list_of_spiketrain def make_multi_method_doc(methods, ident=" "): @@ -361,8 +363,22 @@ def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): return job_kwargs +def _check_cache_memory(recording, memory_limit, total_memory): + if total_memory is None: + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory().available + return recording.get_total_memory_size() < memory_usage + else: + return False + else: + return recording.get_total_memory_size() < total_memory + + + + def cache_preprocessing( - recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs + recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, job_kwargs=None, folder=None, ): """ Cache the preprocessing of a recording object @@ -380,50 +396,71 @@ def cache_preprocessing( The total memory to use for the job in bytes delete_cache: bool If True, delete the cache after the job - **extra_kwargs: dict - The extra kwargs for the job Returns ------- recording: Recording The cached recording object + cache_info: dict + Dict containing info for cleaning cache """ - save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + + cache_info = dict( + mode=mode + ) if mode == "memory": if total_memory is None: - if HAVE_PSUTIL: - assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" - memory_usage = memory_limit * psutil.virtual_memory().available - if recording.get_total_memory_size() < memory_usage: - recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) - else: - import warnings - - warnings.warn("Recording too large to be preloaded in RAM...") - else: - import warnings - - warnings.warn("psutil is required to preload in memory given only a fraction of available memory") - else: - if recording.get_total_memory_size() < total_memory: + mem_ok = _check_cache_memory(recording, memory_limit, total_memory) + if mem_ok: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: import warnings warnings.warn("Recording too large to be preloaded in RAM...") + cache_info["mode"] = "no-cache" + elif mode == "folder": - recording = recording.save_to_folder(**extra_kwargs) + assert folder is not None, "cache_preprocessing(): folder must be given" + recording = recording.save_to_folder(folder=folder) + cache_info["folder"] = folder elif mode == "zarr": - recording = recording.save_to_zarr(**extra_kwargs) + assert folder is not None, "cache_preprocessing(): folder must be given" + recording = recording.save_to_zarr(folder=folder) + cache_info["folder"] = folder elif mode == "no-cache": recording = recording + elif mode == "auto": + mem_ok = _check_cache_memory(recording, memory_limit, total_memory) + if mem_ok: + # first try memory first + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + cache_info["mode"] = "memory" + elif folder is not None: + # then try folder + recording = recording.save_to_folder(folder=folder) + cache_info["mode"] = "folder" + cache_info["folder"] = folder + else: + recording = recording + cache_info["mode"] = "no-cache" else: raise ValueError(f"cache_preprocessing() wrong mode={mode}") - return recording + return recording, cache_info + +def clean_cache_preprocessing(cache_info): + """ + Delete folder eventually created by cache_preprocessing(). + Important : the cached recording must be deleted first. + """ + if cache_info is None or "mode" not in cache_info: + return + if cache_info["mode"] in ("folder", "zarr"): + shutil.rmtree(cache_info["folder"], ignore_errors=True) def remove_empty_templates(templates): @@ -503,6 +540,8 @@ def clean_templates( ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: + if templates.are_templates_sparse(): + templates = templates.to_dense() sparsity = compute_sparsity( templates, method="snr", @@ -510,8 +549,6 @@ def clean_templates( noise_levels=noise_levels, threshold=sparsify_threshold, ) - if templates.are_templates_sparse(): - templates = templates.to_dense() templates = templates.to_sparse(sparsity) ## We removed non empty templates @@ -543,3 +580,31 @@ def clean_templates( templates = templates.select_units(to_select) return templates + +def compute_sparsity_from_peaks_and_label(peaks, unit_indices, unit_ids, recording, radius_um): + """ + Compute the sparisty after clustering. + This uses the peak channel to compute the baricenter of cluster. + Then make a radius around it. + """ + # handle only 2D channels + channel_locations = recording.get_channel_locations()[:, :2] + num_units = unit_ids.size + num_chans = recording.channel_ids.size + + vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain() + indices = np.arange(unit_indices.size, dtype=np.int64) + list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units) + unit_locations = np.zeros((num_units, 2), dtype=float) + sparsity_mask = np.zeros((num_units, num_chans), dtype=bool) + for unit_ind in range(num_units): + spike_inds = list_of_spike_indices[unit_ind] + unit_chans, count = np.unique(peaks[spike_inds]["channel_index"], return_counts=True) + weights = count / np.sum(count) + unit_loc = np.average(channel_locations[unit_chans, :], weights=weights, axis=0) + unit_locations[unit_ind, :] = unit_loc + (chan_inds,) = np.nonzero(np.linalg.norm(channel_locations - unit_loc[None, :], axis=1) <= radius_um) + sparsity_mask[unit_ind, chan_inds] = True + + sparsity = ChannelSparsity(sparsity_mask, unit_ids, recording.channel_ids) + return sparsity, unit_locations \ No newline at end of file