Skip to content

Commit 6f52972

Browse files
committed
add error message for invalid multiprocess
1 parent e513626 commit 6f52972

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

src/pyper/_core/task.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,19 @@ def __init__(
4848
or inspect.iscoroutinefunction(func.__call__) \
4949
or inspect.isasyncgenfunction(func.__call__)
5050

51-
if self.is_async and multiprocess:
52-
raise ValueError("multiprocess cannot be True for an async task")
53-
51+
if multiprocess:
52+
# Asynchronous functions cannot be multiprocessed
53+
if self.is_async:
54+
raise ValueError("multiprocess cannot be True for an async task")
55+
56+
# The function needs to be globally accessible to be multiprocessed
57+
# This excludes objects like lambdas and closures
58+
# We capture these cases to throw a clear error message
59+
module = inspect.getmodule(func)
60+
if module is None or getattr(module, func.__name__, None) is not func:
61+
raise RuntimeError(f"{func} cannot be multiprocessed because it is not globally accessible"
62+
f" -- it must be a globally defined object accessible by the name {func.__name__}")
63+
5464
self.func = func if bind is None else functools.partial(func, *bind[0], **bind[1])
5565
self.branch = branch
5666
self.join = join

tests/test_task.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,32 @@ def test_raise_for_invalid_func():
7474
else:
7575
raise AssertionError
7676

77-
def test_raise_for_invalid_multiprocess():
77+
def test_raise_for_async_multiprocess():
7878
try:
7979
task(afunc, multiprocess=True)
8080
except Exception as e:
8181
assert isinstance(e, ValueError)
8282
else:
8383
raise AssertionError
8484

85+
def test_raise_for_lambda_multiprocess():
86+
try:
87+
task(lambda x: x, multiprocess=True)
88+
except Exception as e:
89+
assert isinstance(e, RuntimeError)
90+
else:
91+
raise AssertionError
92+
93+
def test_raise_for_non_global_multiprocess():
94+
try:
95+
@task(multiprocess=True)
96+
def f(x):
97+
return x
98+
except Exception as e:
99+
assert isinstance(e, RuntimeError)
100+
else:
101+
raise AssertionError
102+
85103
def test_async_task():
86104
p = task(afunc)
87105
assert isinstance(p, AsyncPipeline)

0 commit comments

Comments
 (0)