Skip to content

Commit cba0f3b

Browse files
authored
Merge branch 'main' into ibl-extractor-pid
2 parents d9b169d + 3df1083 commit cba0f3b

File tree

5 files changed

+96
-52
lines changed

5 files changed

+96
-52
lines changed

src/spikeinterface/postprocessing/principal_component.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

3-
import shutil
4-
import pickle
53
import warnings
6-
import tempfile
4+
import platform
75
from pathlib import Path
86
from tqdm.auto import tqdm
97

8+
from concurrent.futures import ProcessPoolExecutor
9+
import multiprocessing as mp
10+
from threadpoolctl import threadpool_limits
11+
1012
import numpy as np
1113

1214
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
@@ -314,11 +316,13 @@ def _run(self, verbose=False, **job_kwargs):
314316
job_kwargs = fix_job_kwargs(job_kwargs)
315317
n_jobs = job_kwargs["n_jobs"]
316318
progress_bar = job_kwargs["progress_bar"]
319+
max_threads_per_process = job_kwargs["max_threads_per_process"]
320+
mp_context = job_kwargs["mp_context"]
317321

318322
# fit model/models
319323
# TODO : make parralel for by_channel_global and concatenated
320324
if mode == "by_channel_local":
321-
pca_models = self._fit_by_channel_local(n_jobs, progress_bar)
325+
pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context)
322326
for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids):
323327
self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind]
324328
pca_model = pca_models
@@ -411,12 +415,16 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):
411415
)
412416
processor.run()
413417

414-
def _fit_by_channel_local(self, n_jobs, progress_bar):
418+
def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context):
415419
from sklearn.decomposition import IncrementalPCA
416-
from concurrent.futures import ProcessPoolExecutor
417420

418421
p = self.params
419422

423+
if mp_context is not None and platform.system() == "Windows":
424+
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
425+
elif mp_context == "fork" and platform.system() == "Darwin":
426+
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
427+
420428
unit_ids = self.sorting_analyzer.unit_ids
421429
channel_ids = self.sorting_analyzer.channel_ids
422430
# there is one PCA per channel for independent fit per channel
@@ -436,13 +444,18 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
436444
pca = pca_models[chan_ind]
437445
pca.partial_fit(wfs[:, :, wf_ind])
438446
else:
439-
# parallel
447+
# create list of args to parallelize. For convenience, the max_threads_per_process is passed
448+
# as last argument
440449
items = [
441-
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)
450+
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process)
451+
for wf_ind, chan_ind in enumerate(channel_inds)
442452
]
443453
n_jobs = min(n_jobs, len(items))
444454

445-
with ProcessPoolExecutor(max_workers=n_jobs) as executor:
455+
with ProcessPoolExecutor(
456+
max_workers=n_jobs,
457+
mp_context=mp.get_context(mp_context),
458+
) as executor:
446459
results = executor.map(_partial_fit_one_channel, items)
447460
for chan_ind, pca_model_updated in results:
448461
pca_models[chan_ind] = pca_model_updated
@@ -674,6 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte
674687

675688

676689
def _partial_fit_one_channel(args):
677-
chan_ind, pca_model, wf_chan = args
678-
pca_model.partial_fit(wf_chan)
679-
return chan_ind, pca_model
690+
chan_ind, pca_model, wf_chan, max_threads_per_process = args
691+
692+
if max_threads_per_process is None:
693+
pca_model.partial_fit(wf_chan)
694+
return chan_ind, pca_model
695+
else:
696+
with threadpool_limits(limits=int(max_threads_per_process)):
697+
pca_model.partial_fit(wf_chan)
698+
return chan_ind, pca_model

src/spikeinterface/postprocessing/tests/test_principal_component.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite):
1818
def test_extension(self, params):
1919
self.run_extension_tests(ComputePrincipalComponents, params=params)
2020

21+
def test_multi_processing(self):
22+
"""
23+
Test the extension works with multiple processes.
24+
"""
25+
sorting_analyzer = self._prepare_sorting_analyzer(
26+
format="memory", sparse=False, extension_class=ComputePrincipalComponents
27+
)
28+
sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2)
29+
sorting_analyzer.compute(
30+
"principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn"
31+
)
32+
2133
def test_mode_concatenated(self):
2234
"""
2335
Replicate the "extension_function_params_list" test outside of

src/spikeinterface/preprocessing/tests/test_filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data):
4646

4747
filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()
4848

49-
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6)
49+
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4)
5050

5151
# Then, change all kwargs to ensure they are propagated
5252
# and check the backwards version.
@@ -66,7 +66,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data):
6666

6767
filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces()
6868

69-
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6)
69+
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4)
7070

7171
def test_causal_filter_custom_coeff(self, recording_and_data):
7272
"""
@@ -89,7 +89,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data):
8989

