Skip to content

Commit 22d19d5

Browse files
committed
more docs stuff
1 parent b5bd2fb commit 22d19d5

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@
2222

2323
class ComputeRandomSpikes(AnalyzerExtension):
2424
"""
25-
AnalyzerExtension that select some random spikes.
25+
AnalyzerExtension that select somes random spikes.
26+
This is allows for a subsampling of spikes for further calculations and is important
27+
for managing that amount of memory and speed of computation in the analyzer.
2628
2729
This will be used by the `waveforms`/`templates` extensions.
2830
29-
This internally use `random_spikes_selection()` parameters are the same.
31+
This internally use `random_spikes_selection()` parameters.
3032
3133
Parameters
3234
----------
33-
method: "uniform" | "all", default: "uniform"
35+
method : "uniform" | "all", default: "uniform"
3436
The method to select the spikes
35-
max_spikes_per_unit: int, default: 500
37+
max_spikes_per_unit : int, default: 500
3638
The maximum number of spikes per unit, ignored if method="all"
37-
margin_size: int, default: None
39+
margin_size : int, default: None
3840
A margin on each border of segments to avoid border spikes, ignored if method="all"
39-
seed: int or None, default: None
41+
seed : int or None, default: None
4042
A seed for the random generator, ignored if method="all"
4143
4244
Returns
@@ -104,7 +106,7 @@ def get_random_spikes(self):
104106
return self._some_spikes
105107

106108
def get_selected_indices_in_spike_train(self, unit_id, segment_index):
107-
# usefull for Waveforms extractor backwars compatibility
109+
# useful for Waveforms extractor backwars compatibility
108110
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
109111
sorting = self.sorting_analyzer.sorting
110112
random_spikes_indices = self.data["random_spikes_indices"]
@@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
133135
134136
Parameters
135137
----------
136-
ms_before: float, default: 1.0
138+
ms_before : float, default: 1.0
137139
The number of ms to extract before the spike events
138-
ms_after: float, default: 2.0
140+
ms_after : float, default: 2.0
139141
The number of ms to extract after the spike events
140-
dtype: None | dtype, default: None
142+
dtype : None | dtype, default: None
141143
The dtype of the waveforms. If None, the dtype of the recording is used.
142144
143145
Returns
144146
-------
145-
waveforms: np.ndarray
147+
waveforms : np.ndarray
146148
Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
147149
"""
148150

@@ -410,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
410412
self._compute_and_append_from_waveforms(self.params["operators"])
411413

412414
else:
413-
for operator in self.params["operators"]:
414-
if operator not in ("average", "std"):
415-
raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension")
415+
bad_operator_list = [
416+
operator for operator in self.params["operators"] if operator not in ("average", "std")
417+
]
418+
if len(bad_operator_list) > 0:
419+
raise ValueError(
420+
f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension"
421+
)
416422

417423
recording = self.sorting_analyzer.recording
418424
sorting = self.sorting_analyzer.sorting
@@ -446,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):
446452

447453
def _compute_and_append_from_waveforms(self, operators):
448454
if not self.sorting_analyzer.has_extension("waveforms"):
449-
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
455+
raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension")
450456

451457
unit_ids = self.sorting_analyzer.unit_ids
452458
channel_ids = self.sorting_analyzer.channel_ids
@@ -471,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):
471477

472478
assert self.sorting_analyzer.has_extension(
473479
"random_spikes"
474-
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
480+
), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')"
475481
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
476482
for unit_index, unit_id in enumerate(unit_ids):
477483
spike_mask = some_spikes["unit_index"] == unit_index
@@ -579,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
579585
probe=self.sorting_analyzer.get_probe(),
580586
)
581587
else:
582-
raise ValueError("outputs must be numpy or Templates")
588+
raise ValueError("outputs must be `numpy` or `Templates`")
583589

584590
def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"):
585591
"""
@@ -589,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
589595
590596
Parameters
591597
----------
592-
unit_ids: list or None
598+
unit_ids : list or None
593599
Unit ids to retrieve waveforms for
594-
operator: "average" | "median" | "std" | "percentile", default: "average"
600+
operator : "average" | "median" | "std" | "percentile", default: "average"
595601
The operator to compute the templates
596-
percentile: float, default: None
602+
percentile : float, default: None
597603
Percentile to use for operator="percentile"
598-
save: bool, default True
604+
save : bool, default: True
599605
In case, the operator is not computed yet it can be saved to folder or zarr
600-
outputs: "numpy" | "Templates"
606+
outputs : "numpy" | "Templates", default: "numpy"
601607
Whether to return a numpy array or a Templates object
602608
603609
Returns
604610
-------
605-
templates: np.array
611+
templates : np.array | Templates
606612
The returned templates (num_units, num_samples, num_channels)
607613
"""
608614
if operator != "percentile":
609615
key = operator
610616
else:
611-
assert percentile is not None, "You must provide percentile=..."
617+
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
612618
key = f"pencentile_{percentile}"
613619

614620
if key in self.data:
@@ -645,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
645651
is_scaled=self.sorting_analyzer.return_scaled,
646652
)
647653
else:
648-
raise ValueError("outputs must be numpy or Templates")
654+
raise ValueError("`outputs` must be 'numpy' or 'Templates'")
649655

650656
def get_unit_template(self, unit_id, operator="average"):
651657
"""
@@ -655,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
655661
----------
656662
unit_id: str | int
657663
Unit id to retrieve waveforms for
658-
operator: str
664+
operator: str, default: "average"
659665
The operator to compute the templates
660666
661667
Returns
@@ -713,13 +719,13 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
713719
return params
714720

715721
def _select_extension_data(self, unit_ids):
716-
# this do not depend on units
722+
# this does not depend on units
717723
return self.data
718724

719725
def _merge_extension_data(
720726
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
721727
):
722-
# this do not depend on units
728+
# this does not depend on units
723729
return self.data.copy()
724730

725731
def _run(self, verbose=False):

0 commit comments

Comments
 (0)