Skip to content

Commit 0c5e8bb

Browse files
authored
Merge branch 'main' into basesorter
2 parents 67f50a4 + 4a1a45a commit 0c5e8bb

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

src/spikeinterface/core/template_tools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc
3131
)
3232
ext = one_object.get_extension("templates")
3333
if ext is not None:
34-
templates_array = ext.data["average"]
34+
if "average" in ext.data:
35+
templates_array = ext.data.get("average")
36+
elif "median" in ext.data:
37+
templates_array = ext.data.get("median")
38+
else:
39+
raise ValueError("Average or median templates have not been computed.")
3540
else:
3641
raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates")
3742
else:

src/spikeinterface/postprocessing/localization_tools.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def compute_monopolar_triangulation(
7777

7878
contact_locations = sorting_analyzer_or_templates.get_channel_locations()
7979

80-
sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
80+
if sorting_analyzer_or_templates.sparsity is None:
81+
sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
82+
else:
83+
sparsity = sorting_analyzer_or_templates.sparsity
84+
8185
templates = get_dense_templates_array(
8286
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
8387
)
@@ -157,9 +161,13 @@ def compute_center_of_mass(
157161

158162
assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature"
159163

160-
sparsity = compute_sparsity(
161-
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
162-
)
164+
if sorting_analyzer_or_templates.sparsity is None:
165+
sparsity = compute_sparsity(
166+
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
167+
)
168+
else:
169+
sparsity = sorting_analyzer_or_templates.sparsity
170+
163171
templates = get_dense_templates_array(
164172
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
165173
)
@@ -650,8 +658,55 @@ def get_convolution_weights(
650658
enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True)
651659

652660

661+
def compute_location_max_channel(
662+
templates_or_sorting_analyzer: SortingAnalyzer | Templates,
663+
unit_ids=None,
664+
peak_sign: "neg" | "pos" | "both" = "neg",
665+
mode: "extremum" | "at_index" | "peak_to_peak" = "extremum",
666+
) -> np.ndarray:
667+
"""
668+
Localize a unit using max channel.
669+
670+
This uses internally `get_template_extremum_channel()`
671+
672+
673+
Parameters
674+
----------
675+
templates_or_sorting_analyzer : SortingAnalyzer | Templates
676+
A SortingAnalyzer or Templates object
677+
unit_ids: list[str] | list[int] | None
678+
A list of unit_id to restrict the computation
679+
peak_sign : "neg" | "pos" | "both"
680+
Sign of the template to find extremum channels
681+
mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index"
682+
Where the amplitude is computed
683+
* "extremum" : take the peak value (max or min depending on `peak_sign`)
684+
* "at_index" : take value at `nbefore` index
685+
* "peak_to_peak" : take the peak-to-peak amplitude
686+
687+
Returns
688+
-------
689+
unit_locations: np.ndarray
690+
2d
691+
"""
692+
extremum_channels_index = get_template_extremum_channel(
693+
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index"
694+
)
695+
contact_locations = templates_or_sorting_analyzer.get_channel_locations()
696+
if unit_ids is None:
697+
unit_ids = templates_or_sorting_analyzer.unit_ids
698+
else:
699+
unit_ids = np.asarray(unit_ids)
700+
unit_locations = np.zeros((unit_ids.size, 2), dtype="float32")
701+
for i, unit_id in enumerate(unit_ids):
702+
unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]]
703+
704+
return unit_locations
705+
706+
653707
_unit_location_methods = {
654708
"center_of_mass": compute_center_of_mass,
655709
"grid_convolution": compute_grid_convolution,
656710
"monopolar_triangulation": compute_monopolar_triangulation,
711+
"max_channel": compute_location_max_channel,
657712
}

src/spikeinterface/postprocessing/tests/test_unit_locations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
1313
dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}),
1414
dict(method="monopolar_triangulation", radius_um=150),
1515
dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"),
16+
dict(method="max_channel"),
1617
],
1718
)
1819
def test_extension(self, params):

0 commit comments

Comments
 (0)