9090
filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()
9191

92-
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True)
92+
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4, equal_nan=True)
9393

9494
# Next, in "sos" mode
9595
options["filter_mode"] = "sos"
@@ -100,7 +100,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data):
100100

101101
filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()
102102

103-
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True)
103+
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4, equal_nan=True)
104104

105105
def test_causal_kwarg_error_raised(self, recording_and_data):
106106
"""

src/spikeinterface/qualitymetrics/pca_metrics.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
from __future__ import annotations
44

5-
5+
import warnings
66
from copy import deepcopy
7-
8-
import numpy as np
7+
import platform
98
from tqdm.auto import tqdm
10-
from concurrent.futures import ProcessPoolExecutor
119

10+
import numpy as np
1211

13-
import warnings
12+
import multiprocessing as mp
13+
from concurrent.futures import ProcessPoolExecutor
14+
from threadpoolctl import threadpool_limits
1415

1516
from .misc_metrics import compute_num_spikes, compute_firing_rates
1617

@@ -56,6 +57,8 @@ def compute_pc_metrics(
5657
seed=None,
5758
n_jobs=1,
5859
progress_bar=False,
60+
mp_context=None,
61+
max_threads_per_process=None,
5962
) -> dict:
6063
"""
6164
Calculate principal component derived metrics.
@@ -144,17 +147,7 @@ def compute_pc_metrics(
144147
pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
145148
pcs_flat = pcs.reshape(pcs.shape[0], -1)
146149

147-
func_args = (
148-
pcs_flat,
149-
labels,
150-
non_nn_metrics,
151-
unit_id,
152-
unit_ids,
153-
qm_params,
154-
seed,
155-
n_spikes_all_units,
156-
fr_all_units,
157-
)
150+
func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process)
158151
items.append(func_args)
159152

160153
if not run_in_parallel and non_nn_metrics:
@@ -167,7 +160,15 @@ def compute_pc_metrics(
167160
for metric_name, metric in pca_metrics_unit.items():
168161
pc_metrics[metric_name][unit_id] = metric
169162
elif run_in_parallel and non_nn_metrics:
170-
with ProcessPoolExecutor(n_jobs) as executor:
163+
if mp_context is not None and platform.system() == "Windows":
164+
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
165+
elif mp_context == "fork" and platform.system() == "Darwin":
166+
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
167+
168+
with ProcessPoolExecutor(
169+
max_workers=n_jobs,
170+
mp_context=mp.get_context(mp_context),
171+
) as executor:
171172
results = executor.map(pca_metrics_one_unit, items)
172173
if progress_bar:
173174
results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics")
@@ -976,26 +977,19 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):
976977

977978

978979
def pca_metrics_one_unit(args):
979-
(
980-
pcs_flat,
981-
labels,
982-
metric_names,
983-
unit_id,
984-
unit_ids,
985-
qm_params,
986-
seed,
987-
# we_folder,
988-
n_spikes_all_units,
989-
fr_all_units,
990-
) = args
991-
992-
# if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names:
993-
# we = load_waveforms(we_folder)
980+
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args
981+
982+
if max_threads_per_process is None:
983+
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
984+
else:
985+
with threadpool_limits(limits=int(max_threads_per_process)):
986+
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
987+
994988

989+
def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params):
995990
pc_metrics = {}
996991
# metrics
997992
if "isolation_distance" in metric_names or "l_ratio" in metric_names:
998-
999993
try:
1000994
isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id)
1001995
except:

src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import pytest
22
import numpy as np
33

4-
from spikeinterface.qualitymetrics import (
5-
compute_pc_metrics,
6-
)
4+
from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list
75

86

97
def test_calculate_pc_metrics(small_sorting_analyzer):
@@ -22,3 +20,24 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
2220
assert not np.all(np.isnan(res2[metric_name].values))
2321

2422
assert np.array_equal(res1[metric_name].values, res2[metric_name].values)
23+
24+
25+
def test_pca_metrics_multi_processing(small_sorting_analyzer):
26+
sorting_analyzer = small_sorting_analyzer
27+
28+
metric_names = get_quality_pca_metric_list()
29+
metric_names.remove("nn_isolation")
30+
metric_names.remove("nn_noise_overlap")
31+
32+
print(f"Computing PCA metrics with 1 thread per process")
33+
res1 = compute_pc_metrics(
34+
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True
35+
)
36+
print(f"Computing PCA metrics with 2 thread per process")
37+
res2 = compute_pc_metrics(
38+
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
39+
)
40+
print("Computing PCA metrics with spawn context")
41+
res2 = compute_pc_metrics(
42+
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
43+
)

0 commit comments

Comments
 (0)