2222
2323class ComputeRandomSpikes (AnalyzerExtension ):
2424 """
25- AnalyzerExtension that select some random spikes.
25+ AnalyzerExtension that select somes random spikes.
26+ This is allows for a subsampling of spikes for further calculations and is important
27+ for managing that amount of memory and speed of computation in the analyzer.
2628
2729 This will be used by the `waveforms`/`templates` extensions.
2830
29- This internally use `random_spikes_selection()` parameters are the same .
31+ This internally use `random_spikes_selection()` parameters.
3032
3133 Parameters
3234 ----------
33- method: "uniform" | "all", default: "uniform"
35+ method : "uniform" | "all", default: "uniform"
3436 The method to select the spikes
35- max_spikes_per_unit: int, default: 500
37+ max_spikes_per_unit : int, default: 500
3638 The maximum number of spikes per unit, ignored if method="all"
37- margin_size: int, default: None
39+ margin_size : int, default: None
3840 A margin on each border of segments to avoid border spikes, ignored if method="all"
39- seed: int or None, default: None
41+ seed : int or None, default: None
4042 A seed for the random generator, ignored if method="all"
4143
4244 Returns
@@ -104,7 +106,7 @@ def get_random_spikes(self):
104106 return self ._some_spikes
105107
106108 def get_selected_indices_in_spike_train (self , unit_id , segment_index ):
107- # usefull for Waveforms extractor backwars compatibility
109+ # useful for Waveforms extractor backwars compatibility
108110 # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
109111 sorting = self .sorting_analyzer .sorting
110112 random_spikes_indices = self .data ["random_spikes_indices" ]
@@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
133135
134136 Parameters
135137 ----------
136- ms_before: float, default: 1.0
138+ ms_before : float, default: 1.0
137139 The number of ms to extract before the spike events
138- ms_after: float, default: 2.0
140+ ms_after : float, default: 2.0
139141 The number of ms to extract after the spike events
140- dtype: None | dtype, default: None
142+ dtype : None | dtype, default: None
141143 The dtype of the waveforms. If None, the dtype of the recording is used.
142144
143145 Returns
144146 -------
145- waveforms: np.ndarray
147+ waveforms : np.ndarray
146148 Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
147149 """
148150
@@ -410,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
410412 self ._compute_and_append_from_waveforms (self .params ["operators" ])
411413
412414 else :
413- for operator in self .params ["operators" ]:
414- if operator not in ("average" , "std" ):
415- raise ValueError (f"Computing templates with operators { operator } needs the 'waveforms' extension" )
415+ bad_operator_list = [
416+ operator for operator in self .params ["operators" ] if operator not in ("average" , "std" )
417+ ]
418+ if len (bad_operator_list ) > 0 :
419+ raise ValueError (
420+ f"Computing templates with operators { bad_operator_list } requires the 'waveforms' extension"
421+ )
416422
417423 recording = self .sorting_analyzer .recording
418424 sorting = self .sorting_analyzer .sorting
@@ -446,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):
446452
447453 def _compute_and_append_from_waveforms (self , operators ):
448454 if not self .sorting_analyzer .has_extension ("waveforms" ):
449- raise ValueError (f"Computing templates with operators { operators } needs the 'waveforms' extension" )
455+ raise ValueError (f"Computing templates with operators { operators } requires the 'waveforms' extension" )
450456
451457 unit_ids = self .sorting_analyzer .unit_ids
452458 channel_ids = self .sorting_analyzer .channel_ids
@@ -471,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):
471477
472478 assert self .sorting_analyzer .has_extension (
473479 "random_spikes"
474- ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes( )"
480+ ), "compute ' templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes' )"
475481 some_spikes = self .sorting_analyzer .get_extension ("random_spikes" ).get_random_spikes ()
476482 for unit_index , unit_id in enumerate (unit_ids ):
477483 spike_mask = some_spikes ["unit_index" ] == unit_index
@@ -579,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
579585 probe = self .sorting_analyzer .get_probe (),
580586 )
581587 else :
582- raise ValueError ("outputs must be numpy or Templates" )
588+ raise ValueError ("outputs must be ` numpy` or ` Templates` " )
583589
584590 def get_templates (self , unit_ids = None , operator = "average" , percentile = None , save = True , outputs = "numpy" ):
585591 """
@@ -589,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
589595
590596 Parameters
591597 ----------
592- unit_ids: list or None
598+ unit_ids : list or None
593599 Unit ids to retrieve waveforms for
594- operator: "average" | "median" | "std" | "percentile", default: "average"
600+ operator : "average" | "median" | "std" | "percentile", default: "average"
595601 The operator to compute the templates
596- percentile: float, default: None
602+ percentile : float, default: None
597603 Percentile to use for operator="percentile"
598- save: bool, default True
604+ save : bool, default: True
599605 In case, the operator is not computed yet it can be saved to folder or zarr
600- outputs: "numpy" | "Templates"
606+ outputs : "numpy" | "Templates", default: "numpy "
601607 Whether to return a numpy array or a Templates object
602608
603609 Returns
604610 -------
605- templates: np.array
611+ templates : np.array | Templates
606612 The returned templates (num_units, num_samples, num_channels)
607613 """
608614 if operator != "percentile" :
609615 key = operator
610616 else :
611- assert percentile is not None , "You must provide percentile=..."
617+ assert percentile is not None , "You must provide percentile=... if `operator='percentile'` "
612618 key = f"pencentile_{ percentile } "
613619
614620 if key in self .data :
@@ -645,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
645651 is_scaled = self .sorting_analyzer .return_scaled ,
646652 )
647653 else :
648- raise ValueError ("outputs must be numpy or Templates" )
654+ raise ValueError ("` outputs` must be ' numpy' or ' Templates' " )
649655
650656 def get_unit_template (self , unit_id , operator = "average" ):
651657 """
@@ -655,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
655661 ----------
656662 unit_id: str | int
657663 Unit id to retrieve waveforms for
658- operator: str
664+ operator: str, default: "average"
659665 The operator to compute the templates
660666
661667 Returns
@@ -713,13 +719,13 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
713719 return params
714720
715721 def _select_extension_data (self , unit_ids ):
716- # this do not depend on units
722+ # this does not depend on units
717723 return self .data
718724
719725 def _merge_extension_data (
720726 self , merge_unit_groups , new_unit_ids , new_sorting_analyzer , keep_mask = None , verbose = False , ** job_kwargs
721727 ):
722- # this do not depend on units
728+ # this does not depend on units
723729 return self .data .copy ()
724730
725731 def _run (self , verbose = False ):
0 commit comments