Skip to content

Commit b9f50e3

Browse files
authored
Merge pull request #3403 from h-mayorquin/add_get_channel_locations_to_the_api
Add `get_channel_locations` to the base recording api
2 parents 9d7832c + 895288c commit b9f50e3

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,30 @@ def _select_segments(self, segment_indices):
746746

747747
return SelectSegmentRecording(self, segment_indices=segment_indices)
748748

749+
def get_channel_locations(
750+
self,
751+
channel_ids: list | np.ndarray | tuple | None = None,
752+
axes: "xy" | "yz" | "xz" | "xyz" = "xy",
753+
) -> np.ndarray:
754+
"""
755+
Get the physical locations of specified channels.
756+
757+
Parameters
758+
----------
759+
channel_ids : array-like, optional
760+
The IDs of the channels for which to retrieve locations. If None, retrieves locations
761+
for all available channels. Default is None.
762+
axes : "xy" | "yz" | "xz" | "xyz", default: "xy"
763+
The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy".
764+
765+
Returns
766+
-------
767+
np.ndarray
768+
A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels.
769+
The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz").
770+
"""
771+
return super().get_channel_locations(channel_ids=channel_ids, axes=axes)
772+
749773
def is_binary_compatible(self) -> bool:
750774
"""
751775
Checks if the recording is "binary" compatible.

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def set_channel_locations(self, locations, channel_ids=None):
349349
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
350350
self.set_property("location", locations, ids=channel_ids)
351351

352-
def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
352+
def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
353353
if channel_ids is None:
354354
channel_ids = self.get_channel_ids()
355355
channel_indices = self.ids_to_indices(channel_ids)

0 commit comments

Comments
 (0)