Skip to content

Commit db46c85

Browse files
committed
implement signalDefinition
Signed-off-by: Tim Li <ltim@uber.com>
1 parent e02576e commit db46c85

File tree

3 files changed

+244
-18
lines changed

3 files changed

+244
-18
lines changed

cadence/signal.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""
2+
Signal definition and registration for Cadence workflows.
3+
4+
This module provides functionality to define and register signal handlers
5+
for workflows, similar to ActivityDefinition but for signals.
6+
"""
7+
8+
import inspect
9+
from dataclasses import dataclass
10+
from functools import update_wrapper
11+
from inspect import Parameter, signature
12+
from typing import (
13+
Callable,
14+
Generic,
15+
ParamSpec,
16+
Type,
17+
TypeVar,
18+
TypedDict,
19+
Unpack,
20+
overload,
21+
get_type_hints,
22+
Any,
23+
)
24+
25+
P = ParamSpec("P")
26+
T = TypeVar("T")
27+
28+
29+
@dataclass(frozen=True)
30+
class SignalParameter:
31+
"""Parameter metadata for a signal handler."""
32+
33+
name: str
34+
type_hint: Type | None
35+
has_default: bool
36+
default_value: Any
37+
38+
39+
class SignalDefinitionOptions(TypedDict, total=False):
40+
"""Options for defining a signal."""
41+
42+
name: str
43+
44+
45+
class SignalDefinition(Generic[P, T]):
46+
"""
47+
Definition of a signal handler with metadata.
48+
49+
Similar to ActivityDefinition but for signal handlers.
50+
Provides type safety and metadata for signal handlers.
51+
"""
52+
53+
def __init__(
54+
self,
55+
wrapped: Callable[P, T],
56+
name: str,
57+
params: list[SignalParameter],
58+
is_async: bool,
59+
):
60+
self._wrapped = wrapped
61+
self._name = name
62+
self._params = params
63+
self._is_async = is_async
64+
update_wrapper(self, wrapped)
65+
66+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
67+
"""Call the wrapped signal handler function."""
68+
return self._wrapped(*args, **kwargs)
69+
70+
@property
71+
def name(self) -> str:
72+
"""Get the signal name."""
73+
return self._name
74+
75+
@property
76+
def params(self) -> list[SignalParameter]:
77+
"""Get the signal parameters."""
78+
return self._params
79+
80+
@property
81+
def is_async(self) -> bool:
82+
"""Check if the signal handler is async."""
83+
return self._is_async
84+
85+
@property
86+
def wrapped(self) -> Callable[P, T]:
87+
"""Get the wrapped signal handler function."""
88+
return self._wrapped
89+
90+
@staticmethod
91+
def wrap(
92+
fn: Callable[P, T], opts: SignalDefinitionOptions
93+
) -> "SignalDefinition[P, T]":
94+
"""
95+
Wrap a function as a SignalDefinition.
96+
97+
Args:
98+
fn: The signal handler function to wrap
99+
opts: Options for the signal definition
100+
101+
Returns:
102+
A SignalDefinition instance
103+
104+
Raises:
105+
ValueError: If name is not provided in options or return type is not None
106+
"""
107+
name = opts.get("name") or fn.__qualname__
108+
is_async = inspect.iscoroutinefunction(fn)
109+
params = _get_signal_signature(fn)
110+
_validate_signal_return_type(fn)
111+
112+
return SignalDefinition(fn, name, params, is_async)
113+
114+
115+
SignalDecorator = Callable[[Callable[P, T]], SignalDefinition[P, T]]
116+
117+
118+
@overload
119+
def defn(fn: Callable[P, T]) -> SignalDefinition[P, T]: ...
120+
121+
122+
@overload
123+
def defn(**kwargs: Unpack[SignalDefinitionOptions]) -> SignalDecorator: ...
124+
125+
126+
def defn(
127+
fn: Callable[P, T] | None = None, **kwargs: Unpack[SignalDefinitionOptions]
128+
) -> SignalDecorator | SignalDefinition[P, T]:
129+
"""
130+
Decorator to define a signal handler.
131+
132+
Can be used with or without parentheses:
133+
@signal.defn(name="approval")
134+
async def handle_approval(self, approved: bool):
135+
...
136+
137+
@signal.defn(name="approval")
138+
def handle_approval(self, approved: bool):
139+
...
140+
141+
Args:
142+
fn: The signal handler function to decorate
143+
**kwargs: Options for the signal definition (name is required)
144+
145+
Returns:
146+
The decorated function as a SignalDefinition instance
147+
148+
Raises:
149+
ValueError: If name is not provided
150+
"""
151+
options = SignalDefinitionOptions(**kwargs)
152+
153+
def decorator(inner_fn: Callable[P, T]) -> SignalDefinition[P, T]:
154+
return SignalDefinition.wrap(inner_fn, options)
155+
156+
if fn is not None:
157+
return decorator(fn)
158+
159+
return decorator
160+
161+
162+
def _validate_signal_return_type(fn: Callable) -> None:
163+
"""
164+
Validate that signal handler returns None.
165+
166+
Args:
167+
fn: The signal handler function
168+
169+
Raises:
170+
ValueError: If return type is not None
171+
"""
172+
try:
173+
hints = get_type_hints(fn)
174+
ret_type = hints.get("return", inspect.Signature.empty)
175+
176+
if ret_type is not None and ret_type is not inspect.Signature.empty:
177+
raise ValueError(
178+
f"Signal handler '{fn.__qualname__}' must return None "
179+
f"(signals cannot return values), got {ret_type}"
180+
)
181+
except NameError:
182+
pass
183+
184+
185+
def _get_signal_signature(fn: Callable[P, T]) -> list[SignalParameter]:
186+
"""
187+
Extract parameter information from a signal handler function.
188+
189+
Args:
190+
fn: The signal handler function
191+
192+
Returns:
193+
List of SignalParameter objects
194+
195+
Raises:
196+
ValueError: If parameters are not positional
197+
"""
198+
sig = signature(fn)
199+
args = sig.parameters
200+
hints = get_type_hints(fn)
201+
params = []
202+
203+
for name, param in args.items():
204+
# Filter out the self parameter for instance methods
205+
if param.name == "self":
206+
continue
207+
208+
has_default = param.default != Parameter.empty
209+
default = param.default if has_default else None
210+
211+
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
212+
type_hint = hints.get(name, None)
213+
params.append(SignalParameter(name, type_hint, has_default, default))
214+
else:
215+
raise ValueError(
216+
f"Signal handler '{fn.__qualname__}' parameter '{name}' must be positional, "
217+
f"got {param.kind.name}"
218+
)
219+
220+
return params

cadence/workflow.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from cadence.client import Client
2121
from cadence.data_converter import DataConverter
22+
from cadence.signal import SignalDefinition, SignalDefinitionOptions
2223

2324
ResultType = TypeVar("ResultType")
2425

@@ -64,16 +65,16 @@ def __init__(
6465
cls: Type,
6566
name: str,
6667
run_method_name: str,
67-
signals: dict[str, Callable[..., Any]],
68+
signals: dict[str, SignalDefinition[..., Any]],
6869
):
6970
self._cls = cls
7071
self._name = name
7172
self._run_method_name = run_method_name
7273
self._signals = signals
7374

7475
@property
75-
def signals(self) -> dict[str, Callable[..., Any]]:
76-
"""Get the signals."""
76+
def signals(self) -> dict[str, SignalDefinition[..., Any]]:
77+
"""Get the signal definitions."""
7778
return self._signals
7879

7980
@property
@@ -111,7 +112,7 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
111112

112113
# Validate that the class has exactly one run method and find it
113114
# Also validate that class does not have multiple signal methods with the same name
114-
signals: dict[str, Callable[..., Any]] = {}
115+
signals: dict[str, SignalDefinition[..., Any]] = {}
115116
signal_names: dict[
116117
str, str
117118
] = {} # Map signal name to method name for duplicate detection
@@ -139,7 +140,11 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
139140
f"Multiple @workflow.signal methods found in class {cls.__name__} "
140141
f"with signal name '{signal_name}': '{attr_name}' and '{signal_names[signal_name]}'"
141142
)
142-
signals[attr_name] = attr
143+
# Create SignalDefinition from the decorated method
144+
signal_def = SignalDefinition.wrap(
145+
attr, SignalDefinitionOptions(name=signal_name)
146+
)
147+
signals[signal_name] = signal_def
143148
signal_names[signal_name] = attr_name
144149

145150
if run_method_name is None:

tests/cadence/worker/test_registry.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from cadence import workflow
1010
from cadence.worker import Registry
1111
from cadence.workflow import WorkflowDefinition
12+
from cadence.signal import SignalDefinition
1213
from tests.cadence import common_activities
1314

1415

@@ -230,9 +231,13 @@ async def handle_approval(self, approved: bool):
230231
workflow_def = reg.get_workflow("WorkflowWithSignal")
231232
assert isinstance(workflow_def, WorkflowDefinition)
232233
assert len(workflow_def.signals) == 1
233-
assert "handle_approval" in workflow_def.signals
234-
assert hasattr(workflow_def.signals["handle_approval"], "_workflow_signal")
235-
assert workflow_def.signals["handle_approval"]._workflow_signal == "approval"
234+
assert "approval" in workflow_def.signals
235+
signal_def = workflow_def.signals["approval"]
236+
assert isinstance(signal_def, SignalDefinition)
237+
assert signal_def.name == "approval"
238+
assert signal_def.is_async is True
239+
assert len(signal_def.params) == 1
240+
assert signal_def.params[0].name == "approved"
236241

237242
def test_workflow_with_multiple_signals(self):
238243
"""Test workflow with multiple signal handlers."""
@@ -254,16 +259,12 @@ async def handle_cancel(self):
254259

255260
workflow_def = reg.get_workflow("WorkflowWithMultipleSignals")
256261
assert len(workflow_def.signals) == 2
257-
assert "handle_approval" in workflow_def.signals
258-
assert "handle_cancel" in workflow_def.signals
259-
assert (
260-
getattr(workflow_def.signals["handle_approval"], "_workflow_signal")
261-
== "approval"
262-
)
263-
assert (
264-
getattr(workflow_def.signals["handle_cancel"], "_workflow_signal")
265-
== "cancel"
266-
)
262+
assert "approval" in workflow_def.signals
263+
assert "cancel" in workflow_def.signals
264+
assert isinstance(workflow_def.signals["approval"], SignalDefinition)
265+
assert isinstance(workflow_def.signals["cancel"], SignalDefinition)
266+
assert workflow_def.signals["approval"].name == "approval"
267+
assert workflow_def.signals["cancel"].name == "cancel"
267268

268269
def test_signal_decorator_requires_name(self):
269270
"""Test that signal decorator requires name parameter."""

0 commit comments

Comments
 (0)