Skip to content

Commit 8533a52

Browse files
Reducing memory footprint for Overlaps during matching (#4157)
* Reducing memory footprint * WIP * Reducing memory footprint for large number of templates/channels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9efde62 commit 8533a52

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

src/spikeinterface/sortingcomponents/matching/circus.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,11 @@ def _prepare_templates(self):
247247
else:
248248
sparsity = self.templates.sparsity.mask
249249

250-
units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2)
251-
self.units_overlaps = units_overlaps > 0
250+
# units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2)
252251
self.unit_overlaps_indices = {}
252+
self.units_overlaps = {}
253253
for i in range(self.num_templates):
254+
self.units_overlaps[i] = np.sum(np.logical_and(sparsity[i, :], sparsity), axis=1) > 0
254255
self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i])
255256

256257
templates_array = self.templates.get_dense_templates().copy()

src/spikeinterface/sortingcomponents/matching/wobble.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,17 @@ def from_templates(cls, params, templates):
278278
Dataclass object for aggregating channel sparsity variables together.
279279
"""
280280
visible_channels = templates.sparsity.mask
281-
unit_overlap = np.sum(
282-
np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2
283-
)
284-
unit_overlap = unit_overlap > 0
281+
num_templates = templates.get_dense_templates().shape[0]
282+
unit_overlap = np.zeros((num_templates, num_templates), dtype=bool)
283+
284+
for i in range(num_templates):
285+
unit_overlap[i] = np.sum(np.logical_and(visible_channels[i], visible_channels), axis=1) > 0
286+
287+
# unit_overlap = np.sum(
288+
# np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2
289+
# )
290+
# unit_overlap = unit_overlap > 0
291+
285292
unit_overlap = np.repeat(unit_overlap, params.jitter_factor, axis=0)
286293
sparsity = cls(visible_channels=visible_channels, unit_overlap=unit_overlap)
287294
return sparsity

src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def __init__(
216216
device=None,
217217
radius_um=50,
218218
return_tensor=False,
219-
random_chunk_kwargs={},
220219
return_output=True,
221220
):
222221
if not HAVE_TORCH:
@@ -288,14 +287,13 @@ def __init__(
288287
exclude_sweep_ms=0.1,
289288
radius_um=50,
290289
noise_levels=None,
291-
random_chunk_kwargs={},
292290
opencl_context_kwargs={},
293291
):
294292
if not HAVE_PYOPENCL:
295293
raise ModuleNotFoundError('"locally_exclusive_cl" needs pyopencl which is not installed')
296294

297295
LocallyExclusivePeakDetector.__init__(
298-
self, recording, peak_sign, detect_threshold, exclude_sweep_ms, radius_um, noise_levels, random_chunk_kwargs
296+
self, recording, peak_sign, detect_threshold, exclude_sweep_ms, radius_um, noise_levels
299297
)
300298

301299
self.executor = OpenCLDetectPeakExecutor(

0 commit comments

Comments
 (0)