77
88from scipy import stats
99
10- # TODO: spike_times -> spike_indexes
10+ # TODO: spike_times -> spike_indices
1111"""
1212Notes
1313-----
1414- not everything is used for current purposes
1515- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
1616"""
1717
18+ ########################################################################################################################
19+ # Get Spike Data
20+ ########################################################################################################################
21+
1822
1923def compute_spike_amplitude_and_depth (
2024 sorter_output : str | Path ,
2125 localised_spikes_only ,
2226 exclude_noise ,
2327 gain : float | None = None ,
24- localised_spikes_channel_cutoff : int = None , # TODO
28+ localised_spikes_channel_cutoff : int = None ,
2529) -> tuple [np .ndarray , ...]:
2630 """
2731 Compute the amplitude and depth of all detected spikes from the kilosort output.
2832
29- This function was ported from Nick Steinmetz's `spikes` repository
30- MATLAB code, https://github.com/cortex-lab/spikes
33+ This function is based on code in Nick Steinmetz's `spikes` repository,
34+ https://github.com/cortex-lab/spikes
3135
3236 Parameters
3337 ----------
@@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
4650
4751 Returns
4852 -------
49- spike_indexes : np.ndarray
50- (num_spikes,) array of spike indexes .
53+ spike_indices : np.ndarray
54+ (num_spikes,) array of spike indices .
5155 spike_amplitudes : np.ndarray
5256 (num_spikes,) array of corresponding spike amplitudes.
5357 spike_depths : np.ndarray
@@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
6670 if isinstance (sorter_output , str ):
6771 sorter_output = Path (sorter_output )
6872
69- params = _load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
73+ params = load_ks_dir (sorter_output , load_pcs = True , exclude_noise = exclude_noise )
7074
7175 if localised_spikes_only :
7276 localised_templates = []
@@ -81,10 +85,56 @@ def compute_spike_amplitude_and_depth(
8185
8286 localised_template_by_spike = np .isin (params ["spike_templates" ], localised_templates )
8387
84- _strip_spikes (params , localised_template_by_spike )
88+ params ["spike_templates" ] = params ["spike_templates" ][localised_template_by_spike ]
89+ params ["spike_indices" ] = params ["spike_indices" ][localised_template_by_spike ]
90+ params ["spike_clusters" ] = params ["spike_clusters" ][localised_template_by_spike ]
91+ params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
92+ params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
93+
94+ spike_locations , spike_max_sites = _get_locations_from_pc_features (params )
95+
96+ # Amplitude is calculated for each spike as the template amplitude
97+ # multiplied by the `template_scaling_amplitudes`.
98+ template_amplitudes_unscaled , * _ = get_unwhite_template_info (
99+ params ["templates" ],
100+ params ["whitening_matrix_inv" ],
101+ params ["channel_positions" ],
102+ )
103+ spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
104+
105+ if gain is not None :
106+ spike_amplitudes *= gain
107+
108+ compute_template_amplitudes_from_spikes (params ["templates" ], params ["spike_templates" ], spike_amplitudes )
109+
110+ if localised_spikes_only :
111+ # Interpolate the channel ids to location.
112+ # Remove spikes > 5 um from average position
113+ # Above we already removed non-localized templates, but that on its own is insufficient.
114+ # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
115+ # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
116+ # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
117+ # 3) just use depth. Probably go for that. check with others.
118+ spike_depths = spike_locations [:, 1 ]
119+ b = stats .linregress (spike_depths , spike_max_sites ).slope
120+ i = np .abs (spike_max_sites - b * spike_depths ) <= 5
85121
122+ params ["spike_indices" ] = params ["spike_indices" ][i ]
123+ spike_amplitudes = spike_amplitudes [i ]
124+ spike_locations = spike_locations [i , :]
125+ spike_max_sites = spike_max_sites [i ]
126+
127+ return params ["spike_indices" ], spike_amplitudes , spike_locations , spike_max_sites
128+
129+
130+ def _get_locations_from_pc_features (params ):
131+ """
132+
133+ This function is based on code in Nick Steinmetz's `spikes` repository,
134+ https://github.com/cortex-lab/spikes
135+ """
86136 # Compute spike depths
87- pc_features = params ["pc_features" ][:, 0 , :] # Do this compute
137+ pc_features = params ["pc_features" ][:, 0 , :]
88138 pc_features [pc_features < 0 ] = 0
89139
90140 # Some spikes do not load at all onto the first PC. To avoid biasing the
@@ -109,58 +159,28 @@ def compute_spike_amplitude_and_depth(
109159 "to extend this code section to handle more components."
110160 )
111161
112- # Get the channel indexes corresponding to the 32 channels from the PC.
162+ # Get the channel indices corresponding to the 32 channels from the PC.
113163 spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
114164
115165 # Compute the spike locations as the center of mass of the PC scores
116166 spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
117- norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ] # TOOD: see why they use square
167+ norm_weights = (
168+ pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ]
169+ ) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
118170 spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
119171 spike_locations = np .sum (spike_locations , axis = 1 )
120172
121173 # TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122- spike_sites = spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
174+ spike_max_sites = spike_features_indices [
175+ np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )
176+ ]
123177
124- # Amplitude is calculated for each spike as the template amplitude
125- # multiplied by the `template_scaling_amplitudes`.
126- template_amplitudes_unscaled , * _ = get_unwhite_template_info (
127- params ["templates" ],
128- params ["whitening_matrix_inv" ],
129- params ["channel_positions" ],
130- )
131- spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
178+ return spike_locations , spike_max_sites
132179
133- if gain is not None :
134- spike_amplitudes *= gain
135-
136- if localised_spikes_only :
137- # Interpolate the channel ids to location.
138- # Remove spikes > 5 um from average position
139- # Above we already removed non-localized templates, but that on its own is insufficient.
140- # Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
141- # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
142- # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
143- # 3) just use depth. Probably go for that. check with others.
144- spike_depths = spike_locations [:, 1 ]
145- b = stats .linregress (spike_depths , spike_sites ).slope
146- i = np .abs (spike_sites - b * spike_depths ) <= 5 # TODO: need to expose this
147-
148- params ["spike_indexes" ] = params ["spike_indexes" ][i ]
149- spike_amplitudes = spike_amplitudes [i ]
150- spike_locations = spike_locations [i , :]
151-
152- return params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites
153180
154-
155- def _strip_spikes_in_place (params , indices ):
156- """ """
157- params ["spike_templates" ] = params ["spike_templates" ][
158- indices
159- ] # TODO: make an function for this. because we do this a lot
160- params ["spike_indexes" ] = params ["spike_indexes" ][indices ]
161- params ["spike_clusters" ] = params ["spike_clusters" ][indices ]
162- params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][indices ]
163- params ["pc_features" ] = params ["pc_features" ][indices ] # TODO: be conciststetn! change indees to indices
181+ ########################################################################################################################
182+ # Get Template Data
183+ ########################################################################################################################
164184
165185
166186def get_unwhite_template_info (
@@ -173,8 +193,8 @@ def get_unwhite_template_info(
173193 Amplitude is calculated for each spike as the template amplitude
174194 multiplied by the `template_scaling_amplitudes`.
175195
176- This function was ported from Nick Steinmetz's `spikes` repository
177- MATLAB code, https://github.com/cortex-lab/spikes
196+ This function is based on code in Nick Steinmetz's `spikes` repository,
197+ https://github.com/cortex-lab/spikes
178198
179199 Parameters
180200 ----------
@@ -213,7 +233,7 @@ def get_unwhite_template_info(
213233
214234 template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
215235
216- # Zero any small channel amplitudes
236+ # Zero any small channel amplitudes TODO: removed this.
217237 # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218238 # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
219239
@@ -253,9 +273,14 @@ def get_unwhite_template_info(
253273 )
254274
255275
256- def compute_template_amplitudes_from_spikes ():
257- # Take the average of all spike amplitudes to get actual template amplitudes
258- # (since tempScalingAmps are equal mean for all templates)
276+ def compute_template_amplitudes_from_spikes (templates , spike_templates , spike_amplitudes ):
277+ """
278+ Take the average of all spike amplitudes to get actual template amplitudes
279+ (since tempScalingAmps are equal mean for all templates)
280+
281+ This function is ported from Nick Steinmetz's `spikes` repository,
282+ https://github.com/cortex-lab/spikes
283+ """
259284 num_indices = templates .shape [0 ]
260285 sum_per_index = np .zeros (num_indices , dtype = np .float64 )
261286 np .add .at (sum_per_index , spike_templates , spike_amplitudes )
@@ -264,7 +289,12 @@ def compute_template_amplitudes_from_spikes():
264289 return template_amplitudes
265290
266291
267- def _load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
292+ ########################################################################################################################
293+ # Load Parameters from KS Directory
294+ ########################################################################################################################
295+
296+
297+ def load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
268298 """
269299 Loads the output of Kilosort into a `params` dict.
270300
@@ -300,7 +330,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
300330
301331 params = read_python (sorter_output / "params.py" )
302332
303- spike_indexes = np .load (sorter_output / "spike_times.npy" )
333+ spike_indices = np .load (sorter_output / "spike_times.npy" )
304334 spike_templates = np .load (sorter_output / "spike_templates.npy" )
305335
306336 if (clusters_path := sorter_output / "spike_clusters.csv" ).is_dir ():
@@ -328,7 +358,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
328358 noise_cluster_ids = cluster_ids [cluster_groups == 0 ]
329359 not_noise_clusters_by_spike = ~ np .isin (spike_clusters .ravel (), noise_cluster_ids )
330360
331- spike_indexes = spike_indexes [not_noise_clusters_by_spike ]
361+ spike_indices = spike_indices [not_noise_clusters_by_spike ]
332362 spike_templates = spike_templates [not_noise_clusters_by_spike ]
333363 temp_scaling_amplitudes = temp_scaling_amplitudes [not_noise_clusters_by_spike ]
334364
@@ -343,7 +373,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343373 cluster_groups = 3 * np .ones (cluster_ids .size )
344374
345375 new_params = {
346- "spike_indexes " : spike_indexes .squeeze (),
376+ "spike_indices " : spike_indices .squeeze (),
347377 "spike_templates" : spike_templates .squeeze (),
348378 "spike_clusters" : spike_clusters .squeeze (),
349379 "pc_features" : pc_features ,
0 commit comments