Skip to content

Commit 5f81566

Browse files
authored
Merge pull request #3501 from zm711/error-get-data
Add error messaging around use of get data in templates
2 parents c4d2eaa + 7379a96 commit 5f81566

File tree

1 file changed

+49
-30
lines changed

1 file changed

+49
-30
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@
2222

2323
class ComputeRandomSpikes(AnalyzerExtension):
2424
"""
25-
AnalyzerExtension that select some random spikes.
25+
AnalyzerExtension that select somes random spikes.
26+
This allows for a subsampling of spikes for further calculations and is important
27+
for managing that amount of memory and speed of computation in the analyzer.
2628
2729
This will be used by the `waveforms`/`templates` extensions.
2830
29-
This internally use `random_spikes_selection()` parameters are the same.
31+
This internally uses `random_spikes_selection()` parameters.
3032
3133
Parameters
3234
----------
33-
method: "uniform" | "all", default: "uniform"
35+
method : "uniform" | "all", default: "uniform"
3436
The method to select the spikes
35-
max_spikes_per_unit: int, default: 500
37+
max_spikes_per_unit : int, default: 500
3638
The maximum number of spikes per unit, ignored if method="all"
37-
margin_size: int, default: None
39+
margin_size : int, default: None
3840
A margin on each border of segments to avoid border spikes, ignored if method="all"
39-
seed: int or None, default: None
41+
seed : int or None, default: None
4042
A seed for the random generator, ignored if method="all"
4143
4244
Returns
@@ -104,7 +106,7 @@ def get_random_spikes(self):
104106
return self._some_spikes
105107

106108
def get_selected_indices_in_spike_train(self, unit_id, segment_index):
107-
# usefull for Waveforms extractor backwars compatibility
109+
# useful for WaveformExtractor backwards compatibility
108110
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
109111
sorting = self.sorting_analyzer.sorting
110112
random_spikes_indices = self.data["random_spikes_indices"]
@@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
133135
134136
Parameters
135137
----------
136-
ms_before: float, default: 1.0
138+
ms_before : float, default: 1.0
137139
The number of ms to extract before the spike events
138-
ms_after: float, default: 2.0
140+
ms_after : float, default: 2.0
139141
The number of ms to extract after the spike events
140-
dtype: None | dtype, default: None
142+
dtype : None | dtype, default: None
141143
The dtype of the waveforms. If None, the dtype of the recording is used.
142144
143145
Returns
144146
-------
145-
waveforms: np.ndarray
147+
waveforms : np.ndarray
146148
Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
147149
"""
148150

@@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
380382
assert isinstance(operators, list)
381383
for operator in operators:
382384
if isinstance(operator, str):
383-
assert operator in ("average", "std", "median", "mad")
385+
if operator not in ("average", "std", "median", "mad"):
386+
error_msg = (
387+
f"You have entered an operator {operator} in your `operators` argument which is "
388+
f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead."
389+
)
390+
raise ValueError(error_msg)
384391
else:
385392
assert isinstance(operator, (list, tuple))
386393
assert len(operator) == 2
@@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
405412
self._compute_and_append_from_waveforms(self.params["operators"])
406413

407414
else:
408-
for operator in self.params["operators"]:
409-
if operator not in ("average", "std"):
410-
raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension")
415+
bad_operator_list = [
416+
operator for operator in self.params["operators"] if operator not in ("average", "std")
417+
]
418+
if len(bad_operator_list) > 0:
419+
raise ValueError(
420+
f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension"
421+
)
411422

412423
recording = self.sorting_analyzer.recording
413424
sorting = self.sorting_analyzer.sorting
@@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):
441452

442453
def _compute_and_append_from_waveforms(self, operators):
443454
if not self.sorting_analyzer.has_extension("waveforms"):
444-
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
455+
raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension")
445456

446457
unit_ids = self.sorting_analyzer.unit_ids
447458
channel_ids = self.sorting_analyzer.channel_ids
@@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):
466477

467478
assert self.sorting_analyzer.has_extension(
468479
"random_spikes"
469-
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
480+
), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')"
470481
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
471482
for unit_index, unit_id in enumerate(unit_ids):
472483
spike_mask = some_spikes["unit_index"] == unit_index
@@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
549560
if operator != "percentile":
550561
key = operator
551562
else:
552-
assert percentile is not None, "You must provide percentile=..."
563+
assert percentile is not None, "You must provide percentile=... if `operator=percentile`"
553564
key = f"percentile_{percentile}"
554565

566+
if key not in self.data.keys():
567+
error_msg = (
568+
f"You have entered `operator={key}`, but the only operators calculated are "
569+
f"{list(self.data.keys())}. Please use one of these as your `operator` in the "
570+
f"`get_data` function."
571+
)
572+
raise ValueError(error_msg)
573+
555574
templates_array = self.data[key]
556575

557576
if outputs == "numpy":
@@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
566585
probe=self.sorting_analyzer.get_probe(),
567586
)
568587
else:
569-
raise ValueError("outputs must be numpy or Templates")
588+
raise ValueError("outputs must be `numpy` or `Templates`")
570589

571590
def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"):
572591
"""
@@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
576595
577596
Parameters
578597
----------
579-
unit_ids: list or None
598+
unit_ids : list or None
580599
Unit ids to retrieve waveforms for
581-
operator: "average" | "median" | "std" | "percentile", default: "average"
600+
operator : "average" | "median" | "std" | "percentile", default: "average"
582601
The operator to compute the templates
583-
percentile: float, default: None
602+
percentile : float, default: None
584603
Percentile to use for operator="percentile"
585-
save: bool, default True
604+
save : bool, default: True
586605
In case, the operator is not computed yet it can be saved to folder or zarr
587-
outputs: "numpy" | "Templates"
606+
outputs : "numpy" | "Templates", default: "numpy"
588607
Whether to return a numpy array or a Templates object
589608
590609
Returns
591610
-------
592-
templates: np.array
611+
templates : np.array | Templates
593612
The returned templates (num_units, num_samples, num_channels)
594613
"""
595614
if operator != "percentile":
596615
key = operator
597616
else:
598-
assert percentile is not None, "You must provide percentile=..."
617+
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
599618
key = f"pencentile_{percentile}"
600619

601620
if key in self.data:
@@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
632651
is_scaled=self.sorting_analyzer.return_scaled,
633652
)
634653
else:
635-
raise ValueError("outputs must be numpy or Templates")
654+
raise ValueError("`outputs` must be 'numpy' or 'Templates'")
636655

637656
def get_unit_template(self, unit_id, operator="average"):
638657
"""
@@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
642661
----------
643662
unit_id: str | int
644663
Unit id to retrieve waveforms for
645-
operator: str
664+
operator: str, default: "average"
646665
The operator to compute the templates
647666
648667
Returns
@@ -701,13 +720,13 @@ def _set_params(self, **noise_level_params):
701720
return params
702721

703722
def _select_extension_data(self, unit_ids):
704-
# this do not depend on units
723+
# this does not depend on units
705724
return self.data
706725

707726
def _merge_extension_data(
708727
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
709728
):
710-
# this do not depend on units
729+
# this does not depend on units
711730
return self.data.copy()
712731

713732
def _run(self, verbose=False):

0 commit comments

Comments
 (0)