Skip to content

Commit 2a9e082

Browse files
committed
More tidy ups, remove TODOs
1 parent 6d569a2 commit 2a9e082

File tree

2 files changed

+36
-40
lines changed

2 files changed

+36
-40
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
from scipy import stats
99

10-
# TODO: spike_times -> spike_indices
1110
"""
1211
Notes
1312
-----
1413
- not everything is used for current purposes
1514
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
15+
16+
TODO: testing against diferent ks versions
1617
"""
1718

1819
########################################################################################################################
@@ -21,22 +22,21 @@
2122

2223

2324
def compute_spike_amplitude_and_depth(
24-
sorter_output: str | Path,
25+
params: dict,
2526
localised_spikes_only,
26-
exclude_noise,
2727
gain: float | None = None,
2828
localised_spikes_channel_cutoff: int = None,
2929
) -> tuple[np.ndarray, ...]:
3030
"""
31-
Compute the amplitude and depth of all detected spikes from the kilosort output.
31+
Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output.
3232
3333
This function is based on code in Nick Steinmetz's `spikes` repository,
3434
https://github.com/cortex-lab/spikes
3535
3636
Parameters
3737
----------
38-
sorter_output : str | Path
39-
Path to the kilosort run sorting output.
38+
params : dict
39+
`params` as loaded from the kilosort output directory (see `load_ks_dir()`)
4040
localised_spikes_only : bool
4141
If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the
4242
amplitude of the maximum loading channel) and which are close to the average depth for
@@ -54,23 +54,16 @@ def compute_spike_amplitude_and_depth(
5454
(num_spikes,) array of spike indices.
5555
spike_amplitudes : np.ndarray
5656
(num_spikes,) array of corresponding spike amplitudes.
57-
spike_depths : np.ndarray
58-
(num_spikes,) array of corresponding depths (probe y-axis location).
59-
60-
Notes
61-
-----
62-
In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template
63-
depth, for each spike (so it is the same for all spikes in a cluster). Here we need
64-
to find the depth of each individual spike, using its low-dimensional projection.
65-
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
66-
Taking the first component, the subset of 32 channels associated with this
67-
spike are indexed to get the actual channel locations (in um). Then, the channel
68-
locations are weighted by their PC values.
57+
spike_locations : np.ndarray
58+
(num_spikes, 2) array of corresponding spike locations (x, y) estimated using
59+
center of mass from the first PC (or, second PC if no signal on first PC).
60+
See `_get_locations_from_pc_features()` for details.
6961
"""
7062
if isinstance(sorter_output, str):
7163
sorter_output = Path(sorter_output)
7264

73-
params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
65+
if not params["pc_features"]:
66+
raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.")
7467

7568
if localised_spikes_only:
7669
localised_templates = []
@@ -91,10 +84,11 @@ def compute_spike_amplitude_and_depth(
9184
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike]
9285
params["pc_features"] = params["pc_features"][localised_template_by_spike]
9386

87+
# Compute the spike locations and maximum-loading channel per spike
9488
spike_locations, spike_max_sites = _get_locations_from_pc_features(params)
9589

96-
# Amplitude is calculated for each spike as the template amplitude
97-
# multiplied by the `template_scaling_amplitudes`.
90+
# Amplitude is calculated for each spike as the template amplitude
91+
# multiplied by the `template_scaling_amplitudes`.
9892
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
9993
params["templates"],
10094
params["whitening_matrix_inv"],
@@ -105,16 +99,11 @@ def compute_spike_amplitude_and_depth(
10599
if gain is not None:
106100
spike_amplitudes *= gain
107101

108-
compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes)
109-
110102
if localised_spikes_only:
111103
# Interpolate the channel ids to location.
112104
# Remove spikes > 5 um from average position
113105
# Above we already removed non-localized templates, but that on its own is insufficient.
114106
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115-
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116-
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117-
# 3) just use depth. Probably go for that. check with others.
118107
spike_depths = spike_locations[:, 1]
119108
b = stats.linregress(spike_depths, spike_max_sites).slope
120109
i = np.abs(spike_max_sites - b * spike_depths) <= 5
@@ -130,6 +119,14 @@ def compute_spike_amplitude_and_depth(
130119
def _get_locations_from_pc_features(params):
131120
"""
132121
122+
Notes
123+
-----
124+
Location of of each individual spike is computed from its low-dimensional projection.
125+
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
126+
Taking the first component, the subset of 32 channels associated with this
127+
spike are indexed to get the actual channel locations (in um). Then, the channel
128+
locations are weighted by their PC values.
129+
133130
This function is based on code in Nick Steinmetz's `spikes` repository,
134131
https://github.com/cortex-lab/spikes
135132
"""
@@ -145,6 +142,10 @@ def _get_locations_from_pc_features(params):
145142
# Then recompute the estimated waveform peak on each channel by
146143
# summing the PCs by their respective weights. However, the PC basis
147144
# vectors themselves do not appear to be output by KS.
145+
146+
# We include the (n_channels i.e. features) from the second PC
147+
# into the `pc_features` mostly containing the first PC. As all
148+
# operations are per-spike (i.e. row-wise)
148149
no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0)
149150

150151
pc_features_2 = params["pc_features"][:, 1, :]
@@ -164,13 +165,12 @@ def _get_locations_from_pc_features(params):
164165

165166
# Compute the spike locations as the center of mass of the PC scores
166167
spike_feature_coords = params["channel_positions"][spike_features_indices, :]
167-
norm_weights = (
168-
pc_features / np.sum(pc_features, axis=1)[:, np.newaxis]
169-
) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
168+
norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis]
169+
170170
spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis]
171171
spike_locations = np.sum(spike_locations, axis=1)
172172

173-
# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
173+
# Find the max site as the channel with the largest PC weight.
174174
spike_max_sites = spike_features_indices[
175175
np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)
176176
]
@@ -233,23 +233,16 @@ def get_unwhite_template_info(
233233

234234
template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1)
235235

236-
# Zero any small channel amplitudes TODO: removed this.
237-
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
238-
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
239-
240236
# Calculate the template depth as the center of mass based on channel amplitudes
241237
weights = template_amplitudes_per_channel / np.sum(template_amplitudes_per_channel, axis=1)[:, np.newaxis]
242238
template_locations = weights @ channel_positions
243239

244240
# Get channel with the largest amplitude (take that as the waveform)
245-
template_max_site = np.argmax(
246-
np.max(np.abs(unwhite_templates), axis=1), axis=1
247-
) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
241+
template_max_site = np.argmax(np.max(np.abs(unwhite_templates), axis=1), axis=1)
248242

249243
# Use template channel with max signal as waveform
250-
waveforms = np.empty(
251-
unwhite_templates.shape[:2]
252-
) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
244+
waveforms = np.empty(unwhite_templates.shape[:2])
245+
253246
for idx, template in enumerate(unwhite_templates):
254247
waveforms[idx, :] = unwhite_templates[idx, :, template_max_site[idx]]
255248

src/spikeinterface/working/test_peaks_from_ks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
dtype=[('x', '<f8'), ('y', '<f8')])
4343
"""
4444

45+
sorter_output_path = (r"D:\data\New folder\CA_528_1\imec0_ks2",)
46+
47+
params = load_ks_dir(sorter_output_path, load_pcs=True, exclude_noise=exclude_noise)
4548

4649
spike_indexes, spike_amplitudes, weighted_locs, max_sites = compute_spike_amplitude_and_depth(
4750
r"D:\data\New folder\CA_528_1\imec0_ks2",

0 commit comments

Comments
 (0)