77
88from scipy import stats
99
10- # TODO: spike_times -> spike_indices
1110"""
1211Notes
1312-----
1413- not everything is used for current purposes
1514- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
15+
16+ TODO: testing against diferent ks versions
1617"""
1718
1819########################################################################################################################
2122
2223
2324def compute_spike_amplitude_and_depth (
24- sorter_output : str | Path ,
25+ params : dict ,
2526 localised_spikes_only ,
26- exclude_noise ,
2727 gain : float | None = None ,
2828 localised_spikes_channel_cutoff : int = None ,
2929) -> tuple [np .ndarray , ...]:
3030 """
31- Compute the amplitude and depth of all detected spikes from the kilosort output.
31+ Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output.
3232
3333 This function is based on code in Nick Steinmetz's `spikes` repository,
3434 https://github.com/cortex-lab/spikes
3535
3636 Parameters
3737 ----------
38- sorter_output : str | Path
39- Path to the kilosort run sorting output.
38+ params : dict
39+ `params` as loaded from the kilosort output directory (see `load_ks_dir()`)
4040 localised_spikes_only : bool
4141 If `True`, only spikes with small spatial footprint (i.e. 20 channels within 1/2 of the
4242 amplitude of the maximum loading channel) and which are close to the average depth for
@@ -54,23 +54,16 @@ def compute_spike_amplitude_and_depth(
5454 (num_spikes,) array of spike indices.
5555 spike_amplitudes : np.ndarray
5656 (num_spikes,) array of corresponding spike amplitudes.
57- spike_depths : np.ndarray
58- (num_spikes,) array of corresponding depths (probe y-axis location).
59-
60- Notes
61- -----
62- In `get_template_info_and_spike_amplitudes` spike depths is calculated as simply the template
63- depth, for each spike (so it is the same for all spikes in a cluster). Here we need
64- to find the depth of each individual spike, using its low-dimensional projection.
65- `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
66- Taking the first component, the subset of 32 channels associated with this
67- spike are indexed to get the actual channel locations (in um). Then, the channel
68- locations are weighted by their PC values.
57+ spike_locations : np.ndarray
58+ (num_spikes, 2) array of corresponding spike locations (x, y) estimated using
59+ center of mass from the first PC (or, second PC if no signal on first PC).
60+ See `_get_locations_from_pc_features()` for details.
6961 """
7062 if isinstance (sorter_output , str ):
7163 sorter_output = Path (sorter_output )
7264
73- params = load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
65+ if not params ["pc_features" ]:
66+ raise ValueError ("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`." )
7467
7568 if localised_spikes_only :
7669 localised_templates = []
@@ -91,10 +84,11 @@ def compute_spike_amplitude_and_depth(
9184 params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
9285 params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
9386
87+ # Compute the spike locations and maximum-loading channel per spike
9488 spike_locations , spike_max_sites = _get_locations_from_pc_features (params )
9589
96- # Amplitude is calculated for each spike as the template amplitude
97- # multiplied by the `template_scaling_amplitudes`.
90+ # Amplitude is calculated for each spike as the template amplitude
91+ # multiplied by the `template_scaling_amplitudes`.
9892 template_amplitudes_unscaled , * _ = get_unwhite_template_info (
9993 params ["templates" ],
10094 params ["whitening_matrix_inv" ],
@@ -105,16 +99,11 @@ def compute_spike_amplitude_and_depth(
10599 if gain is not None :
106100 spike_amplitudes *= gain
107101
108- compute_template_amplitudes_from_spikes (params ["templates" ], params ["spike_templates" ], spike_amplitudes )
109-
110102 if localised_spikes_only :
111103 # Interpolate the channel ids to location.
112104 # Remove spikes > 5 um from average position
113105 # Above we already removed non-localized templates, but that on its own is insufficient.
114106 # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115- # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116- # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117- # 3) just use depth. Probably go for that. check with others.
118107 spike_depths = spike_locations [:, 1 ]
119108 b = stats .linregress (spike_depths , spike_max_sites ).slope
120109 i = np .abs (spike_max_sites - b * spike_depths ) <= 5
@@ -130,6 +119,14 @@ def compute_spike_amplitude_and_depth(
130119def _get_locations_from_pc_features (params ):
131120 """
132121
122+ Notes
123+ -----
124+ Location of of each individual spike is computed from its low-dimensional projection.
125+ `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
126+ Taking the first component, the subset of 32 channels associated with this
127+ spike are indexed to get the actual channel locations (in um). Then, the channel
128+ locations are weighted by their PC values.
129+
133130 This function is based on code in Nick Steinmetz's `spikes` repository,
134131 https://github.com/cortex-lab/spikes
135132 """
@@ -145,6 +142,10 @@ def _get_locations_from_pc_features(params):
145142 # Then recompute the estimated waveform peak on each channel by
146143 # summing the PCs by their respective weights. However, the PC basis
147144 # vectors themselves do not appear to be output by KS.
145+
146+ # We include the (n_channels i.e. features) from the second PC
147+ # into the `pc_features` mostly containing the first PC. As all
148+ # operations are per-spike (i.e. row-wise)
148149 no_pc1_signal_spikes = np .where (np .sum (pc_features , axis = 1 ) == 0 )
149150
150151 pc_features_2 = params ["pc_features" ][:, 1 , :]
@@ -164,13 +165,12 @@ def _get_locations_from_pc_features(params):
164165
165166 # Compute the spike locations as the center of mass of the PC scores
166167 spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
167- norm_weights = (
168- pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ]
169- ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
168+ norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ]
169+
170170 spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
171171 spike_locations = np .sum (spike_locations , axis = 1 )
172172
173- # TODO: now max site per spike is computed from PCs, not as the channel max site as previous
173+ # Find the max site as the channel with the largest PC weight.
174174 spike_max_sites = spike_features_indices [
175175 np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )
176176 ]
@@ -233,23 +233,16 @@ def get_unwhite_template_info(
233233
234234 template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
235235
236- # Zero any small channel amplitudes TODO: removed this.
237- # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
238- # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
239-
240236 # Calculate the template depth as the center of mass based on channel amplitudes
241237 weights = template_amplitudes_per_channel / np .sum (template_amplitudes_per_channel , axis = 1 )[:, np .newaxis ]
242238 template_locations = weights @ channel_positions
243239
244240 # Get channel with the largest amplitude (take that as the waveform)
245- template_max_site = np .argmax (
246- np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1
247- ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
241+ template_max_site = np .argmax (np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1 )
248242
249243 # Use template channel with max signal as waveform
250- waveforms = np .empty (
251- unwhite_templates .shape [:2 ]
252- ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
244+ waveforms = np .empty (unwhite_templates .shape [:2 ])
245+
253246 for idx , template in enumerate (unwhite_templates ):
254247 waveforms [idx , :] = unwhite_templates [idx , :, template_max_site [idx ]]
255248
0 commit comments