@@ -77,7 +77,11 @@ def compute_monopolar_triangulation(
7777
7878 contact_locations = sorting_analyzer_or_templates .get_channel_locations ()
7979
80- sparsity = compute_sparsity (sorting_analyzer_or_templates , method = "radius" , radius_um = radius_um )
80+ if sorting_analyzer_or_templates .sparsity is None :
81+ sparsity = compute_sparsity (sorting_analyzer_or_templates , method = "radius" , radius_um = radius_um )
82+ else :
83+ sparsity = sorting_analyzer_or_templates .sparsity
84+
8185 templates = get_dense_templates_array (
8286 sorting_analyzer_or_templates , return_scaled = get_return_scaled (sorting_analyzer_or_templates )
8387 )
@@ -157,9 +161,13 @@ def compute_center_of_mass(
157161
158162 assert feature in ["ptp" , "mean" , "energy" , "peak_voltage" ], f"{ feature } is not a valid feature"
159163
160- sparsity = compute_sparsity (
161- sorting_analyzer_or_templates , peak_sign = peak_sign , method = "radius" , radius_um = radius_um
162- )
164+ if sorting_analyzer_or_templates .sparsity is None :
165+ sparsity = compute_sparsity (
166+ sorting_analyzer_or_templates , peak_sign = peak_sign , method = "radius" , radius_um = radius_um
167+ )
168+ else :
169+ sparsity = sorting_analyzer_or_templates .sparsity
170+
163171 templates = get_dense_templates_array (
164172 sorting_analyzer_or_templates , return_scaled = get_return_scaled (sorting_analyzer_or_templates )
165173 )
@@ -650,8 +658,55 @@ def get_convolution_weights(
650658 enforce_decrease_shells = numba .jit (enforce_decrease_shells_data , nopython = True )
651659
652660
661+ def compute_location_max_channel (
662+ templates_or_sorting_analyzer : SortingAnalyzer | Templates ,
663+ unit_ids = None ,
664+ peak_sign : "neg" | "pos" | "both" = "neg" ,
665+ mode : "extremum" | "at_index" | "peak_to_peak" = "extremum" ,
666+ ) -> np .ndarray :
667+ """
668+ Localize a unit using max channel.
669+
670+ This uses internally `get_template_extremum_channel()`
671+
672+
673+ Parameters
674+ ----------
675+ templates_or_sorting_analyzer : SortingAnalyzer | Templates
676+ A SortingAnalyzer or Templates object
677+ unit_ids: list[str] | list[int] | None
678+ A list of unit_id to restrict the computation
679+ peak_sign : "neg" | "pos" | "both"
680+ Sign of the template to find extremum channels
681+ mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index"
682+ Where the amplitude is computed
683+ * "extremum" : take the peak value (max or min depending on `peak_sign`)
684+ * "at_index" : take value at `nbefore` index
685+ * "peak_to_peak" : take the peak-to-peak amplitude
686+
687+ Returns
688+ -------
689+ unit_locations: np.ndarray
690+ 2d
691+ """
692+ extremum_channels_index = get_template_extremum_channel (
693+ templates_or_sorting_analyzer , peak_sign = peak_sign , mode = mode , outputs = "index"
694+ )
695+ contact_locations = templates_or_sorting_analyzer .get_channel_locations ()
696+ if unit_ids is None :
697+ unit_ids = templates_or_sorting_analyzer .unit_ids
698+ else :
699+ unit_ids = np .asarray (unit_ids )
700+ unit_locations = np .zeros ((unit_ids .size , 2 ), dtype = "float32" )
701+ for i , unit_id in enumerate (unit_ids ):
702+ unit_locations [i , :] = contact_locations [extremum_channels_index [unit_id ]]
703+
704+ return unit_locations
705+
706+
653707_unit_location_methods = {
654708 "center_of_mass" : compute_center_of_mass ,
655709 "grid_convolution" : compute_grid_convolution ,
656710 "monopolar_triangulation" : compute_monopolar_triangulation ,
711+ "max_channel" : compute_location_max_channel ,
657712}
0 commit comments