From 84b9a1779e832b89ec281d6a0ab89b5e8c5d7703 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 20:48:28 +0000 Subject: [PATCH 1/7] Add first version. --- .../working/load_kilosort_utils.py | 426 ++++++++++++++++++ .../working/test_peaks_from_ks.py | 74 +++ 2 files changed, 500 insertions(+) create mode 100644 src/spikeinterface/working/load_kilosort_utils.py create mode 100644 src/spikeinterface/working/test_peaks_from_ks.py diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py new file mode 100644 index 0000000000..aa3fb2babd --- /dev/null +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +from pathlib import Path +from spikeinterface.core import read_python +import numpy as np +import pandas as pd + +from scipy import stats + +# TODO: spike_times -> spike_indexes + + +def compute_spike_amplitude_and_depth( + sorter_output: str | Path, + localised_spikes_only, + exclude_noise, + gain: float | None = None, + localised_spikes_channel_cutoff: int = None, # TODO +) -> tuple[np.ndarray, ...]: + """ + Compute the amplitude and depth of all detected spikes from the kilosort output. + + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + + Parameters + ---------- + sorter_output : str | Path + Path to the kilosort run sorting output. + localised_spikes_only : bool + If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the + amplitude of the maximum loading channel) and which are close to the average depth for + the cluster are returned. + gain: float | None + If a float provided, the `spike_amplitudes` will be scaled by this gain. + localised_spikes_channel_cutoff : int + If `localised_spikes_only` is `True`, spikes that have less than half of the + maximum loading channel over a range of n channels are removed. + This sets the number of channels. + + Returns + ------- + spike_indexes : np.ndarray + (num_spikes,) array of spike indexes. + spike_amplitudes : np.ndarray + (num_spikes,) array of corresponding spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) array of corresponding depths (probe y-axis location). + + Notes + ----- + In `_template_positions_amplitudes` spike depths is calculated as simply the template + depth, for each spike (so it is the same for all spikes in a cluster). Here we need + to find the depth of each individual spike, using its low-dimensional projection. + `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. + Taking the first component, the subset of 32 channels associated with this + spike are indexed to get the actual channel locations (in um). Then, the channel + locations are weighted by their PC values. + """ + if isinstance(sorter_output, str): + sorter_output = Path(sorter_output) + + params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) + + if localised_spikes_only: + localised_templates = [] + + for idx, template in enumerate(params["templates"]): + max_channel = np.max(np.abs(params["templates"][idx, :, :])) + channels_over_threshold = np.max(np.abs(params["templates"][idx, :, :]), axis=0) > 0.5 * max_channel + channel_ids_over_threshold = np.where(channels_over_threshold)[0] + + if np.ptp(channel_ids_over_threshold) <= localised_spikes_channel_cutoff: + localised_templates.append(idx) + + localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) + + params["spike_templates"] = params["spike_templates"][localised_template_by_spike] + params["spike_indexes"] = params["spike_indexes"][localised_template_by_spike] + params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] + params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] + params["pc_features"] = params["pc_features"][localised_template_by_spike] + + # Compute spike depths + pc_features = params["pc_features"][:, 0, :] + pc_features[pc_features < 0] = 0 + + # Get the channel indexes corresponding to the 32 channels from the PC. + spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] + + ycoords = params["channel_positions"][:, 1] + spike_feature_ycoords = ycoords[spike_features_indices] + + spike_depths = np.sum(spike_feature_ycoords * pc_features**2, axis=1) / np.sum(pc_features**2, axis=1) + + spike_feature_coords = params["channel_positions"][spike_features_indices, :] + norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square + weighted_locs = spike_feature_coords * norm_weights[:, :, np.newaxis] + weighted_locs = np.sum(weighted_locs, axis=1) + # Amplitude is calculated for each spike as the template amplitude + # multiplied by the `template_scaling_amplitudes`. + + # Compute amplitudes, scale if required and drop un-localised spikes before returning. + spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes( + params["templates"], + params["whitening_matrix_inv"], + ycoords, + params["spike_templates"], + params["temp_scaling_amplitudes"], + ) + + if gain is not None: + spike_amplitudes *= gain + + max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) + spike_sites = max_site[params["spike_templates"]] + + if localised_spikes_only: + # Interpolate the channel ids to location. + # Remove spikes > 5 um from average position + # Above we already removed non-localized templates, but that on its own is insufficient. + # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient + # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere + # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. + # 3) just use depth. Probably go for that. check with others. + spike_depths = weighted_locs[:, 1] + b = stats.linregress(spike_depths, spike_sites).slope + i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this + + params["spike_indexes"] = params["spike_indexes"][i] + spike_amplitudes = spike_amplitudes[i] + weighted_locs = weighted_locs[i, :] + + return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything + + +def _filter_large_amplitude_spikes( + spike_times: np.ndarray, + spike_amplitudes: np.ndarray, + spike_depths: np.ndarray, + large_amplitude_only_segment_size, +) -> tuple[np.ndarray, ...]: + """ + Return spike properties with only the largest-amplitude spikes included. The probe + is split into egments, and within each segment the mean and std computed. + Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded + Splitting the probe is only done for the exclusion step, the returned array are flat. + + Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns + copies of these arrays containing only the large amplitude spikes. + """ + spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) + + segment_size_um = large_amplitude_only_segment_size + probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um + + for segment_left_edge in probe_segments_left_edges: + segment_right_edge = segment_left_edge + segment_size_um + + spikes_in_seg = np.where(np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge))[ + 0 + ] + spike_amps_in_seg = spike_amplitudes[spikes_in_seg] + is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) + + spike_bool[spikes_in_seg] = is_high_amplitude + + spike_times = spike_times[spike_bool] + spike_amplitudes = spike_amplitudes[spike_bool] + spike_depths = spike_depths[spike_bool] + + return spike_times, spike_amplitudes, spike_depths + + +def _template_positions_amplitudes( + templates: np.ndarray, + inverse_whitening_matrix: np.ndarray, + ycoords: np.ndarray, + spike_templates: np.ndarray, + template_scaling_amplitudes: np.ndarray, +) -> tuple[np.ndarray, ...]: + """ + Calculate the amplitude and depths of (unwhitened) templates and spikes. + Amplitude is calculated for each spike as the template amplitude + multiplied by the `template_scaling_amplitudes`. + + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + + Parameters + ---------- + templates : np.ndarray + (num_clusters, num_samples, num_channels) array of templates. + inverse_whitening_matrix: np.ndarray + Inverse of the whitening matrix used in KS preprocessing, used to + unwhiten templates. + ycoords : np.ndarray + (num_channels,) array of the y-axis (depth) channel positions. + spike_templates : np.ndarray + (num_spikes,) array indicating the template associated with each spike. + template_scaling_amplitudes : np.ndarray + (num_spikes,) array holding the scaling amplitudes, by which the + template was scaled to match each spike. + + Returns + ------- + spike_amplitudes : np.ndarray + (num_spikes,) array of the amplitude of each spike. + spike_depths : np.ndarray + (num_spikes,) array of the depth (probe y-axis) of each spike. Note + this is just the template depth for each spike (i.e. depth of all spikes + from the same cluster are identical). + template_amplitudes : np.ndarray + (num_templates,) Amplitude of each template, calculated as average of spike amplitudes. + template_depths : np.ndarray + (num_templates,) array of the depth of each template. + unwhite_templates : np.ndarray + Unwhitened templates (num_clusters, num_samples, num_channels). + trough_peak_durations : np.ndarray + (num_templates, ) array of durations from trough to peak for each template waveform + waveforms : np.ndarray + (num_templates, num_samples) Waveform of each template, taken as the signal on the maximum loading channel. + """ + # Unwhiten the template waveforms + unwhite_templates = np.zeros_like(templates) + for idx, template in enumerate(templates): + unwhite_templates[idx, :, :] = templates[idx, :, :] @ inverse_whitening_matrix + + # First, calculate the depth of each template from the amplitude + # on each channel by the center of mass method. + + # Take the max amplitude for each channel, then use the channel + # with most signal as template amplitude. Zero any small channel amplitudes. + template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1) + + template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) + + threshold_values = 0.3 * template_amplitudes_unscaled + template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 + + # Calculate the template depth as the center of mass based on channel amplitudes + template_depths = np.sum(template_amplitudes_per_channel * ycoords[np.newaxis, :], axis=1) / np.sum( + template_amplitudes_per_channel, axis=1 + ) + + # Next, find the depth of each spike based on its template. Recompute the template + # amplitudes as the average of the spike amplitudes ('since + # tempScalingAmps are equal mean for all templates') + spike_amplitudes = template_amplitudes_unscaled[spike_templates] * template_scaling_amplitudes + + # Take the average of all spike amplitudes to get actual template amplitudes + # (since tempScalingAmps are equal mean for all templates) + num_indices = templates.shape[0] + sum_per_index = np.zeros(num_indices, dtype=np.float64) + np.add.at(sum_per_index, spike_templates, spike_amplitudes) + counts = np.bincount(spike_templates, minlength=num_indices) + template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) + + # Each spike's depth is the depth of its template + spike_depths = template_depths[spike_templates] + + # Get channel with the largest amplitude (take that as the waveform) + max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + # Use template channel with max signal as waveform + waveforms = np.empty(templates.shape[:2]) + for idx, template in enumerate(templates): + waveforms[idx, :] = templates[idx, :, max_site[idx]] + + # Get trough-to-peak time for each template. Find the trough as the + # minimum signal for the template waveform. The duration (in + # samples) is the num samples from trough to the largest value + # following the trough. + waveform_trough = np.argmin(waveforms, axis=1) + + trough_peak_durations = np.zeros(waveforms.shape[0]) + for idx, tmp_max in enumerate(waveforms): + trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :]) + + return ( + spike_amplitudes, + spike_depths, + template_depths, + template_amplitudes, + unwhite_templates, + trough_peak_durations, + waveforms, + ) + + +def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: + """ + Loads the output of Kilosort into a `params` dict. + + This function was ported from Nick Steinmetz's `spikes` repository MATLAB + code, https://github.com/cortex-lab/spikes + + Parameters + ---------- + sorter_output : Path + Path to the kilosort run sorting output. + exclude_noise : bool + If `True`, units labelled as "noise` are removed from all + returned arrays (i.e. both units and associated spikes are dropped). + load_pcs : bool + If `True`, principal component (PC) features are loaded. + + Parameters + ---------- + params : dict + A dictionary of parameters combining both the kilosort `params.py` + file as data loaded from `npy` files. The contents of the `npy` + files can be found in the Phy documentation. + + Notes + ----- + When merging and splitting in `Phy`, all changes are made to the + `spike_clusters.npy` (cluster assignment per spike) and `cluster_groups` + csv/tsv which contains the quality assignment (e.g. "noise") for each cluster. + As this function strips the spikes and units based on only these two + data structures, they will work following manual reassignment in Phy. + """ + sorter_output = Path(sorter_output) + + params = read_python(sorter_output / "params.py") + + spike_indexes = np.load(sorter_output / "spike_times.npy") + spike_templates = np.load(sorter_output / "spike_templates.npy") + + if (clusters_path := sorter_output / "spike_clusters.csv").is_dir(): + spike_clusters = np.load(clusters_path) + else: + spike_clusters = spike_templates.copy() + + temp_scaling_amplitudes = np.load(sorter_output / "amplitudes.npy") + + if load_pcs: + pc_features = np.load(sorter_output / "pc_features.npy") + pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy") + else: + pc_features = pc_features_indices = None + + # This makes the assumption that there will never be different .csv and .tsv files + # in the same sorter output (this should never happen, there will never even be two). + # Though can be saved as .tsv, it seems the .csv is also tab formatted as far as pandas is concerned. + if exclude_noise and ( + (cluster_path := sorter_output / "cluster_groups.csv").is_file() + or (cluster_path := sorter_output / "cluster_group.tsv").is_file() + ): + cluster_ids, cluster_groups = _load_cluster_groups(cluster_path) + + noise_cluster_ids = cluster_ids[cluster_groups == 0] + not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids) + + spike_indexes = spike_indexes[not_noise_clusters_by_spike] + spike_templates = spike_templates[not_noise_clusters_by_spike] + temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike] + + if load_pcs: + pc_features = pc_features[not_noise_clusters_by_spike, :, :] + + spike_clusters = spike_clusters[not_noise_clusters_by_spike] + cluster_ids = cluster_ids[cluster_groups != 0] + cluster_groups = cluster_groups[cluster_groups != 0] + else: + cluster_ids = np.unique(spike_clusters) + cluster_groups = 3 * np.ones(cluster_ids.size) + + new_params = { + "spike_indexes": spike_indexes.squeeze(), + "spike_templates": spike_templates.squeeze(), + "spike_clusters": spike_clusters.squeeze(), + "pc_features": pc_features, + "pc_features_indices": pc_features_indices, + "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), + "cluster_ids": cluster_ids, + "cluster_groups": cluster_groups, + "channel_positions": np.load(sorter_output / "channel_positions.npy"), + "templates": np.load(sorter_output / "templates.npy"), + "whitening_matrix_inv": np.load(sorter_output / "whitening_mat_inv.npy"), + } + params.update(new_params) + + return params + + +def _load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: + """ + Load kilosort `cluster_groups` file, that contains a table of + quality assignments, one per unit. These can be "noise", "mua", "good" + or "unsorted". + + There is some slight formatting differences between the `.tsv` and `.csv` + versions, presumably from different kilosort versions. + + This function was ported from Nick Steinmetz's `spikes` repository MATLAB code, + https://github.com/cortex-lab/spikes + + Parameters + ---------- + cluster_path : Path + The full filepath to the `cluster_groups` tsv or csv file. + + Returns + ------- + cluster_ids : np.ndarray + (num_clusters,) Array of (integer) unit IDs. + + cluster_groups : np.ndarray + (num_clusters,) Array of (integer) unit quality assignments, see code + below for mapping to "noise", "mua", "good" and "unsorted". + """ + cluster_groups_table = pd.read_csv(cluster_path, sep="\t") + + group_key = cluster_groups_table.columns[1] # "groups" (csv) or "KSLabel" (tsv) + + for key, _id in zip( + ["noise", "mua", "good", "unsorted"], + ["0", "1", "2", "3"], # required as str to avoid pandas replace downcast FutureWarning + ): + cluster_groups_table[group_key] = cluster_groups_table[group_key].replace(key, _id) + + cluster_ids = cluster_groups_table["cluster_id"].to_numpy() + cluster_groups = cluster_groups_table[group_key].astype(int).to_numpy() + + return cluster_ids, cluster_groups diff --git a/src/spikeinterface/working/test_peaks_from_ks.py b/src/spikeinterface/working/test_peaks_from_ks.py new file mode 100644 index 0000000000..586f98e9c9 --- /dev/null +++ b/src/spikeinterface/working/test_peaks_from_ks.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import spikeinterface.full as si +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import localize_peaks +import numpy as np +from spikeinterface.core.node_pipeline import ( + base_peak_dtype, +) +from spikeinterface.postprocessing.unit_locations import ( + dtype_localize_by_method, +) +import matplotlib.pyplot as plt +from load_kilosort_utils import compute_spike_amplitude_and_depth + + +recording, sorting = si.generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=30000.0, +) +# job_kwargs = dict(n_jobs=2, chunk_size=10000, progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_size=10000, progress_bar=True) + +if False: + peaks_ = detect_peaks( + recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs + ) + + list_locations = [] + + peak_locations = localize_peaks(recording, peaks_, method="center_of_mass", **job_kwargs) +""" +dtype=[('sample_index', ' Date: Wed, 13 Nov 2024 20:58:26 +0000 Subject: [PATCH 2/7] Add some notes. --- .../working/load_kilosort_utils.py | 55 ++++--------------- 1 file changed, 10 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index aa3fb2babd..3f50700d66 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth( Notes ----- - In `_template_positions_amplitudes` spike depths is calculated as simply the template + In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template depth, for each spike (so it is the same for all spikes in a cluster). Here we need to find the depth of each individual spike, using its low-dimensional projection. `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. @@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth( # multiplied by the `template_scaling_amplitudes`. # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = _template_positions_amplitudes( + spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes( params["templates"], params["whitening_matrix_inv"], ycoords, @@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth( if gain is not None: spike_amplitudes *= gain + max_site = np.argmax( + np.max(np.abs(templates), axis=1), axis=1 + ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here. max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) spike_sites = max_site[params["spike_templates"]] + # TODO: here the max site is the same for all spikes from the same template. + # is this the case for spikeinterface? Should we estimate max-site per spike from + # the PCs? + if localised_spikes_only: # Interpolate the channel ids to location. # Remove spikes > 5 um from average position @@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth( return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything -def _filter_large_amplitude_spikes( - spike_times: np.ndarray, - spike_amplitudes: np.ndarray, - spike_depths: np.ndarray, - large_amplitude_only_segment_size, -) -> tuple[np.ndarray, ...]: - """ - Return spike properties with only the largest-amplitude spikes included. The probe - is split into egments, and within each segment the mean and std computed. - Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded - Splitting the probe is only done for the exclusion step, the returned array are flat. - - Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns - copies of these arrays containing only the large amplitude spikes. - """ - spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) - - segment_size_um = large_amplitude_only_segment_size - probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um - - for segment_left_edge in probe_segments_left_edges: - segment_right_edge = segment_left_edge + segment_size_um - - spikes_in_seg = np.where(np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge))[ - 0 - ] - spike_amps_in_seg = spike_amplitudes[spikes_in_seg] - is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) - - spike_bool[spikes_in_seg] = is_high_amplitude - - spike_times = spike_times[spike_bool] - spike_amplitudes = spike_amplitudes[spike_bool] - spike_depths = spike_depths[spike_bool] - - return spike_times, spike_amplitudes, spike_depths - - -def _template_positions_amplitudes( +def get_template_info_and_spike_amplitudes( templates: np.ndarray, inverse_whitening_matrix: np.ndarray, ycoords: np.ndarray, @@ -256,9 +225,6 @@ def _template_positions_amplitudes( counts = np.bincount(spike_templates, minlength=num_indices) template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) - # Each spike's depth is the depth of its template - spike_depths = template_depths[spike_templates] - # Get channel with the largest amplitude (take that as the waveform) max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) @@ -279,7 +245,6 @@ def _template_positions_amplitudes( return ( spike_amplitudes, - spike_depths, template_depths, template_amplitudes, unwhite_templates, From 02ec55a3006f97d929cc1ba7325691faf7967b53 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 14 Nov 2024 13:41:24 +0000 Subject: [PATCH 3/7] Some refactoring, tidying up. --- .../working/load_kilosort_utils.py | 175 ++++---- .../working/plot_kilosort_drift_map.py | 403 ++++++++++++++++++ 2 files changed, 496 insertions(+), 82 deletions(-) create mode 100644 src/spikeinterface/working/plot_kilosort_drift_map.py diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 3f50700d66..2423bc02a0 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -8,6 +8,12 @@ from scipy import stats # TODO: spike_times -> spike_indexes +""" +Notes +----- +- not everything is used for current purposes +- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. +""" def compute_spike_amplitude_and_depth( @@ -75,53 +81,58 @@ def compute_spike_amplitude_and_depth( localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) - params["spike_templates"] = params["spike_templates"][localised_template_by_spike] - params["spike_indexes"] = params["spike_indexes"][localised_template_by_spike] - params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] - params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] - params["pc_features"] = params["pc_features"][localised_template_by_spike] + _strip_spikes(params, localised_template_by_spike) # Compute spike depths - pc_features = params["pc_features"][:, 0, :] + pc_features = params["pc_features"][:, 0, :] # Do this compute pc_features[pc_features < 0] = 0 - # Get the channel indexes corresponding to the 32 channels from the PC. - spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] + # Some spikes do not load at all onto the first PC. To avoid biasing the + # dataset by removing these, we repeat the above for the next PC, + # to compute distances for neurons that do not load onto the 1st PC. + # This is not ideal at all, it would be much better to a) find the + # max value for each channel on each of the PCs (i.e. basis vectors). + # Then recompute the estimated waveform peak on each channel by + # summing the PCs by their respective weights. However, the PC basis + # vectors themselves do not appear to be output by KS. + no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0) + + pc_features_2 = params["pc_features"][:, 1, :] + pc_features_2[pc_features_2 < 0] = 0 - ycoords = params["channel_positions"][:, 1] - spike_feature_ycoords = ycoords[spike_features_indices] + pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes] - spike_depths = np.sum(spike_feature_ycoords * pc_features**2, axis=1) / np.sum(pc_features**2, axis=1) + if any(np.sum(pc_features, axis=1) == 0): + raise RuntimeError( + "Some spikes do not load at all onto the first" + "or second principal component. It is necessary" + "to extend this code section to handle more components." + ) + # Get the channel indexes corresponding to the 32 channels from the PC. + spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] + + # Compute the spike locations as the center of mass of the PC scores spike_feature_coords = params["channel_positions"][spike_features_indices, :] norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square - weighted_locs = spike_feature_coords * norm_weights[:, :, np.newaxis] - weighted_locs = np.sum(weighted_locs, axis=1) + spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis] + spike_locations = np.sum(spike_locations, axis=1) + + # TODO: now max site per spike is computed from PCs, not as the channel max site as previous + spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)] + # Amplitude is calculated for each spike as the template amplitude # multiplied by the `template_scaling_amplitudes`. - - # Compute amplitudes, scale if required and drop un-localised spikes before returning. - spike_amplitudes, _, _, _, unwhite_templates, *_ = get_template_info_and_spike_amplitudes( + template_amplitudes_unscaled, *_ = get_unwhite_template_info( params["templates"], params["whitening_matrix_inv"], - ycoords, - params["spike_templates"], - params["temp_scaling_amplitudes"], + params["channel_positions"], ) + spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] if gain is not None: spike_amplitudes *= gain - max_site = np.argmax( - np.max(np.abs(templates), axis=1), axis=1 - ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here. - max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) - spike_sites = max_site[params["spike_templates"]] - - # TODO: here the max site is the same for all spikes from the same template. - # is this the case for spikeinterface? Should we estimate max-site per spike from - # the PCs? - if localised_spikes_only: # Interpolate the channel ids to location. # Remove spikes > 5 um from average position @@ -130,23 +141,32 @@ def compute_spike_amplitude_and_depth( # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. # 3) just use depth. Probably go for that. check with others. - spike_depths = weighted_locs[:, 1] + spike_depths = spike_locations[:, 1] b = stats.linregress(spike_depths, spike_sites).slope i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this params["spike_indexes"] = params["spike_indexes"][i] spike_amplitudes = spike_amplitudes[i] - weighted_locs = weighted_locs[i, :] + spike_locations = spike_locations[i, :] + + return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites - return params["spike_indexes"], spike_amplitudes, weighted_locs, spike_sites # TODO: rename everything +def _strip_spikes_in_place(params, indices): + """ """ + params["spike_templates"] = params["spike_templates"][ + indices + ] # TODO: make an function for this. because we do this a lot + params["spike_indexes"] = params["spike_indexes"][indices] + params["spike_clusters"] = params["spike_clusters"][indices] + params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices] + params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices -def get_template_info_and_spike_amplitudes( + +def get_unwhite_template_info( templates: np.ndarray, inverse_whitening_matrix: np.ndarray, - ycoords: np.ndarray, - spike_templates: np.ndarray, - template_scaling_amplitudes: np.ndarray, + channel_positions: np.ndarray, ) -> tuple[np.ndarray, ...]: """ Calculate the amplitude and depths of (unwhitened) templates and spikes. @@ -163,28 +183,20 @@ def get_template_info_and_spike_amplitudes( inverse_whitening_matrix: np.ndarray Inverse of the whitening matrix used in KS preprocessing, used to unwhiten templates. - ycoords : np.ndarray - (num_channels,) array of the y-axis (depth) channel positions. - spike_templates : np.ndarray - (num_spikes,) array indicating the template associated with each spike. - template_scaling_amplitudes : np.ndarray - (num_spikes,) array holding the scaling amplitudes, by which the - template was scaled to match each spike. + channel_positions : np.ndarray + (num_channels, 2) array of the x, y channel positions. Returns ------- - spike_amplitudes : np.ndarray - (num_spikes,) array of the amplitude of each spike. - spike_depths : np.ndarray - (num_spikes,) array of the depth (probe y-axis) of each spike. Note - this is just the template depth for each spike (i.e. depth of all spikes - from the same cluster are identical). - template_amplitudes : np.ndarray - (num_templates,) Amplitude of each template, calculated as average of spike amplitudes. - template_depths : np.ndarray - (num_templates,) array of the depth of each template. + template_amplitudes_unscaled : np.ndarray + (num_templates,) array of the unscaled tempalte amplitudes. These can be + used to calculate spike amplitude with `template_amplitude_scalings`. + template_locations : np.ndarray + (num_templates, 2) array of the x, y positions (center of mass) of each template. unwhite_templates : np.ndarray Unwhitened templates (num_clusters, num_samples, num_channels). + template_max_site : np.array + The maximum loading spike for the unwhitened template. trough_peak_durations : np.ndarray (num_templates, ) array of durations from trough to peak for each template waveform waveforms : np.ndarray @@ -195,43 +207,31 @@ def get_template_info_and_spike_amplitudes( for idx, template in enumerate(templates): unwhite_templates[idx, :, :] = templates[idx, :, :] @ inverse_whitening_matrix - # First, calculate the depth of each template from the amplitude - # on each channel by the center of mass method. - # Take the max amplitude for each channel, then use the channel - # with most signal as template amplitude. Zero any small channel amplitudes. + # with most signal as template amplitude. template_amplitudes_per_channel = np.max(unwhite_templates, axis=1) - np.min(unwhite_templates, axis=1) template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) - threshold_values = 0.3 * template_amplitudes_unscaled - template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 + # Zero any small channel amplitudes + # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree? + # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 # Calculate the template depth as the center of mass based on channel amplitudes - template_depths = np.sum(template_amplitudes_per_channel * ycoords[np.newaxis, :], axis=1) / np.sum( - template_amplitudes_per_channel, axis=1 - ) - - # Next, find the depth of each spike based on its template. Recompute the template - # amplitudes as the average of the spike amplitudes ('since - # tempScalingAmps are equal mean for all templates') - spike_amplitudes = template_amplitudes_unscaled[spike_templates] * template_scaling_amplitudes - - # Take the average of all spike amplitudes to get actual template amplitudes - # (since tempScalingAmps are equal mean for all templates) - num_indices = templates.shape[0] - sum_per_index = np.zeros(num_indices, dtype=np.float64) - np.add.at(sum_per_index, spike_templates, spike_amplitudes) - counts = np.bincount(spike_templates, minlength=num_indices) - template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) + weights = template_amplitudes_per_channel / np.sum(template_amplitudes_per_channel, axis=1)[:, np.newaxis] + template_locations = weights @ channel_positions # Get channel with the largest amplitude (take that as the waveform) - max_site = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + template_max_site = np.argmax( + np.max(np.abs(unwhite_templates), axis=1), axis=1 + ) # TODO: i changed this to use unwhitened templates instead of templates. This okay? # Use template channel with max signal as waveform - waveforms = np.empty(templates.shape[:2]) - for idx, template in enumerate(templates): - waveforms[idx, :] = templates[idx, :, max_site[idx]] + waveforms = np.empty( + unwhite_templates.shape[:2] + ) # TODO: i changed this to use unwhitened templates instead of templates. This okay? + for idx, template in enumerate(unwhite_templates): + waveforms[idx, :] = unwhite_templates[idx, :, template_max_site[idx]] # Get trough-to-peak time for each template. Find the trough as the # minimum signal for the template waveform. The duration (in @@ -244,15 +244,26 @@ def get_template_info_and_spike_amplitudes( trough_peak_durations[idx] = np.argmax(tmp_max[waveform_trough[idx] :]) return ( - spike_amplitudes, - template_depths, - template_amplitudes, + template_amplitudes_unscaled, + template_locations, + template_max_site, unwhite_templates, trough_peak_durations, waveforms, ) +def compute_template_amplitudes_from_spikes(): + # Take the average of all spike amplitudes to get actual template amplitudes + # (since tempScalingAmps are equal mean for all templates) + num_indices = templates.shape[0] + sum_per_index = np.zeros(num_indices, dtype=np.float64) + np.add.at(sum_per_index, spike_templates, spike_amplitudes) + counts = np.bincount(spike_templates, minlength=num_indices) + template_amplitudes = np.divide(sum_per_index, counts, out=np.zeros_like(sum_per_index), where=counts != 0) + return template_amplitudes + + def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: """ Loads the output of Kilosort into a `params` dict. diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py new file mode 100644 index 0000000000..449403dba9 --- /dev/null +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -0,0 +1,403 @@ +from pathlib import Path +from spikeinterface.widgets.base import BaseWidget, to_attr +import matplotlib.axis +import scipy.signal +from spikeinterface.core import read_python +import numpy as np +import pandas as pd + +import matplotlib.pyplot as plt +from scipy import stats +import load_kilosort_utils + + +class KilosortDriftMapWidget(BaseWidget): + """ + Create a drift map plot in the kilosort style. This is ported from Nick Steinmetz's + `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes. + By default, a raster plot is drawn with the y-axis is spike depth and + x-axis is time. Optionally, a corresponding 2D activity histogram can be + added as a subplot (spatial bins, spike counts) with optional + peak coloring and drift event detection (see below). + Parameters + ---------- + sorter_output : str | Path, + Path to the kilosort output folder. + only_include_large_amplitude_spikes : bool + If `True`, only spikes with larger amplitudes are included. For + details, see `_filter_large_amplitude_spikes()`. + decimate : None | int + If an integer n, only every nth spike is kept from the plot. Useful for improving + performance when there are many spikes. If `None`, spikes will not be decimated. + add_histogram_plot : bool + If `True`, an activity histogram will be added to a new subplot to the + left of the drift map. + add_histogram_peaks_and_boundaries : bool + If `True`, activity histogram peaks are detected and colored red if + isolated according to start/end boundaries of the peak (blue otherwise). + add_drift_events : bool + If `True`, drift events will be plot on the raster map. Required + `add_histogram_plot` and `add_histogram_peaks_and_boundaries` to run. + weight_histogram_by_amplitude : bool + If `True`, histogram counts will be weighted by spike amplitude. + localised_spikes_only : bool + If `True`, only spatially isolated spikes will be included. + exclude_noise : bool + If `True`, units labelled as noise in the `cluster_groups` file + will be excluded. + gain : float | None + If not `None`, amplitudes will be scaled by the supplied gain. + large_amplitude_only_segment_size: float + If `only_include_large_amplitude_spikes` is `True`, the probe is split into + segments to compute mean and std used as threshold. This sets the size of the + segments in um. + localised_spikes_channel_cutoff: int + If `localised_spikes_only` is `True`, spikes that have more than half of the + maximum loading channel over a range of > n channels are removed. + This sets the number of channels. + """ + + def __init__( + self, + sorter_output: str | Path, + only_include_large_amplitude_spikes: bool = True, + decimate: None | int = None, + add_histogram_plot: bool = False, + add_histogram_peaks_and_boundaries: bool = True, + add_drift_events: bool = True, + weight_histogram_by_amplitude: bool = False, + localised_spikes_only: bool = False, + exclude_noise: bool = False, + gain: float | None = None, + large_amplitude_only_segment_size: float = 800.0, + localised_spikes_channel_cutoff: int = 20, + ): + if not isinstance(sorter_output, Path): + sorter_output = Path(sorter_output) + + if not sorter_output.is_dir(): + raise ValueError(f"No output folder found at {sorter_output}") + + if not (sorter_output / "params.py").is_file(): + raise ValueError( + "The `sorting_output` path is not a valid kilosort output" + "folder. It does not contain a `params.py` file`." + ) + + plot_data = dict( + sorter_output=sorter_output, + only_include_large_amplitude_spikes=only_include_large_amplitude_spikes, + decimate=decimate, + add_histogram_plot=add_histogram_plot, + add_histogram_peaks_and_boundaries=add_histogram_peaks_and_boundaries, + add_drift_events=add_drift_events, + weight_histogram_by_amplitude=weight_histogram_by_amplitude, + localised_spikes_only=localised_spikes_only, + exclude_noise=exclude_noise, + gain=gain, + large_amplitude_only_segment_size=large_amplitude_only_segment_size, + localised_spikes_channel_cutoff=localised_spikes_channel_cutoff, + ) + BaseWidget.__init__(self, plot_data, backend="matplotlib") + + def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: + + dp = to_attr(data_plot) + + spike_indexes, spike_amplitudes, spike_locations, _ = load_kilosort_utils.compute_spike_amplitude_and_depth( + dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff + ) + spike_times = spike_indexes / 30000 + spike_depths = spike_locations[:, 1] + + # Calculate the amplitude range for plotting first, so the scale is always the + # same across all options (e.g. decimation) which helps with interpretability. + if dp.only_include_large_amplitude_spikes: + amplitude_range_all_spikes = ( + spike_amplitudes.min(), + spike_amplitudes.max(), + ) + else: + amplitude_range_all_spikes = np.percentile(spike_amplitudes, (1, 90)) + + if dp.decimate: + spike_times = spike_times[:: dp.decimate] + spike_amplitudes = spike_amplitudes[:: dp.decimate] + spike_depths = spike_depths[:: dp.decimate] + + if dp.only_include_large_amplitude_spikes: + spike_times, spike_amplitudes, spike_depths = self._filter_large_amplitude_spikes( + spike_times, spike_amplitudes, spike_depths, dp.large_amplitude_only_segment_size + ) + + # Setup axis and plot the raster drift map + fig = plt.figure(figsize=(10, 10 * (6 / 8))) + + if dp.add_histogram_plot: + gs = fig.add_gridspec(1, 2, width_ratios=[1, 5]) + hist_axis = fig.add_subplot(gs[0]) + raster_axis = fig.add_subplot(gs[1], sharey=hist_axis) + else: + raster_axis = fig.add_subplot() + + self._plot_kilosort_drift_map_raster( + spike_times, + spike_amplitudes, + spike_depths, + amplitude_range_all_spikes, + axis=raster_axis, + ) + + if not dp.add_histogram_plot: + raster_axis.set_xlabel("time") + raster_axis.set_ylabel("y position") + self.axes = [raster_axis] + return + + # If the histogram plot is requested, plot it alongside + # it's peak colouring, bounds display and drift point display. + hist_axis.set_xlabel("count") + raster_axis.set_xlabel("time") + hist_axis.set_ylabel("y position") + + bin_centers, counts = self._compute_activity_histogram( + spike_amplitudes, spike_depths, dp.weight_histogram_by_amplitude + ) + hist_axis.plot(counts, bin_centers, color="black", linewidth=1) + + if dp.add_histogram_peaks_and_boundaries: + drift_events = self._color_histogram_peaks_and_detect_drift_events( + spike_times, spike_depths, counts, bin_centers, hist_axis + ) + + if dp.add_drift_events and np.any(drift_events): + raster_axis.scatter(drift_events[:, 0], drift_events[:, 1], facecolors="r", edgecolors="none") + for i, _ in enumerate(drift_events): + raster_axis.text( + drift_events[i, 0] + 1, drift_events[i, 1], str(np.round(drift_events[i, 2])), color="r" + ) + self.axes = [hist_axis, raster_axis] + + def _plot_kilosort_drift_map_raster( + self, + spike_times: np.ndarray, + spike_amplitudes: np.ndarray, + spike_depths: np.ndarray, + amplitude_range: np.ndarray | tuple, + axis: matplotlib.axes.Axes, + ) -> None: + """ + Plot a drift raster plot in the kilosort style. + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + Parameters + ---------- + spike_times : np.ndarray + (num_spikes,) array of spike times. + spike_amplitudes : np.ndarray + (num_spikes,) array of corresponding spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) array of corresponding spike depths. + amplitude_range : np.ndarray | tuple + (2,) array of min, max amplitude values for color binning. + axis : matplotlib.axes.Axes + Matplotlib axes object on which to plot the drift map. + """ + n_color_bins = 20 + marker_size = 0.5 + + color_bins = np.linspace(amplitude_range[0], amplitude_range[1], n_color_bins) + + colors = plt.get_cmap("gray")(np.linspace(0, 1, n_color_bins))[::-1] + + for bin_idx in range(n_color_bins - 1): + + spikes_in_amplitude_bin = np.logical_and( + spike_amplitudes >= color_bins[bin_idx], spike_amplitudes <= color_bins[bin_idx + 1] + ) + axis.scatter( + spike_times[spikes_in_amplitude_bin], + spike_depths[spikes_in_amplitude_bin], + color=colors[bin_idx], + s=marker_size, + antialiased=True, + ) + + def _compute_activity_histogram( + self, spike_amplitudes: np.ndarray, spike_depths: np.ndarray, weight_histogram_by_amplitude: bool + ) -> tuple[np.ndarray, ...]: + """ + Compute the activity histogram for the kilosort drift map's left-side plot. + Parameters + ---------- + spike_amplitudes : np.ndarray + (num_spikes,) array of spike amplitudes. + spike_depths : np.ndarray + (num_spikes,) array of spike depths. + weight_histogram_by_amplitude : bool + If `True`, the spike amplitudes are taken into consideration when generating the + histogram. The amplitudes are scaled to the range [0, 1] then summed for each bin, + to generate the histogram values. If `False`, counts (i.e. num spikes per bin) + are used. + Returns + ------- + bin_centers : np.ndarray + The spatial bin centers (probe depth) for the histogram. + values : np.ndarray + The histogram values. If `weight_histogram_by_amplitude` is `False`, these + values represent are counts, otherwise they are counts weighted by amplitude. + """ + assert ( + spike_amplitudes.dtype == np.float64 + ), "`spike amplitudes should be high precision as many values are summed." + + bin_um = 2 + bins = np.arange(spike_depths.min() - bin_um, spike_depths.max() + bin_um, bin_um) + values, bins = np.histogram(spike_depths, bins=bins) + bin_centers = (bins[:-1] + bins[1:]) / 2 + + if weight_histogram_by_amplitude: + bin_indices = np.digitize(spike_depths, bins, right=True) - 1 + values = np.zeros(bin_indices.max() + 1, dtype=np.float64) + scaled_spike_amplitudes = (spike_amplitudes - spike_amplitudes.min()) / np.ptp(spike_amplitudes) + np.add.at(values, bin_indices, scaled_spike_amplitudes) + + return bin_centers, values + + def _color_histogram_peaks_and_detect_drift_events( + self, + spike_times: np.ndarray, + spike_depths: np.ndarray, + counts: np.ndarray, + bin_centers: np.ndarray, + hist_axis: matplotlib.axes.Axes, + ) -> np.ndarray: + """ + Given an activity histogram, color the peaks red (isolated peak) or + blue (peak overlaps with other peaks) and compute spatial drift + events for isolated peaks across time bins. + This function was ported from Nick Steinmetz's `spikes` repository + MATLAB code, https://github.com/cortex-lab/spikes + Parameters + ---------- + spike_times : np.ndarray + (num_spikes,) array of spike times. + spike_depths : np.ndarray + (num_spikes,) array of corresponding spike depths. + counts : np.ndarray + (num_bins,) array of histogram bin counts. + bin_centers : np.ndarray + (num_bins,) array of histogram bin centers. + hist_axis : matplotlib.axes.Axes + Axes on which the histogram is plot, to add peaks. + Returns + ------- + drift_events : np.ndarray + A (num_drift_events, 3) array of drift events. The columns are + (time_position, spatial_position, drift_value). The drift + value is computed per time, spatial bin as the difference between + the median position of spikes in the bin, and the bin center. + """ + all_peak_indexes = scipy.signal.find_peaks( + counts, + )[0] + + # Filter low-frequency peaks, so they are not included in the + # step to determine whether peaks are overlapping (new step + # introduced in the port to python) + bin_above_freq_threshold = counts[all_peak_indexes] > 0.3 * spike_times[-1] + filtered_peak_indexes = all_peak_indexes[bin_above_freq_threshold] + + drift_events = [] + for idx, peak_index in enumerate(filtered_peak_indexes): + + peak_count = counts[peak_index] + + # Find the start and end of peak min/max bounds (5% of amplitude) + start_position = np.where(counts[:peak_index] < peak_count * 0.05)[0].max() + end_position = np.where(counts[peak_index:] < peak_count * 0.05)[0].min() + peak_index + + if ( # bounds include another, different histogram peak + idx > 0 + and start_position < filtered_peak_indexes[idx - 1] + or idx < filtered_peak_indexes.size - 1 + and end_position > filtered_peak_indexes[idx + 1] + ): + hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="blue") + continue + + else: + for position in [start_position, end_position]: + hist_axis.axhline(bin_centers[position], 0, counts.max(), color="grey", linestyle="--") + hist_axis.scatter(peak_count, bin_centers[peak_index], facecolors="none", edgecolors="red") + + # For isolated histogram peaks, detect the drift events, defined as + # difference between spatial bin center and median spike depth in the bin + # over 6 um (in time / spatial bins with at least 10 spikes). + depth_in_window = np.logical_and( + spike_depths > bin_centers[start_position], + spike_depths < bin_centers[end_position], + ) + current_spike_depths = spike_depths[depth_in_window] + current_spike_times = spike_times[depth_in_window] + + window_s = 10 + + all_time_bins = np.arange(0, np.ceil(spike_times[-1]).astype(int), window_s) + for time_bin in all_time_bins: + + spike_in_time_bin = np.logical_and( + current_spike_times >= time_bin, current_spike_times <= time_bin + window_s + ) + drift_size = bin_centers[peak_index] - np.median(current_spike_depths[spike_in_time_bin]) + + # 6 um is the hardcoded threshold for drift, and we want at least 10 spikes for the median calculation + bin_has_drift = np.abs(drift_size) > 6 and np.sum(spike_in_time_bin, dtype=np.int16) > 10 + if bin_has_drift: + drift_events.append((time_bin + window_s / 2, bin_centers[peak_index], drift_size)) + + drift_events = np.array(drift_events) + + return drift_events + + def _filter_large_amplitude_spikes( + self, + spike_times: np.ndarray, + spike_amplitudes: np.ndarray, + spike_depths: np.ndarray, + large_amplitude_only_segment_size, + ) -> tuple[np.ndarray, ...]: + """ + Return spike properties with only the largest-amplitude spikes included. The probe + is split into egments, and within each segment the mean and std computed. + Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded + Splitting the probe is only done for the exclusion step, the returned array are flat. + Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns + copies of these arrays containing only the large amplitude spikes. + """ + spike_bool = np.zeros_like(spike_amplitudes, dtype=bool) + + segment_size_um = large_amplitude_only_segment_size + + probe_segments_left_edges = np.arange(np.floor(spike_depths.max() / segment_size_um) + 1) * segment_size_um + + for segment_left_edge in probe_segments_left_edges: + segment_right_edge = segment_left_edge + segment_size_um + + spikes_in_seg = np.where( + np.logical_and(spike_depths >= segment_left_edge, spike_depths < segment_right_edge) + )[0] + spike_amps_in_seg = spike_amplitudes[spikes_in_seg] + is_high_amplitude = spike_amps_in_seg > np.mean(spike_amps_in_seg) + 1.5 * np.std(spike_amps_in_seg, ddof=1) + + spike_bool[spikes_in_seg] = is_high_amplitude + + spike_times = spike_times[spike_bool] + spike_amplitudes = spike_amplitudes[spike_bool] + spike_depths = spike_depths[spike_bool] + + return spike_times, spike_amplitudes, spike_depths + + +KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2") +plt.show() From 6d569a2943d3dd09a5236233ae3f35765cb11f31 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 14 Nov 2024 14:21:27 +0000 Subject: [PATCH 4/7] Tidy up, general checks. --- .../working/load_kilosort_utils.py | 150 +++++++++++------- .../working/plot_kilosort_drift_map.py | 27 +++- 2 files changed, 114 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 2423bc02a0..3efdabf7fe 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -7,7 +7,7 @@ from scipy import stats -# TODO: spike_times -> spike_indexes +# TODO: spike_times -> spike_indices """ Notes ----- @@ -15,19 +15,23 @@ - things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. """ +######################################################################################################################## +# Get Spike Data +######################################################################################################################## + def compute_spike_amplitude_and_depth( sorter_output: str | Path, localised_spikes_only, exclude_noise, gain: float | None = None, - localised_spikes_channel_cutoff: int = None, # TODO + localised_spikes_channel_cutoff: int = None, ) -> tuple[np.ndarray, ...]: """ Compute the amplitude and depth of all detected spikes from the kilosort output. - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes + This function is based on code in Nick Steinmetz's `spikes` repository, + https://github.com/cortex-lab/spikes Parameters ---------- @@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth( Returns ------- - spike_indexes : np.ndarray - (num_spikes,) array of spike indexes. + spike_indices : np.ndarray + (num_spikes,) array of spike indices. spike_amplitudes : np.ndarray (num_spikes,) array of corresponding spike amplitudes. spike_depths : np.ndarray @@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth( if isinstance(sorter_output, str): sorter_output = Path(sorter_output) - params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) + params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) if localised_spikes_only: localised_templates = [] @@ -81,10 +85,56 @@ def compute_spike_amplitude_and_depth( localised_template_by_spike = np.isin(params["spike_templates"], localised_templates) - _strip_spikes(params, localised_template_by_spike) + params["spike_templates"] = params["spike_templates"][localised_template_by_spike] + params["spike_indices"] = params["spike_indices"][localised_template_by_spike] + params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike] + params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] + params["pc_features"] = params["pc_features"][localised_template_by_spike] + + spike_locations, spike_max_sites = _get_locations_from_pc_features(params) + + # Amplitude is calculated for each spike as the template amplitude + # multiplied by the `template_scaling_amplitudes`. + template_amplitudes_unscaled, *_ = get_unwhite_template_info( + params["templates"], + params["whitening_matrix_inv"], + params["channel_positions"], + ) + spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] + + if gain is not None: + spike_amplitudes *= gain + + compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes) + + if localised_spikes_only: + # Interpolate the channel ids to location. + # Remove spikes > 5 um from average position + # Above we already removed non-localized templates, but that on its own is insufficient. + # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient + # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere + # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. + # 3) just use depth. Probably go for that. check with others. + spike_depths = spike_locations[:, 1] + b = stats.linregress(spike_depths, spike_max_sites).slope + i = np.abs(spike_max_sites - b * spike_depths) <= 5 + params["spike_indices"] = params["spike_indices"][i] + spike_amplitudes = spike_amplitudes[i] + spike_locations = spike_locations[i, :] + spike_max_sites = spike_max_sites[i] + + return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites + + +def _get_locations_from_pc_features(params): + """ + + This function is based on code in Nick Steinmetz's `spikes` repository, + https://github.com/cortex-lab/spikes + """ # Compute spike depths - pc_features = params["pc_features"][:, 0, :] # Do this compute + pc_features = params["pc_features"][:, 0, :] pc_features[pc_features < 0] = 0 # Some spikes do not load at all onto the first PC. To avoid biasing the @@ -109,58 +159,28 @@ def compute_spike_amplitude_and_depth( "to extend this code section to handle more components." ) - # Get the channel indexes corresponding to the 32 channels from the PC. + # Get the channel indices corresponding to the 32 channels from the PC. spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] # Compute the spike locations as the center of mass of the PC scores spike_feature_coords = params["channel_positions"][spike_features_indices, :] - norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square + norm_weights = ( + pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] + ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI. spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis] spike_locations = np.sum(spike_locations, axis=1) # TODO: now max site per spike is computed from PCs, not as the channel max site as previous - spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)] + spike_max_sites = spike_features_indices[ + np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1) + ] - # Amplitude is calculated for each spike as the template amplitude - # multiplied by the `template_scaling_amplitudes`. - template_amplitudes_unscaled, *_ = get_unwhite_template_info( - params["templates"], - params["whitening_matrix_inv"], - params["channel_positions"], - ) - spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"] + return spike_locations, spike_max_sites - if gain is not None: - spike_amplitudes *= gain - - if localised_spikes_only: - # Interpolate the channel ids to location. - # Remove spikes > 5 um from average position - # Above we already removed non-localized templates, but that on its own is insufficient. - # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient - # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere - # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. - # 3) just use depth. Probably go for that. check with others. - spike_depths = spike_locations[:, 1] - b = stats.linregress(spike_depths, spike_sites).slope - i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this - - params["spike_indexes"] = params["spike_indexes"][i] - spike_amplitudes = spike_amplitudes[i] - spike_locations = spike_locations[i, :] - - return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites - -def _strip_spikes_in_place(params, indices): - """ """ - params["spike_templates"] = params["spike_templates"][ - indices - ] # TODO: make an function for this. because we do this a lot - params["spike_indexes"] = params["spike_indexes"][indices] - params["spike_clusters"] = params["spike_clusters"][indices] - params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices] - params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices +######################################################################################################################## +# Get Template Data +######################################################################################################################## def get_unwhite_template_info( @@ -173,8 +193,8 @@ def get_unwhite_template_info( Amplitude is calculated for each spike as the template amplitude multiplied by the `template_scaling_amplitudes`. - This function was ported from Nick Steinmetz's `spikes` repository - MATLAB code, https://github.com/cortex-lab/spikes + This function is based on code in Nick Steinmetz's `spikes` repository, + https://github.com/cortex-lab/spikes Parameters ---------- @@ -213,7 +233,7 @@ def get_unwhite_template_info( template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) - # Zero any small channel amplitudes + # Zero any small channel amplitudes TODO: removed this. # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree? # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 @@ -253,9 +273,14 @@ def get_unwhite_template_info( ) -def compute_template_amplitudes_from_spikes(): - # Take the average of all spike amplitudes to get actual template amplitudes - # (since tempScalingAmps are equal mean for all templates) +def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes): + """ + Take the average of all spike amplitudes to get actual template amplitudes + (since tempScalingAmps are equal mean for all templates) + + This function is ported from Nick Steinmetz's `spikes` repository, + https://github.com/cortex-lab/spikes + """ num_indices = templates.shape[0] sum_per_index = np.zeros(num_indices, dtype=np.float64) np.add.at(sum_per_index, spike_templates, spike_amplitudes) @@ -264,7 +289,12 @@ def compute_template_amplitudes_from_spikes(): return template_amplitudes -def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: +######################################################################################################################## +# Load Parameters from KS Directory +######################################################################################################################## + + +def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict: """ Loads the output of Kilosort into a `params` dict. @@ -300,7 +330,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool params = read_python(sorter_output / "params.py") - spike_indexes = np.load(sorter_output / "spike_times.npy") + spike_indices = np.load(sorter_output / "spike_times.npy") spike_templates = np.load(sorter_output / "spike_templates.npy") if (clusters_path := sorter_output / "spike_clusters.csv").is_dir(): @@ -328,7 +358,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool noise_cluster_ids = cluster_ids[cluster_groups == 0] not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids) - spike_indexes = spike_indexes[not_noise_clusters_by_spike] + spike_indices = spike_indices[not_noise_clusters_by_spike] spike_templates = spike_templates[not_noise_clusters_by_spike] temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike] @@ -343,7 +373,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool cluster_groups = 3 * np.ones(cluster_ids.size) new_params = { - "spike_indexes": spike_indexes.squeeze(), + "spike_indices": spike_indices.squeeze(), "spike_templates": spike_templates.squeeze(), "spike_clusters": spike_clusters.squeeze(), "pc_features": pc_features, diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py index 449403dba9..ecac38495f 100644 --- a/src/spikeinterface/working/plot_kilosort_drift_map.py +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -1,8 +1,8 @@ from pathlib import Path -from spikeinterface.widgets.base import BaseWidget, to_attr import matplotlib.axis import scipy.signal -from spikeinterface.core import read_python + +# from spikeinterface.core import read_python import numpy as np import pandas as pd @@ -10,6 +10,8 @@ from scipy import stats import load_kilosort_utils +from spikeinterface.widgets.base import BaseWidget, to_attr + class KilosortDriftMapWidget(BaseWidget): """ @@ -399,5 +401,24 @@ def _filter_large_amplitude_spikes( return spike_times, spike_amplitudes, spike_depths -KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2") +KilosortDriftMapWidget( + "/Users/joeziminski/data/bombcelll/sorter_output", + only_include_large_amplitude_spikes=False, + localised_spikes_only=True, +) plt.show() + +""" + sorter_output: str | Path, + only_include_large_amplitude_spikes: bool = True, + decimate: None | int = None, + add_histogram_plot: bool = False, + add_histogram_peaks_and_boundaries: bool = True, + add_drift_events: bool = True, + weight_histogram_by_amplitude: bool = False, + localised_spikes_only: bool = False, + exclude_noise: bool = False, + gain: float | None = None, + large_amplitude_only_segment_size: float = 800.0, + localised_spikes_channel_cutoff: int = 20, +""" From 2a9e08296c7aac1db46d83175de5fbff5f52f3f5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 14 Nov 2024 14:58:55 +0000 Subject: [PATCH 5/7] More tidy ups, remove TODOs --- .../working/load_kilosort_utils.py | 73 +++++++++---------- .../working/test_peaks_from_ks.py | 3 + 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 3efdabf7fe..01b2ac2b81 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -7,12 +7,13 @@ from scipy import stats -# TODO: spike_times -> spike_indices """ Notes ----- - not everything is used for current purposes - things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude. + +TODO: testing against diferent ks versions """ ######################################################################################################################## @@ -21,22 +22,21 @@ def compute_spike_amplitude_and_depth( - sorter_output: str | Path, + params: dict, localised_spikes_only, - exclude_noise, gain: float | None = None, localised_spikes_channel_cutoff: int = None, ) -> tuple[np.ndarray, ...]: """ - Compute the amplitude and depth of all detected spikes from the kilosort output. + Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output. This function is based on code in Nick Steinmetz's `spikes` repository, https://github.com/cortex-lab/spikes Parameters ---------- - sorter_output : str | Path - Path to the kilosort run sorting output. + params : dict + `params` as loaded from the kilosort output directory (see `load_ks_dir()`) localised_spikes_only : bool If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the amplitude of the maximum loading channel) and which are close to the average depth for @@ -54,23 +54,16 @@ def compute_spike_amplitude_and_depth( (num_spikes,) array of spike indices. spike_amplitudes : np.ndarray (num_spikes,) array of corresponding spike amplitudes. - spike_depths : np.ndarray - (num_spikes,) array of corresponding depths (probe y-axis location). - - Notes - ----- - In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template - depth, for each spike (so it is the same for all spikes in a cluster). Here we need - to find the depth of each individual spike, using its low-dimensional projection. - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. + spike_locations : np.ndarray + (num_spikes, 2) array of corresponding spike locations (x, y) estimated using + center of mass from the first PC (or, second PC if no signal on first PC). + See `_get_locations_from_pc_features()` for details. """ if isinstance(sorter_output, str): sorter_output = Path(sorter_output) - params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise) + if not params["pc_features"]: + raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.") if localised_spikes_only: localised_templates = [] @@ -91,10 +84,11 @@ def compute_spike_amplitude_and_depth( params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike] params["pc_features"] = params["pc_features"][localised_template_by_spike] + # Compute the spike locations and maximum-loading channel per spike spike_locations, spike_max_sites = _get_locations_from_pc_features(params) - # Amplitude is calculated for each spike as the template amplitude - # multiplied by the `template_scaling_amplitudes`. + # Amplitude is calculated for each spike as the template amplitude + # multiplied by the `template_scaling_amplitudes`. template_amplitudes_unscaled, *_ = get_unwhite_template_info( params["templates"], params["whitening_matrix_inv"], @@ -105,16 +99,11 @@ def compute_spike_amplitude_and_depth( if gain is not None: spike_amplitudes *= gain - compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes) - if localised_spikes_only: # Interpolate the channel ids to location. # Remove spikes > 5 um from average position # Above we already removed non-localized templates, but that on its own is insufficient. # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient - # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere - # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold. - # 3) just use depth. Probably go for that. check with others. spike_depths = spike_locations[:, 1] b = stats.linregress(spike_depths, spike_max_sites).slope i = np.abs(spike_max_sites - b * spike_depths) <= 5 @@ -130,6 +119,14 @@ def compute_spike_amplitude_and_depth( def _get_locations_from_pc_features(params): """ + Notes + ----- + Location of of each individual spike is computed from its low-dimensional projection. + `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. + Taking the first component, the subset of 32 channels associated with this + spike are indexed to get the actual channel locations (in um). Then, the channel + locations are weighted by their PC values. + This function is based on code in Nick Steinmetz's `spikes` repository, https://github.com/cortex-lab/spikes """ @@ -145,6 +142,10 @@ def _get_locations_from_pc_features(params): # Then recompute the estimated waveform peak on each channel by # summing the PCs by their respective weights. However, the PC basis # vectors themselves do not appear to be output by KS. + + # We include the (n_channels i.e. features) from the second PC + # into the `pc_features` mostly containing the first PC. As all + # operations are per-spike (i.e. row-wise) no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0) pc_features_2 = params["pc_features"][:, 1, :] @@ -164,13 +165,12 @@ def _get_locations_from_pc_features(params): # Compute the spike locations as the center of mass of the PC scores spike_feature_coords = params["channel_positions"][spike_features_indices, :] - norm_weights = ( - pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] - ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI. + norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] + spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis] spike_locations = np.sum(spike_locations, axis=1) - # TODO: now max site per spike is computed from PCs, not as the channel max site as previous + # Find the max site as the channel with the largest PC weight. spike_max_sites = spike_features_indices[ np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1) ] @@ -233,23 +233,16 @@ def get_unwhite_template_info( template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1) - # Zero any small channel amplitudes TODO: removed this. - # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree? - # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0 - # Calculate the template depth as the center of mass based on channel amplitudes weights = template_amplitudes_per_channel / np.sum(template_amplitudes_per_channel, axis=1)[:, np.newaxis] template_locations = weights @ channel_positions # Get channel with the largest amplitude (take that as the waveform) - template_max_site = np.argmax( - np.max(np.abs(unwhite_templates), axis=1), axis=1 - ) # TODO: i changed this to use unwhitened templates instead of templates. This okay? + template_max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1) # Use template channel with max signal as waveform - waveforms = np.empty( - unwhite_templates.shape[:2] - ) # TODO: i changed this to use unwhitened templates instead of templates. This okay? + waveforms = np.empty(unwhite_templates.shape[:2]) + for idx, template in enumerate(unwhite_templates): waveforms[idx, :] = unwhite_templates[idx, :, template_max_site[idx]] diff --git a/src/spikeinterface/working/test_peaks_from_ks.py b/src/spikeinterface/working/test_peaks_from_ks.py index 586f98e9c9..438268baf6 100644 --- a/src/spikeinterface/working/test_peaks_from_ks.py +++ b/src/spikeinterface/working/test_peaks_from_ks.py @@ -42,6 +42,9 @@ dtype=[('x', ' Date: Thu, 14 Nov 2024 15:50:29 +0000 Subject: [PATCH 6/7] Working on the private PC decomposition. --- .../working/load_kilosort_utils.py | 19 +++++++++++++------ .../working/plot_kilosort_drift_map.py | 4 +++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 01b2ac2b81..084c3b37ea 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -59,10 +59,7 @@ def compute_spike_amplitude_and_depth( center of mass from the first PC (or, second PC if no signal on first PC). See `_get_locations_from_pc_features()` for details. """ - if isinstance(sorter_output, str): - sorter_output = Path(sorter_output) - - if not params["pc_features"]: + if params["pc_features"] is None: raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.") if localised_spikes_only: @@ -118,10 +115,12 @@ def compute_spike_amplitude_and_depth( def _get_locations_from_pc_features(params): """ + Compute locations from the waveform principal component scores. Notes ----- Location of of each individual spike is computed from its low-dimensional projection. + During sorting, kilosort computes the ' `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. Taking the first component, the subset of 32 channels associated with this spike are indexed to get the actual channel locations (in um). Then, the channel @@ -131,6 +130,13 @@ def _get_locations_from_pc_features(params): https://github.com/cortex-lab/spikes """ # Compute spike depths + + # for each spike, a PCA is computed just on that spike (n samples x n channels). + # the components are all different between spikes, so are not saved. + # This gives a (n pc = 3, num channels) set of scores. + # but then how it is possible for some spikes to have zero score onto the principal channel? + + breakpoint() pc_features = params["pc_features"][:, 0, :] pc_features[pc_features < 0] = 0 @@ -153,7 +159,7 @@ def _get_locations_from_pc_features(params): pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes] - if any(np.sum(pc_features, axis=1) == 0): + if np.any(np.sum(pc_features, axis=1) == 0): raise RuntimeError( "Some spikes do not load at all onto the first" "or second principal component. It is necessary" @@ -319,7 +325,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool As this function strips the spikes and units based on only these two data structures, they will work following manual reassignment in Phy. """ - sorter_output = Path(sorter_output) + if isinstance(sorter_output, str): + sorter_output = Path(sorter_output) params = read_python(sorter_output / "params.py") diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py index ecac38495f..e61b7bddd9 100644 --- a/src/spikeinterface/working/plot_kilosort_drift_map.py +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -106,8 +106,10 @@ def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: dp = to_attr(data_plot) + params = load_kilosort_utils.load_ks_dir(dp.sorter_output, load_pcs=True, exclude_noise=dp.exclude_noise) + spike_indexes, spike_amplitudes, spike_locations, _ = load_kilosort_utils.compute_spike_amplitude_and_depth( - dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff + params, dp.localised_spikes_only, dp.gain, dp.localised_spikes_channel_cutoff ) spike_times = spike_indexes / 30000 spike_depths = spike_locations[:, 1] From f33f80cab89df0261132e398a8e7b3c0572e3033 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 14 Nov 2024 18:21:51 +0000 Subject: [PATCH 7/7] Look into the template option. --- .../working/load_kilosort_utils.py | 68 +++++++------------ 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 084c3b37ea..1bd9178d34 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -30,7 +30,7 @@ def compute_spike_amplitude_and_depth( """ Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output. - This function is based on code in Nick Steinmetz's `spikes` repository, + This function is based on code in Cortex Lab's `spikes` repository, https://github.com/cortex-lab/spikes Parameters @@ -119,54 +119,27 @@ def _get_locations_from_pc_features(params): Notes ----- - Location of of each individual spike is computed from its low-dimensional projection. - During sorting, kilosort computes the ' - `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. - Taking the first component, the subset of 32 channels associated with this - spike are indexed to get the actual channel locations (in um). Then, the channel - locations are weighted by their PC values. - - This function is based on code in Nick Steinmetz's `spikes` repository, + My understanding so far. KS1 paper; The individual spike waveforms are decomposed into + 'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA + decompoisition is performed to compute c basis waveforms. Scores for each + channel onto the top three PCs are stored (these recover the waveform well. + + This function is based on code in Cortex Lab's `spikes` repository, https://github.com/cortex-lab/spikes """ - # Compute spike depths - - # for each spike, a PCA is computed just on that spike (n samples x n channels). - # the components are all different between spikes, so are not saved. - # This gives a (n pc = 3, num channels) set of scores. - # but then how it is possible for some spikes to have zero score onto the principal channel? - - breakpoint() - pc_features = params["pc_features"][:, 0, :] + pc_features = params["pc_features"][:, 0, :].copy() pc_features[pc_features < 0] = 0 - # Some spikes do not load at all onto the first PC. To avoid biasing the - # dataset by removing these, we repeat the above for the next PC, - # to compute distances for neurons that do not load onto the 1st PC. - # This is not ideal at all, it would be much better to a) find the - # max value for each channel on each of the PCs (i.e. basis vectors). - # Then recompute the estimated waveform peak on each channel by - # summing the PCs by their respective weights. However, the PC basis - # vectors themselves do not appear to be output by KS. - - # We include the (n_channels i.e. features) from the second PC - # into the `pc_features` mostly containing the first PC. As all - # operations are per-spike (i.e. row-wise) - no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0) - - pc_features_2 = params["pc_features"][:, 1, :] - pc_features_2[pc_features_2 < 0] = 0 - - pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes] - if np.any(np.sum(pc_features, axis=1) == 0): + # TODO: 1) handle this case for pc_features + # 2) instead use the template_features for all other versions. raise RuntimeError( "Some spikes do not load at all onto the first" "or second principal component. It is necessary" "to extend this code section to handle more components." ) - # Get the channel indices corresponding to the 32 channels from the PC. + # Get the channel indices corresponding to the channels from the PC. spike_features_indices = params["pc_features_indices"][params["spike_templates"], :] # Compute the spike locations as the center of mass of the PC scores @@ -199,7 +172,7 @@ def get_unwhite_template_info( Amplitude is calculated for each spike as the template amplitude multiplied by the `template_scaling_amplitudes`. - This function is based on code in Nick Steinmetz's `spikes` repository, + This function is based on code in Cortex Lab's `spikes` repository, https://github.com/cortex-lab/spikes Parameters @@ -277,7 +250,7 @@ def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_am Take the average of all spike amplitudes to get actual template amplitudes (since tempScalingAmps are equal mean for all templates) - This function is ported from Nick Steinmetz's `spikes` repository, + This function is ported from Cortex Lab's `spikes` repository, https://github.com/cortex-lab/spikes """ num_indices = templates.shape[0] @@ -297,7 +270,7 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool """ Loads the output of Kilosort into a `params` dict. - This function was ported from Nick Steinmetz's `spikes` repository MATLAB + This function was ported from Cortex Lab's `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes Parameters @@ -343,8 +316,15 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool if load_pcs: pc_features = np.load(sorter_output / "pc_features.npy") pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy") + + if (sorter_output / "template_features.npy").is_file(): + template_features = np.load(sorter_output / "template_features.npy") + template_features_indices = np.load(sorter_output / "templates_ind.npy") + else: + template_features = template_features_indices = None else: pc_features = pc_features_indices = None + template_features = template_features_indices = None # This makes the assumption that there will never be different .csv and .tsv files # in the same sorter output (this should never happen, there will never even be two). @@ -364,6 +344,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool if load_pcs: pc_features = pc_features[not_noise_clusters_by_spike, :, :] + if template_features is not None: + template_features = template_features[not_noise_clusters_by_spike, :, :] spike_clusters = spike_clusters[not_noise_clusters_by_spike] cluster_ids = cluster_ids[cluster_groups != 0] @@ -378,6 +360,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool "spike_clusters": spike_clusters.squeeze(), "pc_features": pc_features, "pc_features_indices": pc_features_indices, + "template_features": template_features, + "template_features_indices": template_features_indices, "temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(), "cluster_ids": cluster_ids, "cluster_groups": cluster_groups, @@ -399,7 +383,7 @@ def _load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]: There is some slight formatting differences between the `.tsv` and `.csv` versions, presumably from different kilosort versions. - This function was ported from Nick Steinmetz's `spikes` repository MATLAB code, + This function was ported from Cortex Lab's `spikes` repository MATLAB code, https://github.com/cortex-lab/spikes Parameters