2222
2323class ComputeRandomSpikes (AnalyzerExtension ):
2424 """
25- AnalyzerExtension that select some random spikes.
25+ AnalyzerExtension that select somes random spikes.
26+ This allows for a subsampling of spikes for further calculations and is important
27+ for managing that amount of memory and speed of computation in the analyzer.
2628
2729 This will be used by the `waveforms`/`templates` extensions.
2830
29- This internally use `random_spikes_selection()` parameters are the same .
31+ This internally uses `random_spikes_selection()` parameters.
3032
3133 Parameters
3234 ----------
33- method: "uniform" | "all", default: "uniform"
35+ method : "uniform" | "all", default: "uniform"
3436 The method to select the spikes
35- max_spikes_per_unit: int, default: 500
37+ max_spikes_per_unit : int, default: 500
3638 The maximum number of spikes per unit, ignored if method="all"
37- margin_size: int, default: None
39+ margin_size : int, default: None
3840 A margin on each border of segments to avoid border spikes, ignored if method="all"
39- seed: int or None, default: None
41+ seed : int or None, default: None
4042 A seed for the random generator, ignored if method="all"
4143
4244 Returns
@@ -104,7 +106,7 @@ def get_random_spikes(self):
104106 return self ._some_spikes
105107
106108 def get_selected_indices_in_spike_train (self , unit_id , segment_index ):
107- # usefull for Waveforms extractor backwars compatibility
109+ # useful for WaveformExtractor backwards compatibility
108110 # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
109111 sorting = self .sorting_analyzer .sorting
110112 random_spikes_indices = self .data ["random_spikes_indices" ]
@@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
133135
134136 Parameters
135137 ----------
136- ms_before: float, default: 1.0
138+ ms_before : float, default: 1.0
137139 The number of ms to extract before the spike events
138- ms_after: float, default: 2.0
140+ ms_after : float, default: 2.0
139141 The number of ms to extract after the spike events
140- dtype: None | dtype, default: None
142+ dtype : None | dtype, default: None
141143 The dtype of the waveforms. If None, the dtype of the recording is used.
142144
143145 Returns
144146 -------
145- waveforms: np.ndarray
147+ waveforms : np.ndarray
146148 Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
147149 """
148150
@@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
380382 assert isinstance (operators , list )
381383 for operator in operators :
382384 if isinstance (operator , str ):
383- assert operator in ("average" , "std" , "median" , "mad" )
385+ if operator not in ("average" , "std" , "median" , "mad" ):
386+ error_msg = (
387+ f"You have entered an operator { operator } in your `operators` argument which is "
388+ f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead."
389+ )
390+ raise ValueError (error_msg )
384391 else :
385392 assert isinstance (operator , (list , tuple ))
386393 assert len (operator ) == 2
@@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
405412 self ._compute_and_append_from_waveforms (self .params ["operators" ])
406413
407414 else :
408- for operator in self .params ["operators" ]:
409- if operator not in ("average" , "std" ):
410- raise ValueError (f"Computing templates with operators { operator } needs the 'waveforms' extension" )
415+ bad_operator_list = [
416+ operator for operator in self .params ["operators" ] if operator not in ("average" , "std" )
417+ ]
418+ if len (bad_operator_list ) > 0 :
419+ raise ValueError (
420+ f"Computing templates with operators { bad_operator_list } requires the 'waveforms' extension"
421+ )
411422
412423 recording = self .sorting_analyzer .recording
413424 sorting = self .sorting_analyzer .sorting
@@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):
441452
442453 def _compute_and_append_from_waveforms (self , operators ):
443454 if not self .sorting_analyzer .has_extension ("waveforms" ):
444- raise ValueError (f"Computing templates with operators { operators } needs the 'waveforms' extension" )
455+ raise ValueError (f"Computing templates with operators { operators } requires the 'waveforms' extension" )
445456
446457 unit_ids = self .sorting_analyzer .unit_ids
447458 channel_ids = self .sorting_analyzer .channel_ids
@@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):
466477
467478 assert self .sorting_analyzer .has_extension (
468479 "random_spikes"
469- ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes( )"
480+ ), "compute ' templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes' )"
470481 some_spikes = self .sorting_analyzer .get_extension ("random_spikes" ).get_random_spikes ()
471482 for unit_index , unit_id in enumerate (unit_ids ):
472483 spike_mask = some_spikes ["unit_index" ] == unit_index
@@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
549560 if operator != "percentile" :
550561 key = operator
551562 else :
552- assert percentile is not None , "You must provide percentile=..."
563+ assert percentile is not None , "You must provide percentile=... if `operator=percentile` "
553564 key = f"percentile_{ percentile } "
554565
566+ if key not in self .data .keys ():
567+ error_msg = (
568+ f"You have entered `operator={ key } `, but the only operators calculated are "
569+ f"{ list (self .data .keys ())} . Please use one of these as your `operator` in the "
570+ f"`get_data` function."
571+ )
572+ raise ValueError (error_msg )
573+
555574 templates_array = self .data [key ]
556575
557576 if outputs == "numpy" :
@@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
566585 probe = self .sorting_analyzer .get_probe (),
567586 )
568587 else :
569- raise ValueError ("outputs must be numpy or Templates" )
588+ raise ValueError ("outputs must be ` numpy` or ` Templates` " )
570589
571590 def get_templates (self , unit_ids = None , operator = "average" , percentile = None , save = True , outputs = "numpy" ):
572591 """
@@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
576595
577596 Parameters
578597 ----------
579- unit_ids: list or None
598+ unit_ids : list or None
580599 Unit ids to retrieve waveforms for
581- operator: "average" | "median" | "std" | "percentile", default: "average"
600+ operator : "average" | "median" | "std" | "percentile", default: "average"
582601 The operator to compute the templates
583- percentile: float, default: None
602+ percentile : float, default: None
584603 Percentile to use for operator="percentile"
585- save: bool, default True
604+ save : bool, default: True
586605 In case, the operator is not computed yet it can be saved to folder or zarr
587- outputs: "numpy" | "Templates"
606+ outputs : "numpy" | "Templates", default: "numpy "
588607 Whether to return a numpy array or a Templates object
589608
590609 Returns
591610 -------
592- templates: np.array
611+ templates : np.array | Templates
593612 The returned templates (num_units, num_samples, num_channels)
594613 """
595614 if operator != "percentile" :
596615 key = operator
597616 else :
598- assert percentile is not None , "You must provide percentile=..."
617+ assert percentile is not None , "You must provide percentile=... if `operator='percentile'` "
599618 key = f"pencentile_{ percentile } "
600619
601620 if key in self .data :
@@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
632651 is_scaled = self .sorting_analyzer .return_scaled ,
633652 )
634653 else :
635- raise ValueError ("outputs must be numpy or Templates" )
654+ raise ValueError ("` outputs` must be ' numpy' or ' Templates' " )
636655
637656 def get_unit_template (self , unit_id , operator = "average" ):
638657 """
@@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
642661 ----------
643662 unit_id: str | int
644663 Unit id to retrieve waveforms for
645- operator: str
664+ operator: str, default: "average"
646665 The operator to compute the templates
647666
648667 Returns
@@ -701,13 +720,13 @@ def _set_params(self, **noise_level_params):
701720 return params
702721
703722 def _select_extension_data (self , unit_ids ):
704- # this do not depend on units
723+ # this does not depend on units
705724 return self .data
706725
707726 def _merge_extension_data (
708727 self , merge_unit_groups , new_unit_ids , new_sorting_analyzer , keep_mask = None , verbose = False , ** job_kwargs
709728 ):
710- # this do not depend on units
729+ # this does not depend on units
711730 return self .data .copy ()
712731
713732 def _run (self , verbose = False ):
0 commit comments