Skip to content

Commit 2c34c91

Browse files
committed
add channel recording to the base recording api
1 parent 48b2131 commit 2c34c91

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,30 @@ def _select_segments(self, segment_indices):
746746

747747
return SelectSegmentRecording(self, segment_indices=segment_indices)
748748

749+
def get_channel_locations(
750+
self,
751+
channel_ids: list | np.ndarray | tuple | None = None,
752+
axes: "xy" | "yz" | "xz" = "xy",
753+
) -> np.ndarray:
754+
"""
755+
Get the physical locations of specified channels.
756+
757+
Parameters
758+
----------
759+
channel_ids : array-like, optional
760+
The IDs of the channels for which to retrieve locations. If None, retrieves locations
761+
for all available channels. Default is None.
762+
axes : str, optional
763+
The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy".
764+
765+
Returns
766+
-------
767+
np.ndarray
768+
A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels.
769+
The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz").
770+
"""
771+
return super().get_channel_locations(channel_ids=channel_ids, axes=axes)
772+
749773
def is_binary_compatible(self) -> bool:
750774
"""
751775
Checks if the recording is "binary" compatible.

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def set_channel_locations(self, locations, channel_ids=None):
344344
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
345345
self.set_property("location", locations, ids=channel_ids)
346346

347-
def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
347+
def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
348348
if channel_ids is None:
349349
channel_ids = self.get_channel_ids()
350350
channel_indices = self.ids_to_indices(channel_ids)

0 commit comments

Comments
 (0)