diff --git a/dapr/actor/__init__.py b/dapr/actor/__init__.py index 4323caae2..bf21f488c 100644 --- a/dapr/actor/__init__.py +++ b/dapr/actor/__init__.py @@ -20,7 +20,6 @@ from dapr.actor.runtime.remindable import Remindable from dapr.actor.runtime.runtime import ActorRuntime - __all__ = [ 'ActorInterface', 'ActorProxy', diff --git a/dapr/actor/client/proxy.py b/dapr/actor/client/proxy.py index a7648bf97..aeafdd58c 100644 --- a/dapr/actor/client/proxy.py +++ b/dapr/actor/client/proxy.py @@ -21,8 +21,8 @@ from dapr.actor.runtime._type_utils import get_dispatchable_attrs_from_interface from dapr.clients import DaprActorClientBase, DaprActorHttpClient from dapr.clients.retry import RetryPolicy -from dapr.serializers import Serializer, DefaultJSONSerializer from dapr.conf import settings +from dapr.serializers import DefaultJSONSerializer, Serializer # Actor factory Callable type hint. ACTOR_FACTORY_CALLBACK = Callable[[ActorInterface, str, str], 'ActorProxy'] diff --git a/dapr/actor/runtime/_reminder_data.py b/dapr/actor/runtime/_reminder_data.py index 8821c94bc..5453b8162 100644 --- a/dapr/actor/runtime/_reminder_data.py +++ b/dapr/actor/runtime/_reminder_data.py @@ -14,7 +14,6 @@ """ import base64 - from datetime import timedelta from typing import Any, Dict, Optional diff --git a/dapr/actor/runtime/_state_provider.py b/dapr/actor/runtime/_state_provider.py index 54f6b5837..eeb1e4995 100644 --- a/dapr/actor/runtime/_state_provider.py +++ b/dapr/actor/runtime/_state_provider.py @@ -14,12 +14,11 @@ """ import io +from typing import Any, List, Tuple, Type -from typing import Any, List, Type, Tuple -from dapr.actor.runtime.state_change import StateChangeKind, ActorStateChange +from dapr.actor.runtime.state_change import ActorStateChange, StateChangeKind from dapr.clients.base import DaprActorClientBase -from dapr.serializers import Serializer, DefaultJSONSerializer - +from dapr.serializers import DefaultJSONSerializer, Serializer # Mapping StateChangeKind to Dapr State Operation _MAP_CHANGE_KIND_TO_OPERATION = { diff --git a/dapr/actor/runtime/_type_information.py b/dapr/actor/runtime/_type_information.py index 72566eb17..f9171aea8 100644 --- a/dapr/actor/runtime/_type_information.py +++ b/dapr/actor/runtime/_type_information.py @@ -13,10 +13,10 @@ limitations under the License. """ -from dapr.actor.runtime.remindable import Remindable -from dapr.actor.runtime._type_utils import is_dapr_actor, get_actor_interfaces +from typing import TYPE_CHECKING, List, Type -from typing import List, Type, TYPE_CHECKING +from dapr.actor.runtime._type_utils import get_actor_interfaces, is_dapr_actor +from dapr.actor.runtime.remindable import Remindable if TYPE_CHECKING: from dapr.actor.actor_interface import ActorInterface # noqa: F401 diff --git a/dapr/actor/runtime/actor.py b/dapr/actor/runtime/actor.py index 79b1e6ab1..fab02fc70 100644 --- a/dapr/actor/runtime/actor.py +++ b/dapr/actor/runtime/actor.py @@ -14,16 +14,15 @@ """ import uuid - from datetime import timedelta from typing import Any, Optional from dapr.actor.id import ActorId from dapr.actor.runtime._method_context import ActorMethodContext -from dapr.actor.runtime.context import ActorRuntimeContext -from dapr.actor.runtime.state_manager import ActorStateManager from dapr.actor.runtime._reminder_data import ActorReminderData from dapr.actor.runtime._timer_data import TIMER_CALLBACK, ActorTimerData +from dapr.actor.runtime.context import ActorRuntimeContext +from dapr.actor.runtime.state_manager import ActorStateManager class Actor: diff --git a/dapr/actor/runtime/context.py b/dapr/actor/runtime/context.py index ec66ba366..b2610bed4 100644 --- a/dapr/actor/runtime/context.py +++ b/dapr/actor/runtime/context.py @@ -13,16 +13,16 @@ limitations under the License. """ +from typing import TYPE_CHECKING, Callable, Optional + from dapr.actor.id import ActorId from dapr.actor.runtime._state_provider import StateProvider from dapr.clients.base import DaprActorClientBase from dapr.serializers import Serializer -from typing import Callable, Optional, TYPE_CHECKING - if TYPE_CHECKING: - from dapr.actor.runtime.actor import Actor from dapr.actor.runtime._type_information import ActorTypeInformation + from dapr.actor.runtime.actor import Actor class ActorRuntimeContext: diff --git a/dapr/actor/runtime/manager.py b/dapr/actor/runtime/manager.py index a6d1a792a..969e48e2a 100644 --- a/dapr/actor/runtime/manager.py +++ b/dapr/actor/runtime/manager.py @@ -15,17 +15,16 @@ import asyncio import uuid - from typing import Any, Callable, Coroutine, Dict, Optional from dapr.actor.id import ActorId -from dapr.clients.exceptions import DaprInternalError +from dapr.actor.runtime._method_context import ActorMethodContext +from dapr.actor.runtime._reminder_data import ActorReminderData from dapr.actor.runtime.actor import Actor from dapr.actor.runtime.context import ActorRuntimeContext -from dapr.actor.runtime._method_context import ActorMethodContext from dapr.actor.runtime.method_dispatcher import ActorMethodDispatcher -from dapr.actor.runtime._reminder_data import ActorReminderData from dapr.actor.runtime.reentrancy_context import reentrancy_ctx +from dapr.clients.exceptions import DaprInternalError TIMER_METHOD_NAME = 'fire_timer' REMINDER_METHOD_NAME = 'receive_reminder' diff --git a/dapr/actor/runtime/method_dispatcher.py b/dapr/actor/runtime/method_dispatcher.py index 8d9b65114..ffe66d991 100644 --- a/dapr/actor/runtime/method_dispatcher.py +++ b/dapr/actor/runtime/method_dispatcher.py @@ -14,9 +14,10 @@ """ from typing import Any, Dict, List -from dapr.actor.runtime.actor import Actor + from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime._type_utils import get_dispatchable_attrs +from dapr.actor.runtime.actor import Actor class ActorMethodDispatcher: diff --git a/dapr/actor/runtime/reentrancy_context.py b/dapr/actor/runtime/reentrancy_context.py index 0fc9927df..b295b57d7 100644 --- a/dapr/actor/runtime/reentrancy_context.py +++ b/dapr/actor/runtime/reentrancy_context.py @@ -13,7 +13,7 @@ limitations under the License. """ -from typing import Optional from contextvars import ContextVar +from typing import Optional reentrancy_ctx: ContextVar[Optional[str]] = ContextVar('reentrancy_ctx', default=None) diff --git a/dapr/actor/runtime/runtime.py b/dapr/actor/runtime/runtime.py index 3659f1479..b03f0bc75 100644 --- a/dapr/actor/runtime/runtime.py +++ b/dapr/actor/runtime/runtime.py @@ -14,20 +14,18 @@ """ import asyncio - -from typing import Dict, List, Optional, Type, Callable +from typing import Callable, Dict, List, Optional, Type from dapr.actor.id import ActorId +from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime.actor import Actor from dapr.actor.runtime.config import ActorRuntimeConfig from dapr.actor.runtime.context import ActorRuntimeContext -from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime.manager import ActorManager +from dapr.actor.runtime.reentrancy_context import reentrancy_ctx from dapr.clients.http.dapr_actor_http_client import DaprActorHttpClient -from dapr.serializers import Serializer, DefaultJSONSerializer from dapr.conf import settings - -from dapr.actor.runtime.reentrancy_context import reentrancy_ctx +from dapr.serializers import DefaultJSONSerializer, Serializer class ActorRuntime: diff --git a/dapr/actor/runtime/state_change.py b/dapr/actor/runtime/state_change.py index dba21e2c1..4937fcb53 100644 --- a/dapr/actor/runtime/state_change.py +++ b/dapr/actor/runtime/state_change.py @@ -14,7 +14,7 @@ """ from enum import Enum -from typing import TypeVar, Generic, Optional +from typing import Generic, Optional, TypeVar T = TypeVar('T') diff --git a/dapr/aio/clients/__init__.py b/dapr/aio/clients/__init__.py index e945b1307..3f7ce6363 100644 --- a/dapr/aio/clients/__init__.py +++ b/dapr/aio/clients/__init__.py @@ -15,14 +15,15 @@ from typing import Callable, Dict, List, Optional, Union +from google.protobuf.message import Message as GrpcMessage + +from dapr.aio.clients.grpc.client import DaprGrpcClientAsync, InvokeMethodResponse, MetadataTuple from dapr.clients.base import DaprActorClientBase -from dapr.clients.exceptions import DaprInternalError, ERROR_CODE_UNKNOWN -from dapr.aio.clients.grpc.client import DaprGrpcClientAsync, MetadataTuple, InvokeMethodResponse -from dapr.clients.grpc._jobs import Job, FailurePolicy, DropFailurePolicy, ConstantFailurePolicy +from dapr.clients.exceptions import ERROR_CODE_UNKNOWN, DaprInternalError +from dapr.clients.grpc._jobs import ConstantFailurePolicy, DropFailurePolicy, FailurePolicy, Job from dapr.clients.http.dapr_actor_http_client import DaprActorHttpClient from dapr.clients.http.dapr_invocation_http_client import DaprInvocationHttpClient from dapr.conf import settings -from google.protobuf.message import Message as GrpcMessage __all__ = [ 'DaprClient', @@ -37,10 +38,10 @@ ] from grpc.aio import ( # type: ignore - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, StreamStreamClientInterceptor, + StreamUnaryClientInterceptor, + UnaryStreamClientInterceptor, + UnaryUnaryClientInterceptor, ) diff --git a/dapr/aio/clients/grpc/_request.py b/dapr/aio/clients/grpc/_request.py index b3c3ce2d4..129c556f3 100644 --- a/dapr/aio/clients/grpc/_request.py +++ b/dapr/aio/clients/grpc/_request.py @@ -16,7 +16,7 @@ import io from typing import Union -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.clients.grpc._helpers import to_bytes from dapr.clients.grpc._request import DaprRequest from dapr.proto import api_v1, common_v1 diff --git a/dapr/aio/clients/grpc/_response.py b/dapr/aio/clients/grpc/_response.py index 480eb7769..95c2ba497 100644 --- a/dapr/aio/clients/grpc/_response.py +++ b/dapr/aio/clients/grpc/_response.py @@ -15,8 +15,8 @@ from typing import AsyncGenerator, Generic -from dapr.proto import api_v1 from dapr.clients.grpc._response import DaprResponse, TCryptoResponse +from dapr.proto import api_v1 class CryptoResponse(DaprResponse, Generic[TCryptoResponse]): diff --git a/dapr/aio/clients/grpc/client.py b/dapr/aio/clients/grpc/client.py index 995b82680..6a5185770 100644 --- a/dapr/aio/clients/grpc/client.py +++ b/dapr/aio/clients/grpc/client.py @@ -14,96 +14,90 @@ """ import asyncio -import time -import socket import json +import socket +import time import uuid - from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Text, Union from urllib.parse import urlencode - from warnings import warn -from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any, Awaitable -from typing_extensions import Self - -from google.protobuf.message import Message as GrpcMessage -from google.protobuf.empty_pb2 import Empty as GrpcEmpty -from google.protobuf.any_pb2 import Any as GrpcAny - import grpc.aio # type: ignore +from google.protobuf.any_pb2 import Any as GrpcAny +from google.protobuf.empty_pb2 import Empty as GrpcEmpty +from google.protobuf.message import Message as GrpcMessage from grpc.aio import ( # type: ignore - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor, AioRpcError, + StreamStreamClientInterceptor, + StreamUnaryClientInterceptor, + UnaryStreamClientInterceptor, + UnaryUnaryClientInterceptor, ) +from typing_extensions import Self -from dapr.aio.clients.grpc.subscription import Subscription -from dapr.clients.exceptions import DaprInternalError, DaprGrpcError -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions -from dapr.clients.grpc._state import StateOptions, StateItem -from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus -from dapr.clients.health import DaprHealth -from dapr.clients.retry import RetryPolicy -from dapr.common.pubsub.subscription import StreamInactiveError -from dapr.conf.helpers import GrpcEndpoint -from dapr.conf import settings -from dapr.proto import api_v1, api_service_v1, common_v1 -from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse -from dapr.version import __version__ - +from dapr.aio.clients.grpc._request import ( + DecryptRequestIterator, + EncryptRequestIterator, +) +from dapr.aio.clients.grpc._response import ( + DecryptResponse, + EncryptResponse, +) from dapr.aio.clients.grpc.interceptors import ( DaprClientInterceptorAsync, DaprClientTimeoutInterceptorAsync, ) +from dapr.aio.clients.grpc.subscription import Subscription +from dapr.clients.exceptions import DaprGrpcError, DaprInternalError +from dapr.clients.grpc import conversation +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.clients.grpc._helpers import ( MetadataTuple, - to_bytes, - validateNotNone, - validateNotBlankString, convert_dict_to_grpc_dict_of_any, convert_value_to_struct, + getWorkflowRuntimeStatus, + to_bytes, + validateNotBlankString, + validateNotNone, ) -from dapr.aio.clients.grpc._request import ( - EncryptRequestIterator, - DecryptRequestIterator, -) -from dapr.aio.clients.grpc._response import ( - EncryptResponse, - DecryptResponse, -) +from dapr.clients.grpc._jobs import Job from dapr.clients.grpc._request import ( - InvokeMethodRequest, BindingRequest, + InvokeMethodRequest, TransactionalStateOperation, ) -from dapr.clients.grpc import conversation - -from dapr.clients.grpc._jobs import Job from dapr.clients.grpc._response import ( BindingResponse, + BulkStateItem, + BulkStatesResponse, + ConfigurationResponse, + ConfigurationWatcher, DaprResponse, - GetSecretResponse, GetBulkSecretResponse, GetMetadataResponse, + GetSecretResponse, + GetWorkflowResponse, InvokeMethodResponse, - UnlockResponseStatus, - StateResponse, - BulkStatesResponse, - BulkStateItem, - ConfigurationResponse, QueryResponse, QueryResponseItem, RegisteredComponents, - ConfigurationWatcher, - TryLockResponse, - UnlockResponse, - GetWorkflowResponse, StartWorkflowResponse, + StateResponse, TopicEventResponse, + TryLockResponse, + UnlockResponse, + UnlockResponseStatus, ) +from dapr.clients.grpc._state import StateItem, StateOptions +from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy +from dapr.common.pubsub.subscription import StreamInactiveError +from dapr.conf import settings +from dapr.conf.helpers import GrpcEndpoint +from dapr.proto import api_service_v1, api_v1, common_v1 +from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse +from dapr.version import __version__ class DaprGrpcClientAsync: @@ -170,7 +164,7 @@ def __init__( if not address: address = settings.DAPR_GRPC_ENDPOINT or ( - f'{settings.DAPR_RUNTIME_HOST}:' f'{settings.DAPR_GRPC_PORT}' + f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' ) try: diff --git a/dapr/aio/clients/grpc/interceptors.py b/dapr/aio/clients/grpc/interceptors.py index bf83cf56a..0444d5acb 100644 --- a/dapr/aio/clients/grpc/interceptors.py +++ b/dapr/aio/clients/grpc/interceptors.py @@ -16,7 +16,11 @@ from collections import namedtuple from typing import List, Tuple -from grpc.aio import UnaryUnaryClientInterceptor, StreamStreamClientInterceptor, ClientCallDetails # type: ignore +from grpc.aio import ( # type: ignore + ClientCallDetails, + StreamStreamClientInterceptor, + UnaryUnaryClientInterceptor, +) from dapr.conf import settings diff --git a/dapr/aio/clients/grpc/subscription.py b/dapr/aio/clients/grpc/subscription.py index 9aabf8b28..137badd21 100644 --- a/dapr/aio/clients/grpc/subscription.py +++ b/dapr/aio/clients/grpc/subscription.py @@ -1,13 +1,14 @@ import asyncio + from grpc import StatusCode from grpc.aio import AioRpcError from dapr.clients.grpc._response import TopicEventResponse from dapr.clients.health import DaprHealth from dapr.common.pubsub.subscription import ( + StreamCancelledError, StreamInactiveError, SubscriptionMessage, - StreamCancelledError, ) from dapr.proto import api_v1, appcallback_v1 diff --git a/dapr/clients/__init__.py b/dapr/clients/__init__.py index 78ad99eb4..5d92b56c7 100644 --- a/dapr/clients/__init__.py +++ b/dapr/clients/__init__.py @@ -16,16 +16,16 @@ from typing import Callable, Dict, List, Optional, Union from warnings import warn +from google.protobuf.message import Message as GrpcMessage + from dapr.clients.base import DaprActorClientBase -from dapr.clients.exceptions import DaprInternalError, ERROR_CODE_UNKNOWN -from dapr.clients.grpc.client import DaprGrpcClient, MetadataTuple, InvokeMethodResponse -from dapr.clients.grpc._jobs import Job, FailurePolicy, DropFailurePolicy, ConstantFailurePolicy +from dapr.clients.exceptions import ERROR_CODE_UNKNOWN, DaprInternalError +from dapr.clients.grpc._jobs import ConstantFailurePolicy, DropFailurePolicy, FailurePolicy, Job +from dapr.clients.grpc.client import DaprGrpcClient, InvokeMethodResponse, MetadataTuple from dapr.clients.http.dapr_actor_http_client import DaprActorHttpClient from dapr.clients.http.dapr_invocation_http_client import DaprInvocationHttpClient from dapr.clients.retry import RetryPolicy from dapr.conf import settings -from google.protobuf.message import Message as GrpcMessage - __all__ = [ 'DaprClient', @@ -41,10 +41,10 @@ from grpc import ( # type: ignore - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, StreamStreamClientInterceptor, + StreamUnaryClientInterceptor, + UnaryStreamClientInterceptor, + UnaryUnaryClientInterceptor, ) diff --git a/dapr/clients/exceptions.py b/dapr/clients/exceptions.py index 61ae0d8b6..f6358cb85 100644 --- a/dapr/clients/exceptions.py +++ b/dapr/clients/exceptions.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import base64 import json from typing import TYPE_CHECKING, Optional @@ -20,9 +21,9 @@ from dapr.serializers import Serializer from google.protobuf.json_format import MessageToDict +from google.rpc import error_details_pb2 # type: ignore from grpc import RpcError # type: ignore from grpc_status import rpc_status # type: ignore -from google.rpc import error_details_pb2 # type: ignore ERROR_CODE_UNKNOWN = 'UNKNOWN' ERROR_CODE_DOES_NOT_EXIST = 'ERR_DOES_NOT_EXIST' diff --git a/dapr/clients/grpc/_conversation_helpers.py b/dapr/clients/grpc/_conversation_helpers.py index 37bb81c18..53cc87cf8 100644 --- a/dapr/clients/grpc/_conversation_helpers.py +++ b/dapr/clients/grpc/_conversation_helpers.py @@ -16,6 +16,7 @@ import inspect import random import string +import types from dataclasses import fields, is_dataclass from enum import Enum from typing import ( @@ -23,21 +24,19 @@ Callable, Dict, List, + Literal, Mapping, Optional, Sequence, Union, - Literal, + cast, get_args, get_origin, get_type_hints, - cast, ) from dapr.conf import settings -import types - # Make mypy happy. Runtime handle: real class on 3.10+, else None. # TODO: Python 3.9 is about to be end-of-life, so we can drop this at some point next year (2026) UnionType: Any = getattr(types, 'UnionType', None) @@ -190,14 +189,14 @@ def _json_primitive_type(v: Any) -> str: if settings.DAPR_CONVERSATION_TOOLS_LARGE_ENUM_BEHAVIOR == 'error': raise ValueError( f"Enum '{getattr(python_type, '__name__', str(python_type))}' has {count} members, " - f"exceeding DAPR_CONVERSATION_MAX_ENUM_ITEMS={settings.DAPR_CONVERSATION_TOOLS_MAX_ENUM_ITEMS}. " - f"Either reduce the enum size or set DAPR_CONVERSATION_LARGE_ENUM_BEHAVIOR=string to allow compact schema." + f'exceeding DAPR_CONVERSATION_MAX_ENUM_ITEMS={settings.DAPR_CONVERSATION_TOOLS_MAX_ENUM_ITEMS}. ' + f'Either reduce the enum size or set DAPR_CONVERSATION_LARGE_ENUM_BEHAVIOR=string to allow compact schema.' ) # Default behavior: compact schema as a string with helpful context and a few examples example_values = [item.value for item in members[:5]] if members else [] desc = ( - f"{getattr(python_type, '__name__', 'Enum')} (enum with {count} values). " - f"Provide a valid value. Schema compacted to avoid oversized enum listing." + f'{getattr(python_type, "__name__", "Enum")} (enum with {count} values). ' + f'Provide a valid value. Schema compacted to avoid oversized enum listing.' ) schema = {'type': 'string', 'description': desc} if example_values: @@ -696,8 +695,8 @@ def stringify_tool_output(value: Any) -> str: * dataclass -> asdict If JSON serialization still fails, fallback to str(value). If that fails, return ''. """ - import json as _json import base64 as _b64 + import json as _json from dataclasses import asdict as _asdict if isinstance(value, str): @@ -962,7 +961,7 @@ def _coerce_and_validate(value: Any, expected_type: Any) -> Any: missing.append(pname) if missing: raise ValueError( - f"Missing required constructor arg(s) for {expected_type.__name__}: {', '.join(missing)}" + f'Missing required constructor arg(s) for {expected_type.__name__}: {", ".join(missing)}' ) try: return expected_type(**kwargs) @@ -978,7 +977,7 @@ def _coerce_and_validate(value: Any, expected_type: Any) -> Any: if expected_type is Any or isinstance(value, expected_type): return value raise ValueError( - f"Expected {getattr(expected_type, '__name__', str(expected_type))}, got {type(value).__name__}" + f'Expected {getattr(expected_type, "__name__", str(expected_type))}, got {type(value).__name__}' ) @@ -1014,12 +1013,12 @@ def bind_params_to_func(fn: Callable[..., Any], params: Params): and p.name not in bound.arguments ] if missing: - raise ToolArgumentError(f"Missing required parameter(s): {', '.join(missing)}") + raise ToolArgumentError(f'Missing required parameter(s): {", ".join(missing)}') # unexpected kwargs unless **kwargs present if not any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): extra = set(params) - set(sig.parameters) if extra: - raise ToolArgumentError(f"Unexpected parameter(s): {', '.join(sorted(extra))}") + raise ToolArgumentError(f'Unexpected parameter(s): {", ".join(sorted(extra))}') elif isinstance(params, Sequence): bound = sig.bind(*params) else: diff --git a/dapr/clients/grpc/_helpers.py b/dapr/clients/grpc/_helpers.py index c68b0f56a..8eb9a1e97 100644 --- a/dapr/clients/grpc/_helpers.py +++ b/dapr/clients/grpc/_helpers.py @@ -12,22 +12,22 @@ See the License for the specific language governing permissions and limitations under the License. """ -from enum import Enum -from typing import Any, Dict, List, Optional, Union, Tuple +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union +from google.protobuf import json_format from google.protobuf.any_pb2 import Any as GrpcAny from google.protobuf.message import Message as GrpcMessage +from google.protobuf.struct_pb2 import Struct from google.protobuf.wrappers_pb2 import ( BoolValue, - StringValue, + BytesValue, + DoubleValue, Int32Value, Int64Value, - DoubleValue, - BytesValue, + StringValue, ) -from google.protobuf.struct_pb2 import Struct -from google.protobuf import json_format MetadataDict = Dict[str, List[Union[bytes, str]]] MetadataTuple = Tuple[Tuple[str, Union[bytes, str]], ...] diff --git a/dapr/clients/grpc/_jobs.py b/dapr/clients/grpc/_jobs.py index 896c8db3c..5df9975f0 100644 --- a/dapr/clients/grpc/_jobs.py +++ b/dapr/clients/grpc/_jobs.py @@ -117,9 +117,10 @@ def _get_proto(self): Returns: api_v1.Job: The proto representation of this job. """ - from dapr.proto.runtime.v1 import dapr_pb2 as api_v1 from google.protobuf.any_pb2 import Any as GrpcAny + from dapr.proto.runtime.v1 import dapr_pb2 as api_v1 + # Build the job proto job_proto = api_v1.Job(name=self.name) diff --git a/dapr/clients/grpc/_response.py b/dapr/clients/grpc/_response.py index fff511ff7..e72045e93 100644 --- a/dapr/clients/grpc/_response.py +++ b/dapr/clients/grpc/_response.py @@ -21,19 +21,19 @@ from datetime import datetime from enum import Enum from typing import ( + TYPE_CHECKING, Callable, Dict, + Generator, + Generic, List, + Mapping, + NamedTuple, Optional, - Text, - Union, Sequence, - TYPE_CHECKING, - NamedTuple, - Generator, + Text, TypeVar, - Generic, - Mapping, + Union, ) from google.protobuf.any_pb2 import Any as GrpcAny @@ -43,11 +43,11 @@ from dapr.clients.grpc._helpers import ( MetadataDict, MetadataTuple, + WorkflowRuntimeStatus, to_bytes, to_str, tuple_to_dict, unpack, - WorkflowRuntimeStatus, ) from dapr.proto import api_service_v1, api_v1, appcallback_v1, common_v1 @@ -719,7 +719,7 @@ def _read_subscribe_config( if len(response.items) > 0: handler(response.id, ConfigurationResponse(response.items)) except Exception: - print(f'{self.store_name} configuration watcher for keys ' f'{self.keys} stopped.') + print(f'{self.store_name} configuration watcher for keys {self.keys} stopped.') pass diff --git a/dapr/clients/grpc/_state.py b/dapr/clients/grpc/_state.py index 3dc266b22..e20df4293 100644 --- a/dapr/clients/grpc/_state.py +++ b/dapr/clients/grpc/_state.py @@ -1,7 +1,8 @@ from enum import Enum -from dapr.proto import common_v1 from typing import Dict, Optional, Union +from dapr.proto import common_v1 + class Consistency(Enum): """Represents the consistency mode for a Dapr State Api Call""" diff --git a/dapr/clients/grpc/client.py b/dapr/clients/grpc/client.py index e4ffb2646..469681c96 100644 --- a/dapr/clients/grpc/client.py +++ b/dapr/clients/grpc/client.py @@ -12,89 +12,85 @@ See the License for the specific language governing permissions and limitations under the License. """ + +import json +import socket import threading import time -import socket -import json import uuid - +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union from urllib.parse import urlencode - from warnings import warn -from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any - -from typing_extensions import Self -from datetime import datetime -from google.protobuf.message import Message as GrpcMessage -from google.protobuf.empty_pb2 import Empty as GrpcEmpty -from google.protobuf.any_pb2 import Any as GrpcAny - import grpc # type: ignore +from google.protobuf.any_pb2 import Any as GrpcAny +from google.protobuf.empty_pb2 import Empty as GrpcEmpty +from google.protobuf.message import Message as GrpcMessage from grpc import ( # type: ignore - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor, RpcError, + StreamStreamClientInterceptor, + StreamUnaryClientInterceptor, + UnaryStreamClientInterceptor, + UnaryUnaryClientInterceptor, ) +from typing_extensions import Self -from dapr.clients.exceptions import DaprInternalError, DaprGrpcError -from dapr.clients.grpc._state import StateOptions, StateItem -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions -from dapr.clients.grpc.subscription import Subscription, StreamInactiveError -from dapr.clients.grpc.interceptors import DaprClientInterceptor, DaprClientTimeoutInterceptor -from dapr.clients.health import DaprHealth -from dapr.clients.retry import RetryPolicy -from dapr.common.pubsub.subscription import StreamCancelledError -from dapr.conf import settings -from dapr.proto import api_v1, api_service_v1, common_v1 -from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse -from dapr.version import __version__ - +from dapr.clients.exceptions import DaprGrpcError, DaprInternalError +from dapr.clients.grpc import conversation +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.clients.grpc._helpers import ( - getWorkflowRuntimeStatus, MetadataTuple, - to_bytes, - validateNotNone, - validateNotBlankString, convert_dict_to_grpc_dict_of_any, convert_value_to_struct, + getWorkflowRuntimeStatus, + to_bytes, + validateNotBlankString, + validateNotNone, ) -from dapr.conf.helpers import GrpcEndpoint +from dapr.clients.grpc._jobs import Job from dapr.clients.grpc._request import ( - InvokeMethodRequest, BindingRequest, - TransactionalStateOperation, - EncryptRequestIterator, DecryptRequestIterator, + EncryptRequestIterator, + InvokeMethodRequest, + TransactionalStateOperation, ) -from dapr.clients.grpc import conversation -from dapr.clients.grpc._jobs import Job from dapr.clients.grpc._response import ( BindingResponse, + BulkStateItem, + BulkStatesResponse, + ConfigurationResponse, + ConfigurationWatcher, DaprResponse, - GetSecretResponse, + DecryptResponse, + EncryptResponse, GetBulkSecretResponse, GetMetadataResponse, + GetSecretResponse, + GetWorkflowResponse, InvokeMethodResponse, - UnlockResponseStatus, - StateResponse, - BulkStatesResponse, - BulkStateItem, - ConfigurationResponse, QueryResponse, QueryResponseItem, RegisteredComponents, - ConfigurationWatcher, - TryLockResponse, - UnlockResponse, - GetWorkflowResponse, StartWorkflowResponse, - EncryptResponse, - DecryptResponse, + StateResponse, TopicEventResponse, + TryLockResponse, + UnlockResponse, + UnlockResponseStatus, ) +from dapr.clients.grpc._state import StateItem, StateOptions +from dapr.clients.grpc.interceptors import DaprClientInterceptor, DaprClientTimeoutInterceptor +from dapr.clients.grpc.subscription import StreamInactiveError, Subscription +from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy +from dapr.common.pubsub.subscription import StreamCancelledError +from dapr.conf import settings +from dapr.conf.helpers import GrpcEndpoint, build_grpc_channel_options +from dapr.proto import api_service_v1, api_v1, common_v1 +from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse +from dapr.version import __version__ class DaprGrpcClient: @@ -150,11 +146,9 @@ def __init__( useragent = f'dapr-sdk-python/{__version__}' if not max_grpc_message_length: - options = [ - ('grpc.primary_user_agent', useragent), - ] + base_options = [('grpc.primary_user_agent', useragent)] else: - options = [ + base_options = [ ('grpc.max_send_message_length', max_grpc_message_length), # type: ignore ('grpc.max_receive_message_length', max_grpc_message_length), # type: ignore ('grpc.primary_user_agent', useragent), @@ -162,7 +156,7 @@ def __init__( if not address: address = settings.DAPR_GRPC_ENDPOINT or ( - f'{settings.DAPR_RUNTIME_HOST}:' f'{settings.DAPR_GRPC_PORT}' + f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' ) try: @@ -170,6 +164,9 @@ def __init__( except ValueError as error: raise DaprInternalError(f'{error}') from error + # Merge standard + keepalive + retry options + options = build_grpc_channel_options(base_options) + if self._uri.tls: self._channel = grpc.secure_channel( # type: ignore self._uri.endpoint, diff --git a/dapr/clients/grpc/conversation.py b/dapr/clients/grpc/conversation.py index 1da02dac2..cc5acb390 100644 --- a/dapr/clients/grpc/conversation.py +++ b/dapr/clients/grpc/conversation.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from __future__ import annotations import asyncio diff --git a/dapr/clients/grpc/interceptors.py b/dapr/clients/grpc/interceptors.py index 15bde1857..a574fb8c6 100644 --- a/dapr/clients/grpc/interceptors.py +++ b/dapr/clients/grpc/interceptors.py @@ -1,7 +1,11 @@ from collections import namedtuple from typing import List, Tuple -from grpc import UnaryUnaryClientInterceptor, ClientCallDetails, StreamStreamClientInterceptor # type: ignore +from grpc import ( # type: ignore + ClientCallDetails, + StreamStreamClientInterceptor, + UnaryUnaryClientInterceptor, +) from dapr.conf import settings diff --git a/dapr/clients/grpc/subscription.py b/dapr/clients/grpc/subscription.py index 111946b1b..152ca84bb 100644 --- a/dapr/clients/grpc/subscription.py +++ b/dapr/clients/grpc/subscription.py @@ -1,16 +1,17 @@ -from grpc import RpcError, StatusCode, Call # type: ignore +import queue +import threading +from typing import Optional + +from grpc import Call, RpcError, StatusCode # type: ignore from dapr.clients.grpc._response import TopicEventResponse from dapr.clients.health import DaprHealth from dapr.common.pubsub.subscription import ( + StreamCancelledError, StreamInactiveError, SubscriptionMessage, - StreamCancelledError, ) from dapr.proto import api_v1, appcallback_v1 -import queue -import threading -from typing import Optional class Subscription: diff --git a/dapr/clients/health.py b/dapr/clients/health.py index e3daec79d..4f3bdf8dc 100644 --- a/dapr/clients/health.py +++ b/dapr/clients/health.py @@ -12,11 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -import urllib.request -import urllib.error + import time +import urllib.error +import urllib.request -from dapr.clients.http.conf import DAPR_API_TOKEN_HEADER, USER_AGENT_HEADER, DAPR_USER_AGENT +from dapr.clients.http.conf import DAPR_API_TOKEN_HEADER, DAPR_USER_AGENT, USER_AGENT_HEADER from dapr.clients.http.helpers import get_api_url from dapr.conf import settings diff --git a/dapr/clients/http/client.py b/dapr/clients/http/client.py index 86e9ab6f0..072f31b15 100644 --- a/dapr/clients/http/client.py +++ b/dapr/clients/http/client.py @@ -13,25 +13,25 @@ limitations under the License. """ -import aiohttp +from typing import TYPE_CHECKING, Callable, Dict, Mapping, Optional, Tuple, Union -from typing import Callable, Mapping, Dict, Optional, Union, Tuple, TYPE_CHECKING +import aiohttp from dapr.clients.health import DaprHealth from dapr.clients.http.conf import ( + CONTENT_TYPE_HEADER, DAPR_API_TOKEN_HEADER, - USER_AGENT_HEADER, DAPR_USER_AGENT, - CONTENT_TYPE_HEADER, + USER_AGENT_HEADER, ) from dapr.clients.retry import RetryPolicy if TYPE_CHECKING: from dapr.serializers import Serializer -from dapr.conf import settings from dapr.clients._constants import DEFAULT_JSON_CONTENT_TYPE from dapr.clients.exceptions import DaprHttpError, DaprInternalError +from dapr.conf import settings class DaprHttpClient: diff --git a/dapr/clients/http/dapr_actor_http_client.py b/dapr/clients/http/dapr_actor_http_client.py index 186fdbc1c..711153659 100644 --- a/dapr/clients/http/dapr_actor_http_client.py +++ b/dapr/clients/http/dapr_actor_http_client.py @@ -13,15 +13,15 @@ limitations under the License. """ -from typing import Callable, Dict, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from dapr.clients.http.helpers import get_api_url if TYPE_CHECKING: from dapr.serializers import Serializer -from dapr.clients.http.client import DaprHttpClient from dapr.clients.base import DaprActorClientBase +from dapr.clients.http.client import DaprHttpClient from dapr.clients.retry import RetryPolicy DAPR_REENTRANCY_ID_HEADER = 'Dapr-Reentrancy-Id' diff --git a/dapr/clients/http/dapr_invocation_http_client.py b/dapr/clients/http/dapr_invocation_http_client.py index df4e6d222..604c483c0 100644 --- a/dapr/clients/http/dapr_invocation_http_client.py +++ b/dapr/clients/http/dapr_invocation_http_client.py @@ -14,13 +14,13 @@ """ import asyncio - from typing import Callable, Dict, Optional, Union + from multidict import MultiDict -from dapr.clients.http.client import DaprHttpClient -from dapr.clients.grpc._helpers import MetadataTuple, GrpcMessage +from dapr.clients.grpc._helpers import GrpcMessage, MetadataTuple from dapr.clients.grpc._response import InvokeMethodResponse +from dapr.clients.http.client import DaprHttpClient from dapr.clients.http.conf import CONTENT_TYPE_HEADER from dapr.clients.http.helpers import get_api_url from dapr.clients.retry import RetryPolicy diff --git a/dapr/clients/retry.py b/dapr/clients/retry.py index 171c96fbd..e895e46f3 100644 --- a/dapr/clients/retry.py +++ b/dapr/clients/retry.py @@ -12,11 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. """ + import asyncio -from typing import Optional, List, Callable +import time +from typing import Callable, List, Optional from grpc import RpcError, StatusCode # type: ignore -import time from dapr.conf import settings diff --git a/dapr/common/pubsub/subscription.py b/dapr/common/pubsub/subscription.py index 6f68e180d..eb22a48da 100644 --- a/dapr/common/pubsub/subscription.py +++ b/dapr/common/pubsub/subscription.py @@ -1,7 +1,9 @@ import json +from typing import Optional, Union + from google.protobuf.json_format import MessageToDict + from dapr.proto.runtime.v1.appcallback_pb2 import TopicEventRequest -from typing import Optional, Union class SubscriptionMessage: diff --git a/dapr/conf/__init__.py b/dapr/conf/__init__.py index 7fbe5f2f7..0959311ab 100644 --- a/dapr/conf/__init__.py +++ b/dapr/conf/__init__.py @@ -24,9 +24,7 @@ def __init__(self): default_value = getattr(global_settings, setting) env_variable = os.environ.get(setting) if env_variable: - val = ( - type(default_value)(env_variable) if default_value is not None else env_variable - ) + val = self._coerce_env_value(default_value, env_variable) setattr(self, setting, val) else: setattr(self, setting, default_value) @@ -36,5 +34,27 @@ def __getattr__(self, name): raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") return getattr(self, name) + @staticmethod + def _coerce_env_value(default_value, env_variable: str): + if default_value is None: + return env_variable + # Handle booleans explicitly to avoid bool('false') == True + if isinstance(default_value, bool): + s = env_variable.strip().lower() + if s in ('1', 'true', 't', 'yes', 'y', 'on'): + return True + if s in ('0', 'false', 'f', 'no', 'n', 'off'): + return False + # Fallback: non-empty -> True for backward-compat + return bool(s) + # Integers + if isinstance(default_value, int) and not isinstance(default_value, bool): + return int(env_variable) + # Floats + if isinstance(default_value, float): + return float(env_variable) + # Other types: try to cast as before + return type(default_value)(env_variable) + settings = Settings() diff --git a/dapr/conf/global_settings.py b/dapr/conf/global_settings.py index 5a64e5d4c..dd2bae86d 100644 --- a/dapr/conf/global_settings.py +++ b/dapr/conf/global_settings.py @@ -34,6 +34,23 @@ DAPR_HTTP_TIMEOUT_SECONDS = 60 +# gRPC keepalive (disabled by default; enable via env to help with idle debugging sessions) +DAPR_GRPC_KEEPALIVE_ENABLED: bool = False +DAPR_GRPC_KEEPALIVE_TIME_MS: int = 120000 # send keepalive pings every 120s +DAPR_GRPC_KEEPALIVE_TIMEOUT_MS: int = ( + 20000 # wait 20s for ack before considering the connection dead +) +DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS: bool = False # allow pings when there are no active calls + +# gRPC retries (disabled by default; enable via env to apply channel service config) +DAPR_GRPC_RETRY_ENABLED: bool = False +DAPR_GRPC_RETRY_MAX_ATTEMPTS: int = 4 +DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS: int = 100 +DAPR_GRPC_RETRY_MAX_BACKOFF_MS: int = 1000 +DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER: float = 2.0 +# Comma-separated list of status codes, e.g., 'UNAVAILABLE,DEADLINE_EXCEEDED' +DAPR_GRPC_RETRY_CODES: str = 'UNAVAILABLE,DEADLINE_EXCEEDED' + # ----- Conversation API settings ------ # Configuration for handling large enums to avoid massive JSON schemas that can exceed LLM token limits diff --git a/dapr/conf/helpers.py b/dapr/conf/helpers.py index ab1e494b2..c6c121f47 100644 --- a/dapr/conf/helpers.py +++ b/dapr/conf/helpers.py @@ -1,5 +1,8 @@ +import json +from urllib.parse import ParseResult, parse_qs, urlparse from warnings import warn -from urllib.parse import urlparse, parse_qs, ParseResult + +from dapr.conf import settings class URIParseConfig: @@ -174,7 +177,7 @@ def tls(self) -> bool: def _validate_path_and_query(self) -> None: if self._parsed_url.path: raise ValueError( - f'paths are not supported for gRPC endpoints:' f" '{self._parsed_url.path}'" + f"paths are not supported for gRPC endpoints: '{self._parsed_url.path}'" ) if self._parsed_url.query: query_dict = parse_qs(self._parsed_url.query) @@ -189,3 +192,68 @@ def _validate_path_and_query(self) -> None: f'query parameters are not supported for gRPC endpoints:' f" '{self._parsed_url.query}'" ) + + +# ------------------------------ +# gRPC channel options helpers +# ------------------------------ + + +def get_grpc_keepalive_options(): + """Return a list of keepalive channel options if enabled, else empty list. + + Options are tuples suitable for passing to grpc.{secure,insecure}_channel. + """ + if not settings.DAPR_GRPC_KEEPALIVE_ENABLED: + return [] + return [ + ('grpc.keepalive_time_ms', int(settings.DAPR_GRPC_KEEPALIVE_TIME_MS)), + ('grpc.keepalive_timeout_ms', int(settings.DAPR_GRPC_KEEPALIVE_TIMEOUT_MS)), + ( + 'grpc.keepalive_permit_without_calls', + 1 if settings.DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS else 0, + ), + ] + + +def get_grpc_retry_service_config_option(): + """Return ('grpc.service_config', json) option if retry is enabled, else None. + + Applies a universal retry policy via gRPC service config. + """ + if not getattr(settings, 'DAPR_GRPC_RETRY_ENABLED', False): + return None + retry_policy = { + 'maxAttempts': int(settings.DAPR_GRPC_RETRY_MAX_ATTEMPTS), + 'initialBackoff': f'{int(settings.DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS) / 1000.0}s', + 'maxBackoff': f'{int(settings.DAPR_GRPC_RETRY_MAX_BACKOFF_MS) / 1000.0}s', + 'backoffMultiplier': float(settings.DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER), + 'retryableStatusCodes': [ + c.strip() for c in str(settings.DAPR_GRPC_RETRY_CODES).split(',') if c.strip() + ], + } + service_config = { + 'methodConfig': [ + { + 'name': [{'service': ''}], # apply to all services + 'retryPolicy': retry_policy, + } + ] + } + return ('grpc.service_config', json.dumps(service_config)) + + +def build_grpc_channel_options(base_options=None): + """Combine base options with keepalive and retry policy options. + + Args: + base_options: optional iterable of (key, value) tuples. + Returns: + list of (key, value) tuples. + """ + options = list(base_options or []) + options.extend(get_grpc_keepalive_options()) + retry_opt = get_grpc_retry_service_config_option() + if retry_opt is not None: + options.append(retry_opt) + return options diff --git a/dapr/serializers/json.py b/dapr/serializers/json.py index 4e9665187..59e1c194b 100644 --- a/dapr/serializers/json.py +++ b/dapr/serializers/json.py @@ -14,18 +14,18 @@ """ import base64 -import re import datetime import json - +import re from typing import Any, Callable, Optional, Type + from dateutil import parser from dapr.serializers.base import Serializer from dapr.serializers.util import ( + DAPR_DURATION_PARSER, convert_from_dapr_duration, convert_to_dapr_duration, - DAPR_DURATION_PARSER, ) diff --git a/dev-requirements.txt b/dev-requirements.txt index cbd719859..612b17570 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,7 +3,9 @@ mypy-extensions>=0.4.3 mypy-protobuf>=2.9 flake8>=3.7.9 tox>=4.3.0 +pip>=23.0.0 coverage>=5.3 +pytest wheel # used in unit test only opentelemetry-sdk diff --git a/examples/configuration/configuration.py b/examples/configuration/configuration.py index caf676e6b..d579df7fa 100644 --- a/examples/configuration/configuration.py +++ b/examples/configuration/configuration.py @@ -4,8 +4,9 @@ import asyncio from time import sleep + from dapr.clients import DaprClient -from dapr.clients.grpc._response import ConfigurationWatcher, ConfigurationResponse +from dapr.clients.grpc._response import ConfigurationResponse, ConfigurationWatcher configuration: ConfigurationWatcher = ConfigurationWatcher() diff --git a/examples/crypto/crypto-async.py b/examples/crypto/crypto-async.py index 0946e9bbb..2e49a8282 100644 --- a/examples/crypto/crypto-async.py +++ b/examples/crypto/crypto-async.py @@ -14,7 +14,7 @@ import asyncio from dapr.aio.clients import DaprClient -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions # Name of the crypto component to use CRYPTO_COMPONENT_NAME = 'crypto-localstorage' diff --git a/examples/crypto/crypto.py b/examples/crypto/crypto.py index a282ba453..afe00f343 100644 --- a/examples/crypto/crypto.py +++ b/examples/crypto/crypto.py @@ -12,7 +12,7 @@ # ------------------------------------------------------------ from dapr.clients import DaprClient -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions # Name of the crypto component to use CRYPTO_COMPONENT_NAME = 'crypto-localstorage' diff --git a/examples/demo_actor/demo_actor/demo_actor.py b/examples/demo_actor/demo_actor/demo_actor.py index 0d65d57d2..f9306d47c 100644 --- a/examples/demo_actor/demo_actor/demo_actor.py +++ b/examples/demo_actor/demo_actor/demo_actor.py @@ -11,10 +11,11 @@ # limitations under the License. import datetime +from typing import Optional -from dapr.actor import Actor, Remindable from demo_actor_interface import DemoActorInterface -from typing import Optional + +from dapr.actor import Actor, Remindable class DemoActor(Actor, DemoActorInterface, Remindable): diff --git a/examples/demo_actor/demo_actor/demo_actor_client.py b/examples/demo_actor/demo_actor/demo_actor_client.py index df0e9f737..ad0dfccb6 100644 --- a/examples/demo_actor/demo_actor/demo_actor_client.py +++ b/examples/demo_actor/demo_actor/demo_actor_client.py @@ -12,10 +12,11 @@ import asyncio -from dapr.actor import ActorProxy, ActorId, ActorProxyFactory -from dapr.clients.retry import RetryPolicy from demo_actor_interface import DemoActorInterface +from dapr.actor import ActorId, ActorProxy, ActorProxyFactory +from dapr.clients.retry import RetryPolicy + async def main(): # Create proxy client diff --git a/examples/demo_actor/demo_actor/demo_actor_flask.py b/examples/demo_actor/demo_actor/demo_actor_flask.py index 5715d23d8..de1245ad0 100644 --- a/examples/demo_actor/demo_actor/demo_actor_flask.py +++ b/examples/demo_actor/demo_actor/demo_actor_flask.py @@ -10,13 +10,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from demo_actor import DemoActor from flask import Flask, jsonify from flask_dapr.actor import DaprActor -from dapr.conf import settings -from dapr.actor.runtime.config import ActorRuntimeConfig, ActorTypeConfig, ActorReentrancyConfig +from dapr.actor.runtime.config import ActorReentrancyConfig, ActorRuntimeConfig, ActorTypeConfig from dapr.actor.runtime.runtime import ActorRuntime -from demo_actor import DemoActor +from dapr.conf import settings app = Flask(f'{DemoActor.__name__}Service') diff --git a/examples/demo_actor/demo_actor/demo_actor_service.py b/examples/demo_actor/demo_actor/demo_actor_service.py index c53d06e25..046a8df24 100644 --- a/examples/demo_actor/demo_actor/demo_actor_service.py +++ b/examples/demo_actor/demo_actor/demo_actor_service.py @@ -10,12 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from demo_actor import DemoActor from fastapi import FastAPI # type: ignore -from dapr.actor.runtime.config import ActorRuntimeConfig, ActorTypeConfig, ActorReentrancyConfig + +from dapr.actor.runtime.config import ActorReentrancyConfig, ActorRuntimeConfig, ActorTypeConfig from dapr.actor.runtime.runtime import ActorRuntime from dapr.ext.fastapi import DaprActor # type: ignore -from demo_actor import DemoActor - app = FastAPI(title=f'{DemoActor.__name__}Service') diff --git a/examples/demo_workflow/app.py b/examples/demo_workflow/app.py index c89dcae6e..cf561aa6d 100644 --- a/examples/demo_workflow/app.py +++ b/examples/demo_workflow/app.py @@ -12,15 +12,16 @@ from datetime import timedelta from time import sleep + +from dapr.clients import DaprClient +from dapr.clients.exceptions import DaprInternalError +from dapr.conf import Settings from dapr.ext.workflow import ( - WorkflowRuntime, DaprWorkflowContext, - WorkflowActivityContext, RetryPolicy, + WorkflowActivityContext, + WorkflowRuntime, ) -from dapr.conf import Settings -from dapr.clients import DaprClient -from dapr.clients.exceptions import DaprInternalError settings = Settings() @@ -192,8 +193,7 @@ def main(): instance_id=instance_id, workflow_component=workflow_component ) print( - f'Get response from {workflow_name} ' - f'after terminate call: {get_response.runtime_status}' + f'Get response from {workflow_name} after terminate call: {get_response.runtime_status}' ) child_get_response = d.get_workflow( instance_id=child_instance_id, workflow_component=workflow_component diff --git a/examples/distributed_lock/lock.py b/examples/distributed_lock/lock.py index d18d955f6..2f6364065 100644 --- a/examples/distributed_lock/lock.py +++ b/examples/distributed_lock/lock.py @@ -11,9 +11,10 @@ # limitations under the License. # ------------------------------------------------------------ -from dapr.clients import DaprClient import warnings +from dapr.clients import DaprClient + def main(): # Lock parameters diff --git a/examples/error_handling/error_handling.py b/examples/error_handling/error_handling.py index b75ebed97..ae42a88cd 100644 --- a/examples/error_handling/error_handling.py +++ b/examples/error_handling/error_handling.py @@ -1,7 +1,6 @@ from dapr.clients import DaprClient from dapr.clients.exceptions import DaprGrpcError - with DaprClient() as d: storeName = 'statestore' diff --git a/examples/invoke-custom-data/invoke-caller.py b/examples/invoke-custom-data/invoke-caller.py index 27dabd4de..caeb84313 100644 --- a/examples/invoke-custom-data/invoke-caller.py +++ b/examples/invoke-custom-data/invoke-caller.py @@ -1,7 +1,7 @@ -from dapr.clients import DaprClient - import proto.response_pb2 as response_messages +from dapr.clients import DaprClient + with DaprClient() as d: # Create a typed message with content type and body resp = d.invoke_method( diff --git a/examples/invoke-custom-data/invoke-receiver.py b/examples/invoke-custom-data/invoke-receiver.py index e2ad83ce5..cd1e074d5 100644 --- a/examples/invoke-custom-data/invoke-receiver.py +++ b/examples/invoke-custom-data/invoke-receiver.py @@ -1,7 +1,7 @@ -from dapr.ext.grpc import App, InvokeMethodRequest - import proto.response_pb2 as response_messages +from dapr.ext.grpc import App, InvokeMethodRequest + app = App() diff --git a/examples/invoke-http/invoke-receiver.py b/examples/invoke-http/invoke-receiver.py index 8609464af..928f86987 100644 --- a/examples/invoke-http/invoke-receiver.py +++ b/examples/invoke-http/invoke-receiver.py @@ -1,7 +1,8 @@ # from dapr.ext.grpc import App, InvokeMethodRequest, InvokeMethodResponse -from flask import Flask, request import json +from flask import Flask, request + app = Flask(__name__) diff --git a/examples/jobs/job_management.py b/examples/jobs/job_management.py index fd8c7af88..adfeefef4 100644 --- a/examples/jobs/job_management.py +++ b/examples/jobs/job_management.py @@ -1,9 +1,10 @@ import json from datetime import datetime, timedelta -from dapr.clients import DaprClient, Job, DropFailurePolicy, ConstantFailurePolicy from google.protobuf.any_pb2 import Any as GrpcAny +from dapr.clients import ConstantFailurePolicy, DaprClient, DropFailurePolicy, Job + def create_job_data(message: str): """Helper function to create job payload data.""" diff --git a/examples/jobs/job_processing.py b/examples/jobs/job_processing.py index 9f5733b79..8df1ae5b2 100644 --- a/examples/jobs/job_processing.py +++ b/examples/jobs/job_processing.py @@ -15,8 +15,9 @@ import threading import time from datetime import datetime, timedelta + +from dapr.clients import ConstantFailurePolicy, DaprClient, Job from dapr.ext.grpc import App, JobEvent -from dapr.clients import DaprClient, Job, ConstantFailurePolicy try: from google.protobuf.any_pb2 import Any as GrpcAny diff --git a/examples/pubsub-simple/subscriber.py b/examples/pubsub-simple/subscriber.py index daa11bc89..7dab9e922 100644 --- a/examples/pubsub-simple/subscriber.py +++ b/examples/pubsub-simple/subscriber.py @@ -11,14 +11,15 @@ # limitations under the License. # ------------------------------------------------------------ +import json from time import sleep + from cloudevents.sdk.event import v1 -from dapr.ext.grpc import App + from dapr.clients.grpc._response import TopicEventResponse +from dapr.ext.grpc import App from dapr.proto import appcallback_v1 -import json - app = App() should_retry = True # To control whether dapr should retry sending a message diff --git a/examples/pubsub-streaming-async/subscriber-handler.py b/examples/pubsub-streaming-async/subscriber-handler.py index 06a492af5..c9c8203c2 100644 --- a/examples/pubsub-streaming-async/subscriber-handler.py +++ b/examples/pubsub-streaming-async/subscriber-handler.py @@ -1,5 +1,6 @@ import argparse import asyncio + from dapr.aio.clients import DaprClient from dapr.clients.grpc._response import TopicEventResponse diff --git a/examples/state_store/state_store.py b/examples/state_store/state_store.py index 301c675bc..b783fcdc9 100644 --- a/examples/state_store/state_store.py +++ b/examples/state_store/state_store.py @@ -5,11 +5,9 @@ import grpc from dapr.clients import DaprClient - from dapr.clients.grpc._request import TransactionalStateOperation, TransactionOperationType from dapr.clients.grpc._state import StateItem - with DaprClient() as d: storeName = 'statestore' diff --git a/examples/state_store_query/state_store_query.py b/examples/state_store_query/state_store_query.py index f532f0eb0..26c64da3e 100644 --- a/examples/state_store_query/state_store_query.py +++ b/examples/state_store_query/state_store_query.py @@ -2,10 +2,9 @@ dapr run python3 state_store_query.py """ -from dapr.clients import DaprClient - import json +from dapr.clients import DaprClient with DaprClient() as d: store_name = 'statestore' diff --git a/examples/workflow-async/README.md b/examples/workflow-async/README.md new file mode 100644 index 000000000..4ce670df9 --- /dev/null +++ b/examples/workflow-async/README.md @@ -0,0 +1,15 @@ +# Dapr Workflow Async Examples (Python) + +These examples mirror `examples/workflow/` but author orchestrators with `async def` using the +async workflow APIs. Activities remain regular functions unless noted. + +How to run: +- Ensure a Dapr sidecar is running locally. If needed, set `DURABLETASK_GRPC_ENDPOINT`, or + `DURABLETASK_GRPC_HOST/PORT`. +- Install requirements: `pip install -r requirements.txt` +- Run any example: `python simple.py` + +Notes: +- Orchestrators use `await ctx.activity(...)`, `await ctx.sleep(...)`, `await ctx.when_all/when_any(...)`, etc. +- No event loop is started manually; the Durable Task worker drives the async orchestrators. +- You can also launch instances using `DaprWorkflowClient` as in the non-async examples. diff --git a/examples/workflow-async/child_workflow.py b/examples/workflow-async/child_workflow.py new file mode 100644 index 000000000..9e89a7521 --- /dev/null +++ b/examples/workflow-async/child_workflow.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='child_async') +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +@wfr.async_workflow(name='parent_async') +async def parent(ctx: AsyncWorkflowContext, n: int) -> int: + r = await ctx.call_child_workflow(child, input=n) + print(f'Child workflow returned {r}') + return r + 1 + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'parent_async_instance' + client.schedule_new_workflow(workflow=parent, input=5, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/fan_out_fan_in.py b/examples/workflow-async/fan_out_fan_in.py new file mode 100644 index 000000000..16e6e48d4 --- /dev/null +++ b/examples/workflow-async/fan_out_fan_in.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='square') +def square(ctx: WorkflowActivityContext, x: int) -> int: + return x * x + + +@wfr.async_workflow(name='fan_out_fan_in_async') +async def orchestrator(ctx: AsyncWorkflowContext): + tasks = [ctx.call_activity(square, input=i) for i in range(1, 6)] + results = await ctx.when_all(tasks) + total = sum(results) + return total + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'fofi_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow state: {wf_state}') + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/human_approval.py b/examples/workflow-async/human_approval.py new file mode 100644 index 000000000..7ce177225 --- /dev/null +++ b/examples/workflow-async/human_approval.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, DaprWorkflowClient, WorkflowRuntime + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='human_approval_async') +async def orchestrator(ctx: AsyncWorkflowContext, request_id: str): + decision = await ctx.when_any( + [ + ctx.wait_for_external_event(f'approve:{request_id}'), + ctx.wait_for_external_event(f'reject:{request_id}'), + ctx.create_timer(300.0), + ] + ) + if isinstance(decision, dict) and decision.get('approved'): + return 'APPROVED' + if isinstance(decision, dict) and decision.get('rejected'): + return 'REJECTED' + return 'TIMEOUT' + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'human_approval_async_1' + client.schedule_new_workflow(workflow=orchestrator, input='REQ-1', instance_id=instance_id) + # In a real scenario, raise approve/reject event from another service. + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/requirements.txt b/examples/workflow-async/requirements.txt new file mode 100644 index 000000000..e220036d6 --- /dev/null +++ b/examples/workflow-async/requirements.txt @@ -0,0 +1,2 @@ +dapr-ext-workflow-dev>=1.15.0.dev +dapr-dev>=1.15.0.dev diff --git a/examples/workflow-async/simple.py b/examples/workflow-async/simple.py new file mode 100644 index 000000000..1e7cf0398 --- /dev/null +++ b/examples/workflow-async/simple.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from datetime import timedelta +from time import sleep + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + RetryPolicy, + WorkflowActivityContext, + WorkflowRuntime, +) + +counter = 0 +retry_count = 0 +child_orchestrator_string = '' +instance_id = 'asyncExampleInstanceID' +child_instance_id = 'asyncChildInstanceID' +workflow_name = 'async_hello_world_wf' +child_workflow_name = 'async_child_wf' +input_data = 'Hi Async Counter!' +event_name = 'event1' +event_data = 'eventData' + +retry_policy = RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=100), +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name=workflow_name) +async def hello_world_wf(ctx: AsyncWorkflowContext, wf_input): + # activities + result_1 = await ctx.call_activity(hello_act, input=1) + print(f'Activity 1 returned {result_1}') + result_2 = await ctx.call_activity(hello_act, input=10) + print(f'Activity 2 returned {result_2}') + result_3 = await ctx.call_activity(hello_retryable_act, retry_policy=retry_policy) + print(f'Activity 3 returned {result_3}') + result_4 = await ctx.call_child_workflow(child_retryable_wf, retry_policy=retry_policy) + print(f'Child workflow returned {result_4}') + + # Event vs timeout using when_any + first = await ctx.when_any( + [ + ctx.wait_for_external_event(event_name), + ctx.create_timer(timedelta(seconds=30)), + ] + ) + + # Proceed only if event won + if isinstance(first, dict) and 'event' in first: + await ctx.call_activity(hello_act, input=100) + await ctx.call_activity(hello_act, input=1000) + return 'Completed' + return 'Timeout' + + +@wfr.activity(name='async_hello_act') +def hello_act(ctx: WorkflowActivityContext, wf_input): + global counter + counter += wf_input + return f'Activity returned {wf_input}' + + +@wfr.activity(name='async_hello_retryable_act') +def hello_retryable_act(ctx: WorkflowActivityContext): + global retry_count + if (retry_count % 2) == 0: + retry_count += 1 + raise ValueError('Retryable Error') + retry_count += 1 + return f'Activity returned {retry_count}' + + +@wfr.async_workflow(name=child_workflow_name) +async def child_retryable_wf(ctx: AsyncWorkflowContext): + # Call activity with retry and simulate retryable workflow failure until certain state + child_activity_result = await ctx.call_activity( + act_for_child_wf, input='x', retry_policy=retry_policy + ) + print(f'Child activity returned {child_activity_result}') + # In a real sample, you might check state and raise to trigger retry + return 'ok' + + +@wfr.activity(name='async_act_for_child_wf') +def act_for_child_wf(ctx: WorkflowActivityContext, inp): + global child_orchestrator_string + child_orchestrator_string += inp + + +def main(): + wfr.start() + wf_client = DaprWorkflowClient() + + wf_client.schedule_new_workflow( + workflow=hello_world_wf, input=input_data, instance_id=instance_id + ) + + wf_client.wait_for_workflow_start(instance_id) + + # Let initial activities run + sleep(5) + + # Raise event to continue + wf_client.raise_workflow_event( + instance_id=instance_id, event_name=event_name, data={'ok': True} + ) + + # Wait for completion + state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow status: {state.runtime_status.name}') + + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/task_chaining.py b/examples/workflow-async/task_chaining.py new file mode 100644 index 000000000..c9c92addc --- /dev/null +++ b/examples/workflow-async/task_chaining.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='sum') +def sum_act(ctx: WorkflowActivityContext, nums): + return sum(nums) + + +@wfr.async_workflow(name='task_chaining_async') +async def orchestrator(ctx: AsyncWorkflowContext): + a = await ctx.call_activity(sum_act, input=[1, 2]) + b = await ctx.call_activity(sum_act, input=[a, 3]) + c = await ctx.call_activity(sum_act, input=[b, 4]) + return c + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'task_chain_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/README.md b/examples/workflow/README.md index f5b901d1c..6f3a97f9f 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -12,6 +12,8 @@ This directory contains examples of using the [Dapr Workflow](https://docs.dapr. You can install dapr SDK package using pip command: ```sh +python3 -m venv .venv +source .venv/bin/activate pip3 install -r requirements.txt ``` diff --git a/examples/workflow/aio/async_activity_sequence.py b/examples/workflow/aio/async_activity_sequence.py new file mode 100644 index 000000000..8eecd1f87 --- /dev/null +++ b/examples/workflow/aio/async_activity_sequence.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.activity(name='add') + def add(ctx, xy): + return xy[0] + xy[1] + + @rt.workflow(name='sum_three') + async def sum_three(ctx: AsyncWorkflowContext, nums): + a = await ctx.call_activity(add, input=[nums[0], nums[1]]) + b = await ctx.call_activity(add, input=[a, nums[2]]) + return b + + rt.start() + print("Registered async workflow 'sum_three' and activity 'add'") + + # This example registers only; use Dapr client to start instances externally. + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_external_event.py b/examples/workflow/aio/async_external_event.py new file mode 100644 index 000000000..905314224 --- /dev/null +++ b/examples/workflow/aio/async_external_event.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='wait_event') + async def wait_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return {'event': data} + + rt.start() + print("Registered async workflow 'wait_event'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_sub_orchestrator.py b/examples/workflow/aio/async_sub_orchestrator.py new file mode 100644 index 000000000..c00d9ca93 --- /dev/null +++ b/examples/workflow/aio/async_sub_orchestrator.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='child') + async def child(ctx: AsyncWorkflowContext, n): + return n * 2 + + @rt.async_workflow(name='parent') + async def parent(ctx: AsyncWorkflowContext, n): + r = await ctx.call_child_workflow(child, input=n) + return r + 1 + + rt.start() + print("Registered async workflows 'parent' and 'child'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/context_interceptors_example.py b/examples/workflow/aio/context_interceptors_example.py new file mode 100644 index 000000000..d005bca17 --- /dev/null +++ b/examples/workflow/aio/context_interceptors_example.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- + +""" +Example: Interceptors for context propagation (client + runtime). + +This example shows how to: + - Define a small context (dict) carried via contextvars + - Implement ClientInterceptor to inject that context into outbound inputs + - Implement RuntimeInterceptor to restore the context before user code runs + - Wire interceptors into WorkflowRuntime and DaprWorkflowClient + +Note: Scheduling/running requires a Dapr sidecar. This file focuses on the wiring pattern. +""" + +from __future__ import annotations + +import contextvars +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +# A simple context carried across boundaries +_current_ctx: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + 'wf_ctx', default=None +) + + +def set_ctx(ctx: dict[str, Any] | None) -> None: + _current_ctx.set(ctx) + + +def get_ctx() -> dict[str, Any] | None: + return _current_ctx.get() + + +def _merge_ctx(args: Any) -> Any: + ctx = get_ctx() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + +class ContextClientInterceptor(BaseClientInterceptor): + def schedule_new_workflow( + self, input: ScheduleWorkflowInput, nxt: Callable[[ScheduleWorkflowInput], Any] + ) -> Any: # type: ignore[override] + input = ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + ) + return nxt(input) + + +class ContextWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def call_child_workflow( + self, input: CallChildWorkflowInput, nxt: Callable[[CallChildWorkflowInput], Any] + ) -> Any: + return nxt( + CallChildWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + def call_activity( + self, input: CallActivityInput, nxt: Callable[[CallActivityInput], Any] + ) -> Any: + return nxt( + CallActivityInput( + activity_name=input.activity_name, + args=_merge_ctx(input.args), + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + +class ContextRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow( + self, input: ExecuteWorkflowInput, nxt: Callable[[ExecuteWorkflowInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + def execute_activity( + self, input: ExecuteActivityInput, nxt: Callable[[ExecuteActivityInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + +# Example workflow and activity +def activity_log(ctx, data: dict[str, Any]) -> str: # noqa: ANN001 (example) + # Access restored context inside activity via contextvars + return f'ok:{get_ctx()}' + + +def workflow_example(ctx, x: int): # noqa: ANN001 (example) + y = yield ctx.call_activity(activity_log, input={'msg': 'hello'}) + return y + + +def wire_up() -> tuple[WorkflowRuntime, DaprWorkflowClient]: + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + + # Register workflow/activity + runtime.workflow(name='example')(workflow_example) + runtime.activity(name='activity_log')(activity_log) + return runtime, client + + +if __name__ == '__main__': + # This section demonstrates how you would set a context and schedule a workflow. + # Requires a running Dapr sidecar to actually execute. + rt, cli = wire_up() + set_ctx({'tenant': 'acme', 'request_id': 'r-123'}) + # instance_id = cli.schedule_new_workflow(workflow_example, input={'x': 1}) + # print('scheduled:', instance_id) + # rt.start(); rt.wait_for_ready(); ... + pass diff --git a/examples/workflow/aio/model_tool_serialization_example.py b/examples/workflow/aio/model_tool_serialization_example.py new file mode 100644 index 000000000..2c1bdf4c8 --- /dev/null +++ b/examples/workflow/aio/model_tool_serialization_example.py @@ -0,0 +1,66 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from dapr.ext.workflow import ensure_canonical_json + +""" +Example of implementing provider-specific model/tool serialization OUTSIDE the core package. + +This demonstrates how to build and use your own contracts using the generic helpers from +`dapr.ext.workflow.serializers`. +""" + + +def to_model_request(payload: Dict[str, Any]) -> Dict[str, Any]: + req = { + 'schema_version': 'model_req@v1', + 'model_name': payload.get('model_name'), + 'system_instructions': payload.get('system_instructions'), + 'input': payload.get('input'), + 'model_settings': payload.get('model_settings') or {}, + 'tools': payload.get('tools') or [], + } + return ensure_canonical_json(req, strict=True) + + +def from_model_response(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict): + content = obj.get('content') + tool_calls = obj.get('tool_calls') or [] + out = {'schema_version': 'model_res@v1', 'content': content, 'tool_calls': tool_calls} + return ensure_canonical_json(out, strict=False) + return ensure_canonical_json( + {'schema_version': 'model_res@v1', 'content': str(obj), 'tool_calls': []}, strict=False + ) + + +def to_tool_request(name: str, args: list | None, kwargs: dict | None) -> Dict[str, Any]: + req = { + 'schema_version': 'tool_req@v1', + 'tool_name': name, + 'args': args or [], + 'kwargs': kwargs or {}, + } + return ensure_canonical_json(req, strict=True) + + +def from_tool_result(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict) and ('result' in obj or 'error' in obj): + return ensure_canonical_json({'schema_version': 'tool_res@v1', **obj}, strict=False) + return ensure_canonical_json( + {'schema_version': 'tool_res@v1', 'result': obj, 'error': None}, strict=False + ) diff --git a/examples/workflow/aio/tracing_interceptors_example.py b/examples/workflow/aio/tracing_interceptors_example.py new file mode 100644 index 000000000..ea4834fb0 --- /dev/null +++ b/examples/workflow/aio/tracing_interceptors_example.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +TRACE_ID_KEY = 'otel.trace_id' +SPAN_ID_KEY = 'otel.span_id' + + +class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def schedule_new_workflow(self, input: ScheduleWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict(input.metadata or {}) + md[TRACE_ID_KEY] = trace_id + md[SPAN_ID_KEY] = span_id + return next( + ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=md, + local_context=input.local_context, + ) + ) + + +class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def __init__(self, on_span: Callable[[str, dict[str, str]], Any]): + self._on_span = on_span + + def execute_workflow(self, input: ExecuteWorkflowInput, next): # type: ignore[override] + # Suppress spans during replay + if not input.ctx.is_replaying: + self._on_span('dapr:executeWorkflow', input.metadata or {}) + return next(input) + + def execute_activity(self, input: ExecuteActivityInput, next): # type: ignore[override] + self._on_span('dapr:executeActivity', input.metadata or {}) + return next(input) + + +class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def call_activity(self, input: CallActivityInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + activity_name=input.activity_name, + args=input.args, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + def call_child_workflow(self, input: CallChildWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + +def example_usage(): + # Simplified trace getter and span recorder + def _get_trace(): + return ('trace-123', 'span-abc') + + spans: list[tuple[str, dict[str, str]]] = [] + + def _on_span(name: str, attrs: dict[str, str]): + spans.append((name, attrs)) + + runtime = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor(_on_span)], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(_get_trace)], + ) + + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(_get_trace)]) + + # Register and run as you would normally; spans list can be asserted in tests + return runtime, client, spans + + +if __name__ == '__main__': # pragma: no cover + example_usage() diff --git a/examples/workflow/child_workflow.py b/examples/workflow/child_workflow.py index dccaa631b..57ab2fc3e 100644 --- a/examples/workflow/child_workflow.py +++ b/examples/workflow/child_workflow.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dapr.ext.workflow as wf import time +import dapr.ext.workflow as wf + wfr = wf.WorkflowRuntime() diff --git a/examples/workflow/e2e_execinfo.py b/examples/workflow/e2e_execinfo.py new file mode 100644 index 000000000..91f0b295f --- /dev/null +++ b/examples/workflow/e2e_execinfo.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import time + +from dapr.ext.workflow import DaprWorkflowClient, WorkflowRuntime + + +def main(): + port = '50001' + + rt = WorkflowRuntime(port=port) + + def activity_noop(ctx): + ei = ctx.execution_info + # Return attempt (may be None if engine doesn't set it) + return { + 'attempt': ei.attempt if ei else None, + 'workflow_id': ei.workflow_id if ei else None, + } + + @rt.workflow(name='child-to-parent') + def child(ctx, x): + ei = ctx.execution_info + out = yield ctx.call_activity(activity_noop, input=None) + return { + 'child_workflow_name': ei.workflow_name if ei else None, + 'parent_instance_id': ei.parent_instance_id if ei else None, + 'activity': out, + } + + @rt.workflow(name='parent') + def parent(ctx, x): + res = yield ctx.call_child_workflow(child, input={'x': x}) + return res + + rt.register_activity(activity_noop, name='activity_noop') + + rt.start() + try: + # Wait for the worker to be ready to accept work + rt.wait_for_ready(timeout=10) + + client = DaprWorkflowClient(port=port) + instance_id = client.schedule_new_workflow(parent, input=1) + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=30) + print('instance:', instance_id) + print('runtime_status:', state.runtime_status if state else None) + print('state:', state) + finally: + # Give a moment for logs to flush then shutdown + time.sleep(0.5) + rt.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/fan_out_fan_in.py b/examples/workflow/fan_out_fan_in.py index e5799862f..f625ea287 100644 --- a/examples/workflow/fan_out_fan_in.py +++ b/examples/workflow/fan_out_fan_in.py @@ -12,6 +12,7 @@ import time from typing import List + import dapr.ext.workflow as wf wfr = wf.WorkflowRuntime() diff --git a/examples/workflow/human_approval.py b/examples/workflow/human_approval.py index 6a8a725d7..0ca373e53 100644 --- a/examples/workflow/human_approval.py +++ b/examples/workflow/human_approval.py @@ -11,12 +11,12 @@ # limitations under the License. import threading +import time from dataclasses import asdict, dataclass from datetime import timedelta -import time -from dapr.clients import DaprClient import dapr.ext.workflow as wf +from dapr.clients import DaprClient wfr = wf.WorkflowRuntime() diff --git a/examples/workflow/monitor.py b/examples/workflow/monitor.py index 6cf575cfe..d4f534df5 100644 --- a/examples/workflow/monitor.py +++ b/examples/workflow/monitor.py @@ -10,10 +10,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from dataclasses import dataclass from datetime import timedelta -import random from time import sleep + import dapr.ext.workflow as wf wfr = wf.WorkflowRuntime() diff --git a/examples/workflow/requirements.txt b/examples/workflow/requirements.txt index e220036d6..db1d5c539 100644 --- a/examples/workflow/requirements.txt +++ b/examples/workflow/requirements.txt @@ -1,2 +1,12 @@ -dapr-ext-workflow-dev>=1.15.0.dev -dapr-dev>=1.15.0.dev +# dapr-ext-workflow-dev>=1.15.0.dev +# dapr-dev>=1.15.0.dev + +# local development: install local packages in editable mode + +# if using dev version of durabletask-python +-e ../../../durabletask-python + +# if using dev version of dapr-ext-workflow +-e ../../ext/dapr-ext-workflow +-e ../.. + diff --git a/examples/workflow/simple.py b/examples/workflow/simple.py index 76f21eba4..5bbdf68a9 100644 --- a/examples/workflow/simple.py +++ b/examples/workflow/simple.py @@ -11,16 +11,17 @@ # limitations under the License. from datetime import timedelta from time import sleep + +from dapr.clients.exceptions import DaprInternalError +from dapr.conf import Settings from dapr.ext.workflow import ( - WorkflowRuntime, + DaprWorkflowClient, DaprWorkflowContext, - WorkflowActivityContext, RetryPolicy, - DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, when_any, ) -from dapr.conf import Settings -from dapr.clients.exceptions import DaprInternalError settings = Settings() diff --git a/examples/workflow/task_chaining.py b/examples/workflow/task_chaining.py index 074cadcd2..8a2058e1c 100644 --- a/examples/workflow/task_chaining.py +++ b/examples/workflow/task_chaining.py @@ -14,7 +14,6 @@ import dapr.ext.workflow as wf - wfr = wf.WorkflowRuntime() diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py index 942603078..e43df65c9 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py @@ -16,5 +16,4 @@ from .actor import DaprActor from .app import DaprApp - __all__ = ['DaprActor', 'DaprApp'] diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py index 93b7860e1..4b3990da4 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py @@ -13,12 +13,12 @@ limitations under the License. """ -from typing import Any, Optional, Type, List +from typing import Any, List, Optional, Type from dapr.actor import Actor, ActorRuntime from dapr.clients.exceptions import ERROR_CODE_UNKNOWN, DaprInternalError from dapr.serializers import DefaultJSONSerializer -from fastapi import FastAPI, APIRouter, Request, Response, status # type: ignore +from fastapi import APIRouter, FastAPI, Request, Response, status # type: ignore from fastapi.logger import logger from fastapi.responses import JSONResponse diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py index d926fac5c..6bede5234 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py @@ -13,6 +13,7 @@ """ from typing import Dict, List, Optional + from fastapi import FastAPI # type: ignore diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py index 7d73b4a48..bf3ad7b52 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py @@ -13,13 +13,11 @@ limitations under the License. """ -from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest, JobEvent +from dapr.clients.grpc._jobs import ConstantFailurePolicy, DropFailurePolicy, FailurePolicy, Job +from dapr.clients.grpc._request import BindingRequest, InvokeMethodRequest, JobEvent from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse -from dapr.clients.grpc._jobs import Job, FailurePolicy, DropFailurePolicy, ConstantFailurePolicy - from dapr.ext.grpc.app import App, Rule # type:ignore - __all__ = [ 'App', 'Rule', diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py b/ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py index 029dff745..f6d782da1 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py @@ -1,6 +1,6 @@ -import grpc from typing import Callable, Optional +import grpc from dapr.proto import appcallback_service_v1 from dapr.proto.runtime.v1.appcallback_pb2 import HealthCheckResponse diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py index 996267fdd..8de632f97 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py @@ -12,25 +12,25 @@ See the License for the specific language governing permissions and limitations under the License. """ -import grpc -from cloudevents.sdk.event import v1 # type: ignore from typing import Callable, Dict, List, Optional, Tuple, Union +from cloudevents.sdk.event import v1 # type: ignore from google.protobuf import empty_pb2 from google.protobuf.message import Message as GrpcMessage from google.protobuf.struct_pb2 import Struct -from dapr.proto import appcallback_service_v1, common_v1, appcallback_v1 +import grpc +from dapr.clients._constants import DEFAULT_JSON_CONTENT_TYPE +from dapr.clients.grpc._request import BindingRequest, InvokeMethodRequest, JobEvent +from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse +from dapr.proto import appcallback_service_v1, appcallback_v1, common_v1 +from dapr.proto.common.v1.common_pb2 import InvokeRequest from dapr.proto.runtime.v1.appcallback_pb2 import ( - TopicEventRequest, BindingEventRequest, JobEventRequest, + TopicEventRequest, ) -from dapr.proto.common.v1.common_pb2 import InvokeRequest -from dapr.clients._constants import DEFAULT_JSON_CONTENT_TYPE -from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest, JobEvent -from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse InvokeMethodCallable = Callable[[InvokeMethodRequest], Union[str, bytes, InvokeMethodResponse]] TopicSubscribeCallable = Callable[[v1.Event], Optional[TopicEventResponse]] diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py index 9f9ac8472..87543bdf3 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py @@ -13,14 +13,13 @@ limitations under the License. """ -import grpc - from concurrent import futures from typing import Dict, Optional +import grpc from dapr.conf import settings -from dapr.ext.grpc._servicer import _CallbackServicer, Rule # type: ignore from dapr.ext.grpc._health_servicer import _HealthCheckServicer # type: ignore +from dapr.ext.grpc._servicer import Rule, _CallbackServicer # type: ignore from dapr.proto import appcallback_service_v1 diff --git a/ext/dapr-ext-grpc/tests/test_app.py b/ext/dapr-ext-grpc/tests/test_app.py index 2a33dd668..6d6661b53 100644 --- a/ext/dapr-ext-grpc/tests/test_app.py +++ b/ext/dapr-ext-grpc/tests/test_app.py @@ -16,7 +16,8 @@ import unittest from cloudevents.sdk.event import v1 -from dapr.ext.grpc import App, Rule, InvokeMethodRequest, BindingRequest + +from dapr.ext.grpc import App, BindingRequest, InvokeMethodRequest, Rule class AppTests(unittest.TestCase): diff --git a/ext/dapr-ext-grpc/tests/test_servicier.py b/ext/dapr-ext-grpc/tests/test_servicier.py index 2447eea3c..4cfa6a0e1 100644 --- a/ext/dapr-ext-grpc/tests/test_servicier.py +++ b/ext/dapr-ext-grpc/tests/test_servicier.py @@ -14,15 +14,14 @@ """ import unittest - from unittest.mock import MagicMock, Mock +from google.protobuf.any_pb2 import Any as GrpcAny + from dapr.clients.grpc._request import InvokeMethodRequest from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse from dapr.ext.grpc._servicer import _CallbackServicer -from dapr.proto import common_v1, appcallback_v1 - -from google.protobuf.any_pb2 import Any as GrpcAny +from dapr.proto import appcallback_v1, common_v1 class OnInvokeTests(unittest.TestCase): diff --git a/ext/dapr-ext-workflow/Dapr-Workflow-Middleware-Outbound-Hooks.md b/ext/dapr-ext-workflow/Dapr-Workflow-Middleware-Outbound-Hooks.md new file mode 100644 index 000000000..5de1a9cb0 --- /dev/null +++ b/ext/dapr-ext-workflow/Dapr-Workflow-Middleware-Outbound-Hooks.md @@ -0,0 +1,116 @@ +## Dapr Workflow Middleware: Outbound Hooks for Context Propagation + +Goal +- Add outbound hooks to Dapr Workflow middleware so adapters can inject tracing/context when scheduling activities/child workflows/signals without wrapping user code. +- Achieve parity with Temporal’s interceptor model: workflow outbound injection + activity inbound restoration. + +Why +- Current `RuntimeMiddleware` exposes only inbound lifecycle hooks (workflow/activity start/complete/error). There is no hook before scheduling activities to mutate inputs/headers. +- We currently wrap `ctx.call_activity` to inject tracing. This is effective but adapter-specific and leaks into application flow. + +Proposed API +- Extend `dapr.ext.workflow` with additional hook points (illustrative names): + +```python +class RuntimeMiddleware(Protocol): + # existing inbound hooks... + def on_workflow_start(self, ctx: Any, input: Any): ... + def on_workflow_complete(self, ctx: Any, result: Any): ... + def on_workflow_error(self, ctx: Any, error: BaseException): ... + def on_activity_start(self, ctx: Any, input: Any): ... + def on_activity_complete(self, ctx: Any, result: Any): ... + def on_activity_error(self, ctx: Any, error: BaseException): ... + + # new outbound hooks (workflow outbound) + def on_schedule_activity( + self, + ctx: Any, + activity: Callable[..., Any] | str, + input: Any, + retry_policy: Any | None, + ) -> Any: # returns possibly-modified input + """Called just before scheduling an activity. Return new input to use.""" + + def on_start_child_workflow( + self, + ctx: Any, + workflow: Callable[..., Any] | str, + input: Any, + ) -> Any: # returns possibly-modified input + """Called before starting a child workflow.""" + + def on_signal_workflow( + self, + ctx: Any, + signal_name: str, + input: Any, + ) -> Any: # returns possibly-modified input + """Called before signaling a workflow.""" +``` + +Behavior +- Hooks run within workflow sandbox; must be deterministic and side-effect free. +- The engine uses the middleware’s return value as the actual input for the scheduling call. +- If multiple middlewares are installed, chain them in order (each sees the previous result). +- If a hook raises, log and continue with the last good value (non-fatal by default). + +Reference Impl (engine changes) +- In the workflow context implementation, just before delegating to DurableTask’s schedule API: + - Call `on_schedule_activity(ctx, activity, input, retry_policy)` for each installed middleware. + - Use the returned input for the actual schedule call. + - Repeat pattern for child workflows and signals. + +Adapter usage (example) +```python +class TraceContextMiddleware(RuntimeMiddleware): + def on_schedule_activity(self, ctx, activity, input, retry_policy): + from agents_sdk.adapters.openai.tracing import serialize_trace_context + tracing = serialize_trace_context() + if input is None: + return {"tracing": tracing} + if isinstance(input, dict) and "tracing" not in input: + return {**input, "tracing": tracing} + return input + + # inbound restore already implemented via on_activity_start/complete/error +``` + +Determinism Constraints +- No non-deterministic APIs (time, random, network) inside hooks. +- Pure data transformation of provided `input` + current context-derived data already in scope. +- If adapters need time/ids, they must obtain deterministic values from the workflow context (e.g., `ctx.current_utc_datetime`). + +Error Handling Policy +- Hook exceptions are caught and logged; scheduling proceeds with the last known value. +- Optionally, add a strict mode flag at runtime init to fail-fast on hook errors. + +Testing Plan +- Unit + - on_schedule_activity merges tracing into None and dict inputs; leaves non-dict unchanged. + - Chaining: two middlewares both modify input; verify order and final payload. + - Exceptions: first middleware raises, second still runs; input falls back correctly. +- Integration (workflow sandbox) + - Workflow calls multiple activities; verify each receives tracing in input. + - Mixed: activity + sleep + when_all; ensure only activities are modified. + - Child workflow path: verify `on_start_child_workflow` injects tracing. + - Signal path: verify `on_signal_workflow` injects tracing payload. +- Determinism + - Re-run workflow with same history: identical decisions and payloads. + - Ensure no network/time/random usage in hooks; static analysis/lint rules. + +Migration for this repo +- Once the SDK exposes outbound hooks: + - Remove `wrap_ctx_inject_tracing` and `activity_restore_wrapper` from the adapter wiring. + - Keep inbound restoration in middleware only (already implemented). + - Simplify `AgentWorkflowRuntime` so it doesn’t need context wrappers. + +Open Questions +- Should hooks support header/metadata objects in addition to input payload mutation? +- Do we need an outbound hook for external events (emit) beyond signals? + +Timeline +- Week 1: Implement engine hooks + unit tests in SDK. +- Week 2: Add integration tests; update docs and examples. +- Week 3: Migrate adapters here to middleware-only; delete wrappers. + + diff --git a/ext/dapr-ext-workflow/README.rst b/ext/dapr-ext-workflow/README.rst index aa0003c6e..bc7b5d098 100644 --- a/ext/dapr-ext-workflow/README.rst +++ b/ext/dapr-ext-workflow/README.rst @@ -16,6 +16,490 @@ Installation pip install dapr-ext-workflow +Async authoring (experimental) +------------------------------ + +This package supports authoring workflows with ``async def`` in addition to the existing generator-based orchestrators. + +- Register async workflows using ``WorkflowRuntime.workflow`` (auto-detects coroutine) or ``async_workflow`` / ``register_async_workflow``. +- Use ``AsyncWorkflowContext`` for deterministic operations: + + - Activities: ``await ctx.call_activity(activity_fn, input=...)`` + - Child workflows: ``await ctx.call_child_workflow(workflow_fn, input=...)`` + - Timers: ``await ctx.create_timer(seconds|timedelta)`` + - External events: ``await ctx.wait_for_external_event(name)`` + - Concurrency: ``await ctx.when_all([...])``, ``await ctx.when_any([...])`` + - Deterministic utils: ``ctx.now()``, ``ctx.random()``, ``ctx.uuid4()``, ``ctx.new_guid()``, ``ctx.random_string(length)`` + +Interceptors (client/runtime/outbound) +-------------------------------------- + +Interceptors provide a simple, composable way to apply cross-cutting behavior with a single +enter/exit per call. There are three types: + +- Client interceptors: wrap outbound scheduling from the client (schedule_new_workflow). +- Workflow outbound interceptors: wrap calls made inside workflows (call_activity, call_child_workflow). +- Runtime interceptors: wrap inbound execution of workflows and activities (before user code). + +Use cases include context propagation, request metadata stamping, replay-aware logging, validation, +and policy enforcement. + +Response/output shaping +~~~~~~~~~~~~~~~~~~~~~~~ + +Interceptors are "around" hooks: they can shape inputs before calling ``next(...)`` and may also +shape the returned value (or map exceptions) after ``next(...)`` returns. This mirrors gRPC +interceptors and keeps the surface simple – one hook per interception point. + +- Client interceptors can transform schedule/query/signal responses. +- Runtime interceptors can transform workflow/activity results (with guardrails below). +- Workflow-outbound interceptors remain input-only to keep awaitable composition simple. + +Examples +^^^^^^^^ + +Client schedule response shaping:: + + from dapr.ext.workflow import ( + DaprWorkflowClient, ClientInterceptor, ScheduleWorkflowRequest + ) + + class ShapeId(ClientInterceptor): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + raw = next(input) + return f"tenant-A:{raw}" + + client = DaprWorkflowClient(interceptors=[ShapeId()]) + instance_id = client.schedule_new_workflow(my_workflow, input={}) + # instance_id == "tenant-A:" + +Runtime activity result shaping:: + + from dapr.ext.workflow import WorkflowRuntime, RuntimeInterceptor, ExecuteActivityRequest + + class WrapResult(RuntimeInterceptor): + def execute_activity(self, input: ExecuteActivityRequest, next): + res = next(input) + return {"value": res} + + rt = WorkflowRuntime(runtime_interceptors=[WrapResult()]) + @rt.activity + def echo(ctx, x): + return x + # echo(...) returns {"value": x} + +Determinism guardrails (workflows) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Workflow response shaping must be replay-safe: pure transforms only (no I/O, time, RNG). +- Base the transform solely on (input, metadata, original_result). Map errors to typed exceptions. +- Activities are not replayed, so result shaping may perform I/O, but keep it lightweight. + +Quick start +~~~~~~~~~~~ + +.. code-block:: python + + from __future__ import annotations + import contextvars + from typing import Any, Callable, List + + from dapr.ext.workflow import ( + WorkflowRuntime, + DaprWorkflowClient, + ClientInterceptor, + WorkflowOutboundInterceptor, + RuntimeInterceptor, + ScheduleWorkflowRequest, + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteWorkflowRequest, + ExecuteActivityRequest, + ) + + # Example: propagate a lightweight context dict through inputs + _current_ctx: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + 'wf_ctx', default=None + ) + + def set_ctx(ctx: dict[str, Any] | None): + _current_ctx.set(ctx) + + def _merge_ctx(args: Any) -> Any: + ctx = _current_ctx.get() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + # Typed payloads + class MyWorkflowInput: + def __init__(self, question: str, tags: List[str] | None = None): + self.question = question + self.tags = tags or [] + + class MyActivityInput: + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + class ContextClientInterceptor(ClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest[MyWorkflowInput], nxt: Callable[[ScheduleWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + input = ScheduleWorkflowRequest( + workflow_name=input.workflow_name, + input=_merge_ctx(input.input), + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + ) + return nxt(input) + + class ContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow(self, input: CallChildWorkflowRequest[MyWorkflowInput], nxt: Callable[[CallChildWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + return nxt(CallChildWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=_merge_ctx(input.input), + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + )) + + def call_activity(self, input: CallActivityRequest[MyActivityInput], nxt: Callable[[CallActivityRequest[MyActivityInput]], Any]) -> Any: + return nxt(CallActivityRequest[MyActivityInput]( + activity_name=input.activity_name, + input=_merge_ctx(input.input), + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + )) + + class ContextRuntimeInterceptor(RuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow(self, input: ExecuteWorkflowRequest[MyWorkflowInput], nxt: Callable[[ExecuteWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + # Restore context from input if present (no I/O, replay-safe) + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + def execute_activity(self, input: ExecuteActivityRequest[MyActivityInput], nxt: Callable[[ExecuteActivityRequest[MyActivityInput]], Any]) -> Any: + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + # Wire into client and runtime + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + +Context metadata (durable propagation) +------------------------------------- + +Interceptors support a durable context channel: + +- ``metadata``: a string-only dict that is durably persisted and propagated across workflow + boundaries (schedule, child workflows, activities). Typical use: tracing and correlation ids + (e.g., ``otel.trace_id``), tenancy, request ids. This is provider-agnostic and does not require + changes to your workflow/activities. + +How it works +~~~~~~~~~~~~ + +- Client interceptors can set ``metadata`` when scheduling a workflow or calling activities/children. +- Runtime unwraps a reserved envelope before user code runs and exposes the metadata to + ``RuntimeInterceptor`` via ``ExecuteWorkflowRequest.metadata`` / ``ExecuteActivityRequest.metadata``, + while delivering only the original payload to the user function. +- Outbound calls made inside a workflow use client interceptors; when ``metadata`` is present on the + call input, the runtime re-wraps the payload to persist and propagate it. + +Envelope (backward compatible) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Internally, the runtime persists metadata by wrapping inputs in an envelope: + +:: + + { + "__dapr_meta__": { "v": 1, "metadata": { "otel.trace_id": "abc" } }, + "__dapr_payload__": { ... original user input ... } + } + +- The runtime unwraps this automatically so user code continues to receive the exact original input + structure and types. +- The version field (``v``) is reserved for forward compatibility. + +Minimal input guidance (SDK-facing) +----------------------------------- + +- Workflow input SHOULD be JSON serializable and a preferably a single dict carried under ``ExecuteWorkflowRequest.input``. Prefer a + single object over positional ``input`` to avoid shape ambiguity and ease future evolution. This is + a recommendation for consistency and versioning; the SDK accepts any JSON-serializable input type + (dict, list, or scalar) and preserves the original shape when unwrapping the envelope. + +- For contextual data, you can use "headers" (aliases for metadata) on the workflow context: + ``set_headers``/``get_headers`` behave the same as ``set_metadata``/``get_metadata`` and are + provided for familiarity with systems that use header terminology. ``continue_as_new`` also + supports ``carryover_headers`` as an alias to ``carryover_metadata``. +- If your app needs a tracing or correlation fallback, include a small ``trace_context`` dict in + your input envelope. Interceptors should restore from ``metadata`` first (see below), then + optionally fall back to this field when present. + +Example (generic): + +.. code-block:: json + + { + "schema_version": "your-app:workflow_input@v1", + "trace_context": { "trace_id": "...", "span_id": "..." }, + "payload": { } + } + +Determinism and safety +~~~~~~~~~~~~~~~~~~~~~~ + +- In workflows, read metadata and avoid non-deterministic operations inside interceptors. Do not + perform network I/O in orchestrators. +- Activities may read/modify metadata and perform I/O inside the activity function if desired. + +Metadata persistence lifecycle +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``ctx.set_metadata()`` attaches a string-only dict to the current workflow activation. The runtime + persists it by wrapping inputs in the envelope shown above. Set metadata before yielding or + returning from an activation to ensure it is durably recorded. +- ``continue_as_new``: metadata is not implicitly carried. Use + ``ctx.continue_as_new(new_input, carryover_metadata=True)`` to carry current metadata or provide a + dict to merge/override: ``carryover_metadata={"key": "value"}``. +- Child workflows and activities: metadata is propagated when set on the outbound call input by + interceptors. If you maintain a baseline via ``ctx.set_metadata(...)``, your + ``WorkflowOutboundInterceptor`` can merge it into call-specific metadata. + +Tracing interceptors (example) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can implement tracing as interceptors that stamp/propagate IDs in ``metadata`` and suppress +spans during replay. A minimal sketch: + +.. code-block:: python + + from typing import Any, Callable + from dapr.ext.workflow import ( + BaseClientInterceptor, BaseWorkflowOutboundInterceptor, BaseRuntimeInterceptor, + WorkflowRuntime, DaprWorkflowClient, + ScheduleWorkflowRequest, CallActivityRequest, CallChildWorkflowRequest, + ExecuteWorkflowRequest, ExecuteActivityRequest, + ) + + TRACE_ID_KEY = 'otel.trace_id' + + class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(ScheduleWorkflowRequest( + workflow_name=input.workflow_name, + input=input.input, + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=md, + )) + + class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def call_activity(self, input: CallActivityRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + activity_name=input.activity_name, + input=input.input, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + def call_child_workflow(self, input: CallChildWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + workflow_name=input.workflow_name, + input=input.input, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + + class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + if not input.ctx.is_replaying: + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start workflow span here + return next(input) + def execute_activity(self, input: ExecuteActivityRequest, next): + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start activity span here + return next(input) + + rt = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor()], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(lambda: 'trace-123')], + ) + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(lambda: 'trace-123')]) + +See the full runnable example in ``ext/dapr-ext-workflow/examples/tracing_interceptors_example.py``. + +Recommended tracing restoration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Restore tracing from ``ExecuteWorkflowRequest.metadata`` first (e.g., a key like ``otel.trace_id``) + to preserve determinism and cross-activation continuity without touching user payloads. +- If no tracing metadata is present, optionally fall back to ``input.trace_context`` in your + application-defined input envelope. +- Suppress workflow spans during replay by checking ``input.ctx.is_replaying`` in runtime + interceptors. + +Engine-provided tracing +~~~~~~~~~~~~~~~~~~~~~~~ + +- When available from the runtime, use engine-provided fields surfaced on the contexts instead of + reconstructing from headers/metadata: + + - ``ctx.trace_parent`` / ``ctx.trace_state`` (and the same on ``activity_ctx``) + - ``ctx.workflow_span_id`` (identifier for the workflow span) + +- Interceptors should prefer these fields. Use headers/metadata only as a fallback or for + application-specific context. + +Execution info (minimal) and context properties +----------------------------------------------- + +``execution_info`` is now minimal and only includes the durable ``inbound_metadata`` that was +propagated into this activation. Use context properties directly for all engine fields: + +- ``ctx.trace_parent``, ``ctx.workflow_span_id``, ``ctx.workflow_attempt`` (and equivalents on the + activity context like ``ctx.attempt``). +- Manage outbound propagation via ``ctx.set_metadata(...)`` / ``ctx.get_metadata()``. The runtime + persists and propagates these values through the metadata envelope. + +Example: + +.. code-block:: python + + # In a workflow function + inbound = ctx.execution_info.inbound_metadata if ctx.execution_info else None + # Prepare outbound propagation + baseline = ctx.get_metadata() or {} + ctx.set_metadata({**baseline, 'tenant': 'acme'}) + +Notes +~~~~~ + +- User functions never see the envelope keys; they get the same input as before. +- Only string keys/values should be stored in headers/metadata; enforce size limits and redaction + policies as needed. +- With newer durabletask-python, the engine provides deterministic context fields on + ``OrchestrationContext``/``ActivityContext`` that the SDK surfaces via + ``ctx.execution_info``/``activity_ctx.execution_info``: ``workflow_name``, + ``parent_instance_id``, ``history_event_sequence``, and ``attempt``. The SDK no longer stamps + parent linkage in metadata when these are present. +- Interceptors are synchronous and must not perform I/O in orchestrators. Activities may perform + I/O inside the user function; interceptor code should remain fast and replay-safe. +- Client interceptors are applied when calling ``DaprWorkflowClient.schedule_new_workflow(...)`` and + when orchestrators call ``ctx.call_activity(...)`` or ``ctx.call_child_workflow(...)``. + + +Best-effort sandbox +~~~~~~~~~~~~~~~~~~~ + +Opt-in scoped compatibility mode maps ``asyncio.sleep``, ``random``, ``uuid.uuid4``, and ``time.time`` to deterministic equivalents during workflow execution. Use ``sandbox_mode="best_effort"`` or ``"strict"`` when registering async workflows. Strict mode blocks ``asyncio.create_task`` in orchestrators. + +Examples +~~~~~~~~ + +See ``ext/dapr-ext-workflow/examples`` for: + +- ``async_activity_sequence.py`` +- ``async_external_event.py`` +- ``async_sub_orchestrator.py`` + +Determinism and semantics +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``when_any`` losers: the first-completer result is returned; non-winning awaitables are ignored deterministically (no additional commands are emitted by the orchestrator for cancellation). This ensures replay stability. Integration behavior with the sidecar is subject to the Durable Task scheduler; the orchestrator does not actively cancel losers. +- Suspension and termination: when an instance is suspended, only new external events are buffered while replay continues to reconstruct state; async orchestrators can inspect ``ctx.is_suspended`` if exposed by the runtime. Termination completes the orchestrator with TERMINATED status and does not raise into the coroutine. End-to-end confirmation requires running against a sidecar; unit tests in this repo do not start a sidecar. + +Async patterns +~~~~~~~~~~~~~~ + +- Activities + + - Call: ``await ctx.call_activity(activity_fn, input=..., retry_policy=...)`` + - Activity functions can be ``def`` or ``async def``. When ``async def`` is used, the runtime awaits them. + +- Timers + + - Create a durable timer: ``await ctx.create_timer(seconds|timedelta)`` + +- External events + + - Wait: ``await ctx.wait_for_external_event(name)`` + - Raise (from client): ``DaprWorkflowClient.raise_workflow_event(instance_id, name, data)`` + +- Concurrency + + - All: ``results = await ctx.when_all([ ...awaitables... ])`` + - Any: ``first = await ctx.when_any([ ...awaitables... ])`` (non-winning awaitables are ignored deterministically) + +- Child workflows + + - Call: ``await ctx.call_child_workflow(workflow_fn, input=..., retry_policy=...)`` + +- Deterministic utilities + + - ``ctx.now()`` returns orchestration time from history + - ``ctx.random()`` returns a deterministic PRNG + - ``ctx.uuid4()`` returns a PRNG-derived deterministic UUID + +Runtime compatibility +--------------------- + +- ``ctx.is_suspended`` is surfaced if provided by the underlying runtime/context version; behavior may vary by Durable Task build. Integration tests that validate suspension semantics are gated behind a sidecar harness. + +when_any losers diagnostics (integration) +----------------------------------------- + +- When the sidecar exposes command diagnostics, you can assert only a single command set is emitted for a ``when_any`` (the orchestrator completes after the first winner without emitting cancels). Until then, unit tests assert single-yield behavior and README documents the expected semantics. + +Micro-bench guidance +-------------------- + +- The coroutine-to-generator driver yields at each deterministic suspension point and avoids polling. In practice, overhead vs. generator orchestrators is negligible relative to activity I/O. To measure locally: + + - Create paired generator/async orchestrators that call N no-op activities and 1 timer. + - Drive them against a local sidecar and compare wall-clock per activation and total completion time. + - Ensure identical history/inputs; differences should be within noise vs. activity latency. + +Notes +----- + +- Orchestrators authored as ``async def`` are not driven by a global event loop you start. The Durable Task worker drives them via a coroutine-to-generator bridge; do not call ``asyncio.run`` around orchestrators. +- Use ``WorkflowRuntime.workflow`` with an ``async def`` (auto-detected) or ``WorkflowRuntime.async_workflow`` to register async orchestrators. + +Why async without an event loop? +-------------------------------- + +- Each ``await`` in an async orchestrator corresponds to a deterministic Durable Task decision (activity, timer, external event, ``when_all/any``). The worker advances the coroutine by sending results/exceptions back in, preserving replay and ordering. +- This gives you the readability and structure of ``async/await`` while enforcing workflow determinism (no ad-hoc I/O in orchestrators; all I/O happens in activities). +- The pattern follows other workflow engines (e.g., Durable Functions/Temporal): async authoring for clarity, runtime-driven scheduling for correctness. + References ---------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py index f78615112..b6a75e472 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py @@ -14,17 +14,48 @@ """ # Import your main classes here -from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name +from dapr.ext.workflow.aio import AsyncWorkflowContext from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, when_all, when_any +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ClientInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, +) +from dapr.ext.workflow.retry_policy import RetryPolicy +from dapr.ext.workflow.serializers import ( + ActivityIOAdapter, + CanonicalSerializable, + GenericSerializer, + ensure_canonical_json, + get_activity_adapter, + get_serializer, + register_activity_adapter, + register_serializer, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name from dapr.ext.workflow.workflow_state import WorkflowState, WorkflowStatus -from dapr.ext.workflow.retry_policy import RetryPolicy __all__ = [ 'WorkflowRuntime', 'DaprWorkflowClient', 'DaprWorkflowContext', + 'AsyncWorkflowContext', 'WorkflowActivityContext', 'WorkflowState', 'WorkflowStatus', @@ -32,4 +63,32 @@ 'when_any', 'alternate_name', 'RetryPolicy', + # interceptors + 'ClientInterceptor', + 'BaseClientInterceptor', + 'WorkflowOutboundInterceptor', + 'BaseWorkflowOutboundInterceptor', + 'RuntimeInterceptor', + 'BaseRuntimeInterceptor', + 'ScheduleWorkflowRequest', + 'CallChildWorkflowRequest', + 'CallActivityRequest', + 'ExecuteWorkflowRequest', + 'ExecuteActivityRequest', + 'compose_workflow_outbound_chain', + 'compose_runtime_chain', + 'WorkflowExecutionInfo', + 'ActivityExecutionInfo', + # serializers + 'CanonicalSerializable', + 'GenericSerializer', + 'ActivityIOAdapter', + 'ensure_canonical_json', + 'register_serializer', + 'get_serializer', + 'register_activity_adapter', + 'get_activity_adapter', + 'use_activity_adapter', + 'serialize_activity_input', + 'serialize_activity_output', ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py new file mode 100644 index 000000000..a195bc0aa --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py @@ -0,0 +1,43 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Note: Do not import WorkflowRuntime here to avoid circular imports +# Re-export async context and awaitables +from .async_context import AsyncWorkflowContext # noqa: F401 +from .async_driver import CoroutineOrchestratorRunner # noqa: F401 +from .awaitables import ( # noqa: F401 + ActivityAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async I/O surface for Dapr Workflow extension. + +This package provides explicit async-focused imports that mirror the top-level +exports, improving discoverability and aligning with dapr.aio patterns. +""" + +__all__ = [ + 'AsyncWorkflowContext', + 'CoroutineOrchestratorRunner', + 'ActivityAwaitable', + 'SubOrchestratorAwaitable', + 'SleepAwaitable', + 'WhenAllAwaitable', + 'WhenAnyAwaitable', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py new file mode 100644 index 000000000..ec68dc1cc --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py @@ -0,0 +1,189 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Awaitable, Callable, Sequence + +from durabletask import task +from durabletask.aio.awaitables import gather as _dt_gather # type: ignore[import-not-found] +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) + +from .awaitables import ( + ActivityAwaitable, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async workflow context that exposes deterministic awaitables for activities, timers, +external events, and concurrency, along with deterministic utilities. +""" + + +class AsyncWorkflowContext(DeterministicContextMixin): + def __init__(self, base_ctx: task.OrchestrationContext): + self._base_ctx = base_ctx + + # Core workflow metadata parity with sync context + @property + def instance_id(self) -> str: + return self._base_ctx.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._base_ctx.current_utc_datetime + + # Activities & Sub-orchestrations + def call_activity( + self, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return ActivityAwaitable( + self._base_ctx, activity_fn, input=input, retry_policy=retry_policy, metadata=metadata + ) + + def call_child_workflow( + self, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return SubOrchestratorAwaitable( + self._base_ctx, + workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + metadata=metadata, + ) + + @property + def is_replaying(self) -> bool: + return self._base_ctx.is_replaying + + # Tracing (engine-provided) pass-throughs when available + @property + def trace_parent(self) -> str | None: + return self._base_ctx.trace_parent + + @property + def trace_state(self) -> str | None: + return self._base_ctx.trace_state + + @property + def workflow_span_id(self) -> str | None: + return self._base_ctx.orchestration_span_id + + @property + def workflow_attempt(self) -> int | None: + getter = getattr(self._base_ctx, 'workflow_attempt', None) + return ( + getter + if isinstance(getter, int) or getter is None + else getattr(self._base_ctx, 'workflow_attempt', None) + ) + + # Timers & Events + def create_timer(self, fire_at: float | timedelta | datetime) -> Awaitable[None]: + # If float provided, interpret as seconds + if isinstance(fire_at, (int, float)): + fire_at = timedelta(seconds=float(fire_at)) + return SleepAwaitable(self._base_ctx, fire_at) + + def sleep(self, duration: float | timedelta | datetime) -> Awaitable[None]: + return self.create_timer(duration) + + def wait_for_external_event(self, name: str) -> Awaitable[Any]: + return ExternalEventAwaitable(self._base_ctx, name) + + # Concurrency + def when_all(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[list[Any]]: + return WhenAllAwaitable(awaitables) + + def when_any(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[Any]: + return WhenAnyAwaitable(awaitables) + + def gather(self, *aws: Awaitable[Any], return_exceptions: bool = False) -> Awaitable[list[Any]]: + return _dt_gather(*aws, return_exceptions=return_exceptions) + + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + @property + def is_suspended(self) -> bool: + # Placeholder; will be wired when Durable Task exposes this state in context + return self._base_ctx.is_suspended + + # Pass-throughs for completeness + def set_custom_status(self, custom_status: str) -> None: + if hasattr(self._base_ctx, 'set_custom_status'): + self._base_ctx.set_custom_status(custom_status) + + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + # Try extended signature; fall back to minimal for older fakes/contexts + try: + self._base_ctx.continue_as_new( + new_input, save_events=save_events, carryover_metadata=effective_carryover + ) + except TypeError: + self._base_ctx.continue_as_new(new_input, save_events=save_events) + + # Metadata parity + def set_metadata(self, metadata: dict[str, str] | None) -> None: + setter = getattr(self._base_ctx, 'set_metadata', None) + if callable(setter): + setter(metadata) + + def get_metadata(self) -> dict[str, str] | None: + getter = getattr(self._base_ctx, 'get_metadata', None) + return getter() if callable(getter) else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + + # Execution info parity + @property + def execution_info(self): # type: ignore[override] + return getattr(self._base_ctx, 'execution_info', None) + + +__all__ = [ + 'AsyncWorkflowContext', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py new file mode 100644 index 000000000..7d964174c --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py @@ -0,0 +1,95 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Generator, Optional + +from durabletask import task +from durabletask.aio.sandbox import SandboxMode, sandbox_scope + + +class CoroutineOrchestratorRunner: + """Wraps an async orchestrator into a generator-compatible runner.""" + + def __init__( + self, + async_orchestrator: Callable[..., Awaitable[Any]], + *, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + self._async_orchestrator = async_orchestrator + self._sandbox_mode = sandbox_mode + + def to_generator( + self, async_ctx: Any, input_data: Optional[Any] + ) -> Generator[task.Task, Any, Any]: + # Instantiate the coroutine with or without input depending on signature/usage + try: + if input_data is None: + coro = self._async_orchestrator(async_ctx) + else: + coro = self._async_orchestrator(async_ctx, input_data) + except TypeError: + # Fallback for orchestrators that only accept a single ctx arg + coro = self._async_orchestrator(async_ctx) + + # Prime the coroutine + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(None) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(None) + except StopIteration as stop: + return stop.value # type: ignore[misc] + + # Drive the coroutine by yielding the underlying Durable Task(s) + while True: + try: + result = yield awaited + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(result) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(result) + except StopIteration as stop: + return stop.value + except Exception as exc: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(exc) + except StopIteration as stop: + return stop.value + except BaseException as base_exc: + # Handle cancellation that may not derive from Exception in some environments + try: + import asyncio as _asyncio # local import to avoid hard dep at module import time + + is_cancel = isinstance(base_exc, _asyncio.CancelledError) + except Exception: + is_cancel = False + if is_cancel: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(base_exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(base_exc) + except StopIteration as stop: + return stop.value + continue + raise diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py new file mode 100644 index 000000000..2946c53a9 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py @@ -0,0 +1,124 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Callable, Iterable + +from durabletask import task +from durabletask.aio.awaitables import ( + AwaitableBase as _BaseAwaitable, # type: ignore[import-not-found] +) +from durabletask.aio.awaitables import ( + ExternalEventAwaitable as _DTExternalEventAwaitable, +) +from durabletask.aio.awaitables import ( + SleepAwaitable as _DTSleepAwaitable, +) +from durabletask.aio.awaitables import ( + WhenAllAwaitable as _DTWhenAllAwaitable, +) + +AwaitableBase = _BaseAwaitable + + +class ActivityAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ): + self._ctx = ctx + self._activity_fn = activity_fn + self._input = input + self._retry_policy = retry_policy + self._metadata = metadata + + def _to_task(self) -> task.Task: + if self._retry_policy is None: + return self._ctx.call_activity( + self._activity_fn, input=self._input, metadata=self._metadata + ) + return self._ctx.call_activity( + self._activity_fn, + input=self._input, + retry_policy=self._retry_policy, + metadata=self._metadata, + ) + + +class SubOrchestratorAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ): + self._ctx = ctx + self._workflow_fn = workflow_fn + self._input = input + self._instance_id = instance_id + self._retry_policy = retry_policy + self._metadata = metadata + + def _to_task(self) -> task.Task: + if self._retry_policy is None: + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + metadata=self._metadata, + ) + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + retry_policy=self._retry_policy, + metadata=self._metadata, + ) + + +class SleepAwaitable(_DTSleepAwaitable): + pass + + +class ExternalEventAwaitable(_DTExternalEventAwaitable): + pass + + +class WhenAllAwaitable(_DTWhenAllAwaitable): + pass + + +class WhenAnyAwaitable(AwaitableBase): + def __init__(self, tasks_like: Iterable[AwaitableBase | task.Task]): + self._tasks_like = list(tasks_like) + + def _to_task(self) -> task.Task: + underlying: list[task.Task] = [] + for a in self._tasks_like: + if isinstance(a, AwaitableBase): + underlying.append(a._to_task()) # type: ignore[attr-defined] + elif isinstance(a, task.Task): + underlying.append(a) + else: + raise TypeError('when_any expects AwaitableBase or durabletask.task.Task') + return task.when_any(underlying) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py new file mode 100644 index 000000000..3d887e17b --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py @@ -0,0 +1,233 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio as _asyncio +import random as _random +import time as _time +import uuid as _uuid +from contextlib import ContextDecorator +from typing import Any + +from durabletask.aio.sandbox import SandboxMode +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + +""" +Scoped sandbox patching for async workflows (best-effort, strict). +""" + + +def _ctx_instance_id(async_ctx: Any) -> str: + if hasattr(async_ctx, 'instance_id'): + return getattr(async_ctx, 'instance_id') + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'instance_id'): + return async_ctx._base_ctx.instance_id + return '' + + +def _ctx_now(async_ctx: Any): + if hasattr(async_ctx, 'now'): + try: + return async_ctx.now() + except Exception: + pass + if hasattr(async_ctx, 'current_utc_datetime'): + return async_ctx.current_utc_datetime + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'current_utc_datetime'): + return async_ctx._base_ctx.current_utc_datetime + import datetime as _dt + + return _dt.datetime.utcfromtimestamp(0) + + +class _Sandbox(ContextDecorator): + def __init__(self, async_ctx: Any, mode: str): + self._async_ctx = async_ctx + self._mode = mode + self._saved: dict[str, Any] = {} + + def __enter__(self): + self._saved['asyncio.sleep'] = _asyncio.sleep + self._saved['asyncio.gather'] = getattr(_asyncio, 'gather', None) + self._saved['asyncio.create_task'] = getattr(_asyncio, 'create_task', None) + self._saved['random.random'] = _random.random + self._saved['random.randrange'] = _random.randrange + self._saved['random.randint'] = _random.randint + self._saved['uuid.uuid4'] = _uuid.uuid4 + self._saved['time.time'] = _time.time + self._saved['time.time_ns'] = getattr(_time, 'time_ns', None) + + rnd = deterministic_random(_ctx_instance_id(self._async_ctx), _ctx_now(self._async_ctx)) + + async def _sleep_patched(delay: float, result: Any = None): # type: ignore[override] + try: + if float(delay) <= 0: + return await self._saved['asyncio.sleep'](0) + except Exception: + return await self._saved['asyncio.sleep'](delay) # type: ignore[arg-type] + + await self._async_ctx.sleep(delay) + return result + + def _random_patched() -> float: + return rnd.random() + + def _randrange_patched(start, stop=None, step=1): + return rnd.randrange(start, stop, step) if stop is not None else rnd.randrange(start) + + def _randint_patched(a, b): + return rnd.randint(a, b) + + def _uuid4_patched(): + return deterministic_uuid4(rnd) + + def _time_patched() -> float: + return float(_ctx_now(self._async_ctx).timestamp()) + + def _time_ns_patched() -> int: + return int(_ctx_now(self._async_ctx).timestamp() * 1_000_000_000) + + def _create_task_blocked(coro, *args, **kwargs): + try: + close = getattr(coro, 'close', None) + if callable(close): + try: + close() + except Exception: + pass + finally: + raise RuntimeError( + 'asyncio.create_task is not allowed inside workflow (strict mode)' + ) + + def _is_workflow_awaitable(obj: Any) -> bool: + try: + if hasattr(obj, '_to_dapr_task') or hasattr(obj, '_to_task'): + return True + except Exception: + pass + try: + from durabletask import task as _dt + + if isinstance(obj, _dt.Task): + return True + except Exception: + pass + return False + + class _OneShot: + def __init__(self, factory): + self._factory = factory + self._done = False + self._res: Any = None + self._exc: BaseException | None = None + + def __await__(self): # type: ignore[override] + if self._done: + + async def _replay(): + if self._exc is not None: + raise self._exc + return self._res + + return _replay().__await__() + + async def _compute(): + try: + out = await self._factory() + self._res = out + self._done = True + return out + except BaseException as e: # noqa: BLE001 + self._exc = e + self._done = True + raise + + return _compute().__await__() + + def _patched_gather(*aws: Any, return_exceptions: bool = False): # type: ignore[override] + if not aws: + + async def _empty(): + return [] + + return _OneShot(_empty) + + if all(_is_workflow_awaitable(a) for a in aws): + + async def _await_when_all(): + from dapr.ext.workflow.aio.awaitables import WhenAllAwaitable # local import + + combined = WhenAllAwaitable(list(aws)) + return await combined + + return _OneShot(_await_when_all) + + async def _run_mixed(): + results = [] + for a in aws: + try: + results.append(await a) + except Exception as e: # noqa: BLE001 + if return_exceptions: + results.append(e) + else: + raise + return results + + return _OneShot(_run_mixed) + + _asyncio.sleep = _sleep_patched # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = _patched_gather # type: ignore[assignment] + _random.random = _random_patched # type: ignore[assignment] + _random.randrange = _randrange_patched # type: ignore[assignment] + _random.randint = _randint_patched # type: ignore[assignment] + _uuid.uuid4 = _uuid4_patched # type: ignore[assignment] + _time.time = _time_patched # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = _time_ns_patched # type: ignore[assignment] + if self._mode == 'strict' and self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = _create_task_blocked # type: ignore[assignment] + + return self + + def __exit__(self, exc_type, exc, tb): + _asyncio.sleep = self._saved['asyncio.sleep'] # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = self._saved['asyncio.gather'] # type: ignore[assignment] + if self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = self._saved['asyncio.create_task'] # type: ignore[assignment] + _random.random = self._saved['random.random'] # type: ignore[assignment] + _random.randrange = self._saved['random.randrange'] # type: ignore[assignment] + _random.randint = self._saved['random.randint'] # type: ignore[assignment] + _uuid.uuid4 = self._saved['uuid.uuid4'] # type: ignore[assignment] + _time.time = self._saved['time.time'] # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = self._saved['time.time_ns'] # type: ignore[assignment] + return False + + +def sandbox_scope(async_ctx: Any, mode: SandboxMode): + if mode == SandboxMode.OFF: + + class _Null(ContextDecorator): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + return _Null() + return _Sandbox(async_ctx, 'strict' if mode == SandboxMode.STRICT else 'best_effort') diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index cc384503a..5ca13d999 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -14,23 +14,28 @@ """ from __future__ import annotations -from datetime import datetime -from typing import Any, Optional, TypeVar +from datetime import datetime +from typing import Any, List, Optional, TypeVar -from durabletask import client import durabletask.internal.orchestrator_service_pb2 as pb - -from dapr.ext.workflow.workflow_state import WorkflowState -from dapr.ext.workflow.workflow_context import Workflow -from dapr.ext.workflow.util import getAddress +from durabletask import client from grpc import RpcError from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings -from dapr.conf.helpers import GrpcEndpoint -from dapr.ext.workflow.logger import LoggerOptions, Logger +from dapr.conf.helpers import GrpcEndpoint, build_grpc_channel_options +from dapr.ext.workflow.interceptors import ( + ClientInterceptor, + ScheduleWorkflowRequest, + compose_client_chain, + wrap_payload_with_metadata, +) +from dapr.ext.workflow.logger import Logger, LoggerOptions +from dapr.ext.workflow.util import getAddress +from dapr.ext.workflow.workflow_context import Workflow +from dapr.ext.workflow.workflow_state import WorkflowState T = TypeVar('T') TInput = TypeVar('TInput') @@ -52,6 +57,8 @@ def __init__( host: Optional[str] = None, port: Optional[str] = None, logger_options: Optional[LoggerOptions] = None, + *, + interceptors: Optional[List[ClientInterceptor]] = None, ): address = getAddress(host, port) @@ -62,18 +69,31 @@ def __init__( self._logger = Logger('DaprWorkflowClient', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) options = self._logger.get_options() + # Optional gRPC channel options (keepalive, retry policy) via helpers + channel_options = build_grpc_channel_options() + + # Construct base kwargs for TaskHubGrpcClient + base_kwargs = { + 'host_address': uri.endpoint, + 'metadata': metadata, + 'secure_channel': uri.tls, + 'log_handler': options.log_handler, + 'log_formatter': options.log_formatter, + } + + # Initialize TaskHubGrpcClient (DurableTask supports options) self.__obj = client.TaskHubGrpcClient( - host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter, + **base_kwargs, + options=channel_options, ) + # Interceptors + self._client_interceptors: List[ClientInterceptor] = list(interceptors or []) + def schedule_new_workflow( self, workflow: Workflow, @@ -82,6 +102,7 @@ def schedule_new_workflow( instance_id: Optional[str] = None, start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + metadata: Optional[dict[str, str]] = None, ) -> str: """Schedules a new workflow instance for execution. @@ -100,21 +121,33 @@ def schedule_new_workflow( Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): + wf_name = ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + + # Build interceptor chain around schedule call + def terminal(term_req: ScheduleWorkflowRequest) -> str: + payload = wrap_payload_with_metadata(term_req.input, term_req.metadata) return self.__obj.schedule_new_orchestration( - workflow.__dict__['_dapr_alternate_name'], - input=input, - instance_id=instance_id, - start_at=start_at, - reuse_id_policy=reuse_id_policy, + term_req.workflow_name, + input=payload, + instance_id=term_req.instance_id, + start_at=term_req.start_at, + reuse_id_policy=term_req.reuse_id_policy, ) - return self.__obj.schedule_new_orchestration( - workflow.__name__, + + chain = compose_client_chain(self._client_interceptors, terminal) + schedule_req = ScheduleWorkflowRequest( + workflow_name=wf_name, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, + metadata=metadata, ) + return chain(schedule_req) def get_workflow_state( self, instance_id: str, *, fetch_payloads: bool = True diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index 2dee46fe2..56fe2acdb 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,29 +11,62 @@ limitations under the License. """ -from typing import Any, Callable, List, Optional, TypeVar, Union +import enum from datetime import datetime, timedelta +from typing import Any, Callable, List, Optional, TypeVar, Union from durabletask import task +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) -from dapr.ext.workflow.workflow_context import WorkflowContext, Workflow -from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext -from dapr.ext.workflow.logger import LoggerOptions, Logger +from dapr.ext.workflow.execution_info import WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import unwrap_payload_with_metadata, wrap_payload_with_metadata +from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.retry_policy import RetryPolicy +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_context import Workflow, WorkflowContext T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') -class DaprWorkflowContext(WorkflowContext): - """DaprWorkflowContext that provides proxy access to internal OrchestrationContext instance.""" +class Handlers(enum.Enum): + CALL_ACTIVITY = 'call_activity' + CALL_CHILD_WORKFLOW = 'call_child_workflow' + CONTINUE_AS_NEW = 'continue_as_new' + + +class DaprWorkflowContext(WorkflowContext, DeterministicContextMixin): + """Workflow context wrapper with deterministic utilities and metadata helpers. + + Purpose + ------- + - Proxy to the underlying ``durabletask.task.OrchestrationContext`` (engine fields like + ``trace_parent``, ``orchestration_span_id``, and ``workflow_attempt`` pass through). + - Provide SDK-level helpers for durable metadata propagation via interceptors. + - Expose ``execution_info`` as a per-activation snapshot complementing live properties. + + Tips + ---- + - Use ``ctx.get_metadata()/set_metadata()`` to manage outbound propagation. + - Use ``ctx.execution_info.inbound_metadata`` to inspect what arrived on this activation. + - Prefer engine-backed properties for tracing/attempts when available (not yet available in dapr sidecar); fall back to + metadata only for app-specific context. + """ def __init__( - self, ctx: task.OrchestrationContext, logger_options: Optional[LoggerOptions] = None + self, + ctx: task.OrchestrationContext, + logger_options: Optional[LoggerOptions] = None, + *, + outbound_handlers: Optional[dict[Handlers, Any]] = None, ): self.__obj = ctx self._logger = Logger('DaprWorkflowContext', logger_options) + self._outbound_handlers = outbound_handlers or {} + self._metadata: dict[str, str] | None = None # provide proxy access to regular attributes of wrapped object def __getattr__(self, name): @@ -53,10 +84,53 @@ def current_utc_datetime(self) -> datetime: def is_replaying(self) -> bool: return self.__obj.is_replaying + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + # Tracing (engine-provided) pass-throughs when available + @property + def trace_parent(self) -> str | None: + return self.__obj.trace_parent + + @property + def trace_state(self) -> str | None: + return self.__obj.trace_state + + @property + def workflow_span_id(self) -> str | None: + # provided by durabletask; naming aligned to workflow + return self.__obj.orchestration_span_id + + @property + def workflow_attempt(self) -> int | None: + # Provided by durabletask when available (e.g., sub-orchestrator retry attempt) + return getattr(self.__obj, 'workflow_attempt', None) + + # Metadata API + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + def set_custom_status(self, custom_status: str) -> None: self._logger.debug(f'{self.instance_id}: Setting custom status to {custom_status}') self.__obj.set_custom_status(custom_status) + # Execution info (populated by runtime when available) + @property + def execution_info(self) -> WorkflowExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: WorkflowExecutionInfo) -> None: + self._execution_info = info + def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: self._logger.debug(f'{self.instance_id}: Creating timer to fire at {fire_at} time') return self.__obj.create_timer(fire_at) @@ -67,6 +141,7 @@ def call_activity( *, input: TInput = None, retry_policy: Optional[RetryPolicy] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: self._logger.debug(f'{self.instance_id}: Creating activity {activity.__name__}') if hasattr(activity, '_dapr_alternate_name'): @@ -74,9 +149,19 @@ def call_activity( else: # this case should ideally never happen act = activity.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_ACTIVITY in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_ACTIVITY] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_ACTIVITY]( + self, activity, input, retry_policy, metadata or self.get_metadata() + ) if retry_policy is None: - return self.__obj.call_activity(activity=act, input=input) - return self.__obj.call_activity(activity=act, input=input, retry_policy=retry_policy.obj) + return self.__obj.call_activity(activity=act, input=transformed_input) + return self.__obj.call_activity( + activity=act, input=transformed_input, retry_policy=retry_policy.obj + ) def call_child_workflow( self, @@ -85,12 +170,13 @@ def call_child_workflow( input: Optional[TInput] = None, instance_id: Optional[str] = None, retry_policy: Optional[RetryPolicy] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') def wf(ctx: task.OrchestrationContext, inp: TInput): - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - return workflow(daprWfContext, inp) + dapr_wf_context = DaprWorkflowContext(ctx, self._logger.get_options()) + return workflow(dapr_wf_context, inp) # copy workflow name so durabletask.worker can find the orchestrator in its registry @@ -99,22 +185,67 @@ def wf(ctx: task.OrchestrationContext, inp: TInput): else: # this case should ideally never happen wf.__name__ = workflow.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_CHILD_WORKFLOW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW]( + self, workflow, input, metadata or self.get_metadata() + ) if retry_policy is None: - return self.__obj.call_sub_orchestrator(wf, input=input, instance_id=instance_id) + return self.__obj.call_sub_orchestrator( + wf, input=transformed_input, instance_id=instance_id + ) return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, retry_policy=retry_policy.obj + wf, input=transformed_input, instance_id=instance_id, retry_policy=retry_policy.obj ) def wait_for_external_event(self, name: str) -> task.Task: self._logger.debug(f'{self.instance_id}: Waiting for external event {name}') return self.__obj.wait_for_external_event(name) - def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: self._logger.debug(f'{self.instance_id}: Continuing as new') - self.__obj.continue_as_new(new_input, save_events=save_events) + # Allow workflow outbound interceptors (wired via runtime) to modify payload/metadata + transformed_input: Any = new_input + if Handlers.CONTINUE_AS_NEW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CONTINUE_AS_NEW] + ): + transformed_input = self._outbound_handlers[Handlers.CONTINUE_AS_NEW]( + self, new_input, self.get_metadata() + ) + + # Merge/carry metadata if requested, unwrapping any envelope produced by interceptors + payload, base_md = unwrap_payload_with_metadata(transformed_input) + # Start with current context metadata; then layer any interceptor-provided metadata on top + current_md = self.get_metadata() or {} + effective_md = {**current_md, **(base_md or {})} + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + if effective_carryover: + base = effective_md or {} + if isinstance(effective_carryover, dict): + md = {**base, **effective_carryover} + else: + md = base + payload = wrap_payload_with_metadata(payload, md) + else: + # If we had metadata from interceptors or context, preserve it + if effective_md: + payload = wrap_payload_with_metadata(payload, effective_md) + self.__obj.continue_as_new(payload, save_events=save_events) -def when_all(tasks: List[task.Task[T]]) -> task.WhenAllTask[T]: +def when_all(tasks: List[task.Task]) -> task.WhenAllTask: """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail.""" return task.when_all(tasks) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py new file mode 100644 index 000000000..d33a02c60 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py @@ -0,0 +1,27 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Backward-compatible shim: import deterministic utilities from durabletask +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, + deterministic_random, + deterministic_uuid4, +) + +__all__ = [ + 'DeterministicContextMixin', + 'deterministic_random', + 'deterministic_uuid4', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py new file mode 100644 index 000000000..0aacd7106 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py @@ -0,0 +1,49 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +""" +Minimal, deterministic snapshots of inbound durable metadata. + +Rationale +--------- + +Execution info previously mirrored many engine fields (IDs, tracing, attempts) already +available on the workflow/activity contexts. To remove redundancy and simplify usage, the +execution info types now only capture the durable ``inbound_metadata`` that was actually +propagated into this activation. Use context properties directly for engine fields. +""" + + +@dataclass +class WorkflowExecutionInfo: + """Per-activation snapshot for workflows. + + Only includes ``inbound_metadata`` that arrived with this activation. + """ + + inbound_metadata: dict[str, str] + + +@dataclass +class ActivityExecutionInfo: + """Per-activation snapshot for activities. + + Only includes ``inbound_metadata`` that arrived with this activity invocation. + """ + + inbound_metadata: dict[str, str] + activity_name: str diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py new file mode 100644 index 000000000..0af8df22e --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py @@ -0,0 +1,420 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Protocol, TypeVar + +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_context import WorkflowContext + +# Type variables for generic interceptor payload typing +TInput = TypeVar('TInput') +TWorkflowInput = TypeVar('TWorkflowInput') +TActivityInput = TypeVar('TActivityInput') + +""" +Interceptor interfaces and chain utilities for the Dapr Workflow SDK. + +Providing a single enter/exit around calls. + +IMPORTANT: Generator wrappers for async workflows +-------------------------------------------------- +When writing runtime interceptors that touch workflow execution, be careful with generator +handling. If an interceptor obtains a workflow generator from user code (e.g., an async +orchestrator adapted into a generator) it must not manually iterate it using a for-loop +and yield the produced items. Doing so breaks send()/throw() propagation back into the +inner generator, which can cause resumed results from the durable runtime to be dropped +and appear as None to awaiters. + +Best practices: +- If the interceptor participates in composition and needs to return the generator, + return it directly (do not iterate it). +- If the interceptor must wrap the generator, always use "yield from inner_gen" so that + send()/throw() are forwarded correctly. + +Context managers with async workflows +-------------------------------------- +When using context managers (like ExitStack, logging contexts, or trace contexts) in an +interceptor for async workflows, be aware that calling `next(input)` returns a generator +object immediately, NOT the final result. The generator executes later when the durable +task runtime drives it. + +If you need a context manager to remain active during the workflow execution: + +**WRONG - Context exits before workflow runs:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + with setup_context(): + return next(input) # Returns generator, context exits immediately! + +**CORRECT - Context stays active throughout execution:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with setup_context(): + gen = next(input) + yield from gen # Keep context alive while generator executes + return wrapper() + +For more complex scenarios with ExitStack or async context managers, wrap the generator +with `yield from` to ensure your context spans the entire workflow execution, including +all replay and continuation events. + +Example with ExitStack: + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with ExitStack() as stack: + # Set up contexts (trace, logging, etc.) + stack.enter_context(trace_context(...)) + stack.enter_context(logging_context(...)) + + # Get the generator from the next interceptor/handler + gen = next(input) + + # Keep contexts alive while generator executes + yield from gen + return wrapper() + +This pattern ensures your context manager remains active during: +- Initial workflow execution +- Replays from durable state +- Continuation after awaits +- Activity calls and child workflow invocations +""" + + +# Context metadata propagation +# ---------------------------- +# "metadata" is a durable, string-only map. It is serialized on the wire and propagates across +# boundaries (client → runtime → activity/child), surviving replays/retries. Use it when downstream +# components must observe the value. In-process ephemeral state should be handled within interceptors +# without attempting to propagate across process boundaries. + + +# ------------------------------ +# Client-side interceptor surface +# ------------------------------ + + +@dataclass +class ScheduleWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + start_at: Any | None + reuse_id_policy: Any | None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallChildWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class ContinueAsNewRequest(Generic[TInput]): + input: TInput + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallActivityRequest(Generic[TInput]): + activity_name: str + input: TInput + retry_policy: Any | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +class ClientInterceptor(Protocol, Generic[TInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: + ... + + +# ------------------------------- +# Runtime-side interceptor surface +# ------------------------------- + + +@dataclass +class ExecuteWorkflowRequest(Generic[TInput]): + ctx: WorkflowContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +@dataclass +class ExecuteActivityRequest(Generic[TInput]): + ctx: WorkflowActivityContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +class RuntimeInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + ... + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: + ... + + +# ------------------------------ +# Convenience base classes (devex) +# ------------------------------ + + +class BaseClientInterceptor(Generic[TInput]): + """Subclass this to get method name completion and safe defaults. + + Override any of the methods to customize behavior. By default, these + methods simply call `next` unchanged. + """ + + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + # No workflow-outbound methods here; use WorkflowOutboundInterceptor for those + + +class BaseRuntimeInterceptor(Generic[TWorkflowInput, TActivityInput]): + """Subclass this to get method name completion and safe defaults.""" + + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + +# ------------------------------ +# Helper: chain composition +# ------------------------------ + + +def compose_client_chain( + interceptors: list[ClientInterceptor], terminal: Callable[[Any], Any] +) -> Callable[[Any], Any]: + """Compose client interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., scheduling the workflow) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: ClientInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ScheduleWorkflowRequest): + return curr_icpt.schedule_new_workflow(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Workflow outbound interceptor surface +# ------------------------------ + + +class WorkflowOutboundInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + ... + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: + ... + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: + ... + + +class BaseWorkflowOutboundInterceptor(Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: + return next(input) + + +# ------------------------------ +# Backward-compat typing aliases +# ------------------------------ + + +def compose_workflow_outbound_chain( + interceptors: list[WorkflowOutboundInterceptor], + terminal: Callable[[Any], Any], +) -> Callable[[Any], Any]: + """Compose workflow outbound interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., preparing outbound call args) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: WorkflowOutboundInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + # Dispatch to the appropriate outbound method on the interceptor + if isinstance(input, CallActivityRequest): + return curr_icpt.call_activity(input, nxt) + if isinstance(input, CallChildWorkflowRequest): + return curr_icpt.call_child_workflow(input, nxt) + if isinstance(input, ContinueAsNewRequest): + return curr_icpt.continue_as_new(input, nxt) + # Fallback to next if input type unknown + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Helper: envelope for durable metadata +# ------------------------------ + +_META_KEY = '__dapr_meta__' +_META_VERSION = 1 +_PAYLOAD_KEY = '__dapr_payload__' + + +def wrap_payload_with_metadata(payload: Any, metadata: dict[str, str] | None) -> Any: + """If metadata is provided and non-empty, wrap payload in an envelope for persistence. + + Backward compatible: if metadata is falsy, return payload unchanged. + """ + if metadata: + return { + _META_KEY: { + 'v': _META_VERSION, + 'metadata': metadata, + }, + _PAYLOAD_KEY: payload, + } + return payload + + +def unwrap_payload_with_metadata(obj: Any) -> tuple[Any, dict[str, str] | None]: + """Extract payload and metadata from envelope if present. + + Returns (payload, metadata_dict_or_none). + """ + try: + if isinstance(obj, dict) and _META_KEY in obj and _PAYLOAD_KEY in obj: + meta = obj.get(_META_KEY) or {} + md = meta.get('metadata') if isinstance(meta, dict) else None + return obj.get(_PAYLOAD_KEY), md if isinstance(md, dict) else None + except Exception: + # Be robust: on any error, treat as raw payload + pass + return obj, None + + +def compose_runtime_chain( + interceptors: list[RuntimeInterceptor], final_handler: Callable[[Any], Any] +): + """Compose runtime interceptors into a single callable (synchronous). + + The ``final_handler`` callable is the final handler invoked after all interceptors; it + performs the core operation (e.g., calling user workflow/activity or returning a + workflow generator) when the chain ends. + """ + next_fn = final_handler + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: RuntimeInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ExecuteWorkflowRequest): + return curr_icpt.execute_workflow(input, nxt) + if isinstance(input, ExecuteActivityRequest): + return curr_icpt.execute_activity(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py index 5583bde7e..b63a763bd 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py @@ -1,4 +1,4 @@ -from dapr.ext.workflow.logger.options import LoggerOptions from dapr.ext.workflow.logger.logger import Logger +from dapr.ext.workflow.logger.options import LoggerOptions __all__ = ['LoggerOptions', 'Logger'] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py index 6b0f3fec4..b93e7074f 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py @@ -1,5 +1,6 @@ import logging from typing import Union + from dapr.ext.workflow.logger.options import LoggerOptions diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py index 0be44c52b..15cee8cc3 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py @@ -13,8 +13,8 @@ limitations under the License. """ -from typing import Union import logging +from typing import Union class LoggerOptions: diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py b/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py index af1f5ea9e..aa12f479d 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py @@ -13,8 +13,8 @@ limitations under the License. """ -from typing import Optional, TypeVar from datetime import timedelta +from typing import Optional, TypeVar from durabletask import task diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py new file mode 100644 index 000000000..9d8fcf0c2 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import json +from typing import ( + Any, + Callable, + Dict, + MutableMapping, + MutableSequence, + Optional, + Protocol, + cast, +) + +""" +General-purpose, provider-agnostic JSON serialization helpers for workflow activities. + +This module focuses on generic extension points to ensure activity inputs/outputs are JSON-only +and replay-safe. It intentionally avoids provider-specific shapes (e.g., model/tool contracts), +which should live in examples or external packages. +""" + + +def _is_json_primitive(value: Any) -> bool: + return value is None or isinstance(value, (str, int, float, bool)) + + +def _to_json_safe(value: Any, *, strict: bool) -> Any: + """Convert a Python object to a JSON-serializable structure. + + - Dict keys become strings (lenient) or error (strict) if not str. + - Unsupported values become str(value) (lenient) or error (strict). + """ + + if _is_json_primitive(value): + return value + + if isinstance(value, MutableSequence) or isinstance(value, tuple): + return [_to_json_safe(v, strict=strict) for v in value] + + if isinstance(value, MutableMapping) or isinstance(value, dict): + output: Dict[str, Any] = {} + for k, v in value.items(): + if not isinstance(k, str): + if strict: + raise ValueError('dict keys must be strings in strict mode') + k = str(k) + output[k] = _to_json_safe(v, strict=strict) + return output + + if strict: + # Attempt final json.dumps to surface type + try: + json.dumps(value) + except Exception as err: + raise ValueError(f'non-JSON-serializable value: {type(value).__name__}') from err + return value + + return str(value) + + +def _ensure_json(obj: Any, *, strict: bool) -> Any: + converted = _to_json_safe(obj, strict=strict) + # json.dumps as a final guard + json.dumps(converted) + return converted + + +# ---------------------------------------------------------------------------------------------- +# Generic helpers and extension points +# ---------------------------------------------------------------------------------------------- + + +class CanonicalSerializable(Protocol): + """Objects implementing this can produce a canonical JSON-serializable structure.""" + + def to_canonical_json(self, *, strict: bool = True) -> Any: + ... + + +class GenericSerializer(Protocol): + """Serializer that converts arbitrary Python objects to/from JSON-serializable data.""" + + def serialize(self, obj: Any, *, strict: bool = True) -> Any: + ... + + def deserialize(self, data: Any) -> Any: + ... + + +_SERIALIZERS: Dict[str, GenericSerializer] = {} + + +def register_serializer(name: str, serializer: GenericSerializer) -> None: + if not name: + raise ValueError('serializer name must be non-empty') + _SERIALIZERS[name] = serializer + + +def get_serializer(name: str) -> Optional[GenericSerializer]: + return _SERIALIZERS.get(name) + + +def ensure_canonical_json(obj: Any, *, strict: bool = True) -> Any: + """Ensure any object is converted into a JSON-serializable structure. + + - If the object implements CanonicalSerializable, call to_canonical_json + - Else, coerce via the internal JSON-safe conversion + """ + + if hasattr(obj, 'to_canonical_json') and callable(getattr(obj, 'to_canonical_json')): + return _ensure_json( + cast(CanonicalSerializable, obj).to_canonical_json(strict=strict), strict=strict + ) + return _ensure_json(obj, strict=strict) + + +class ActivityIOAdapter(Protocol): + """Adapter to control how activity inputs/outputs are serialized.""" + + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: + ... + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: + ... + + +_ACTIVITY_ADAPTERS: Dict[str, ActivityIOAdapter] = {} + + +def register_activity_adapter(name: str, adapter: ActivityIOAdapter) -> None: + if not name: + raise ValueError('activity adapter name must be non-empty') + _ACTIVITY_ADAPTERS[name] = adapter + + +def get_activity_adapter(name: str) -> Optional[ActivityIOAdapter]: + return _ACTIVITY_ADAPTERS.get(name) + + +def use_activity_adapter( + adapter: ActivityIOAdapter, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to attach an ActivityIOAdapter to an activity function.""" + + def _decorate(f: Callable[..., Any]) -> Callable[..., Any]: + cast(Any, f).__dapr_activity_io_adapter__ = adapter + return f + + return _decorate + + +def serialize_activity_input(func: Callable[..., Any], input: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_input(input, strict=strict) + return ensure_canonical_json(input, strict=strict) + + +def serialize_activity_output(func: Callable[..., Any], output: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_output(output, strict=strict) + return ensure_canonical_json(output, strict=strict) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/util.py b/ext/dapr-ext-workflow/dapr/ext/workflow/util.py index 648bc973d..3199e2558 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/util.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/util.py @@ -21,7 +21,7 @@ def getAddress(host: Optional[str] = None, port: Optional[str] = None) -> str: if not host and not port: address = settings.DAPR_GRPC_ENDPOINT or ( - f'{settings.DAPR_RUNTIME_HOST}:' f'{settings.DAPR_GRPC_PORT}' + f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' ) else: host = host or settings.DAPR_RUNTIME_HOST diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py index f460e8013..759b4438d 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py @@ -14,20 +14,33 @@ """ from __future__ import annotations + from typing import Callable, TypeVar from durabletask import task +from dapr.ext.workflow.execution_info import ActivityExecutionInfo + T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') class WorkflowActivityContext: - """Defines properties and methods for task activity context objects.""" + """Wrapper for ``durabletask.task.ActivityContext`` with metadata helpers. + + Purpose + ------- + - Provide pass-throughs for engine fields (``trace_parent``, ``trace_state``, + and parent ``workflow_span_id`` when available). + - Surface ``execution_info``: a per-activation snapshot that includes the retry + ``attempt`` and ``inbound_metadata`` actually received for this activity. + - Offer ``get_metadata()/set_metadata()`` for SDK-level durable metadata management. + """ def __init__(self, ctx: task.ActivityContext): self.__obj = ctx + self._metadata: dict[str, str] | None = None @property def workflow_id(self) -> str: @@ -42,6 +55,39 @@ def task_id(self) -> int: def get_inner_context(self) -> task.ActivityContext: return self.__obj + # Tracing fields (engine-provided) — pass-throughs when available + @property + def trace_parent(self) -> str | None: + return getattr(self.__obj, 'trace_parent', None) + + @property + def trace_state(self) -> str | None: + return getattr(self.__obj, 'trace_state', None) + + @property + def workflow_span_id(self) -> str | None: + # Parent workflow's span id for this activity invocation, if available + return getattr(self.__obj, 'workflow_span_id', None) + + @property + def attempt(self) -> int | None: + """Retry attempt for this activity invocation when provided by the engine.""" + return getattr(self.__obj, 'attempt', None) + + @property + def execution_info(self) -> ActivityExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: ActivityExecutionInfo) -> None: + self._execution_info = info + + # Metadata accessors (SDK-level; set by runtime inbound if available) + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + # Activities are simple functions that can be scheduled by workflows Activity = Callable[..., TOutput] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index b4c85f6a6..f92689fe9 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +12,14 @@ """ from __future__ import annotations + from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Optional, TypeVar, Union +from typing import Any, Callable, Generator, TypeVar, Union from durabletask import task +from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import Activity T = TypeVar('T') @@ -90,7 +90,7 @@ def set_custom_status(self, custom_status: str) -> None: pass @abstractmethod - def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: + def create_timer(self, fire_at: datetime | timedelta) -> task.Task: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -107,7 +107,7 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: @abstractmethod def call_activity( - self, activity: Activity[TOutput], *, input: Optional[TInput] = None + self, activity: Activity[TOutput], *, input: TInput | None = None ) -> task.Task[TOutput]: """Schedule an activity for execution. @@ -132,8 +132,9 @@ def call_child_workflow( self, orchestrator: Workflow[TOutput], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, + input: TInput | None = None, + instance_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule child-workflow function for execution. @@ -146,6 +147,9 @@ def call_child_workflow( instance_id: str A unique ID to use for the sub-orchestration instance. If not specified, a random UUID will be used. + retry_policy: RetryPolicy | None + Optional retry policy for the child-workflow. When provided, failures will be retried + according to the policy. Returns ------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index d1f02b354..c84e110c0 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -13,22 +13,43 @@ limitations under the License. """ +import asyncio import inspect +import traceback from functools import wraps -from typing import Optional, TypeVar +from typing import Any, Awaitable, Callable, List, Optional, TypeVar -from durabletask import worker, task +try: + from typing import Literal # py39+ +except ImportError: # pragma: no cover + Literal = str # type: ignore -from dapr.ext.workflow.workflow_context import Workflow -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext -from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext -from dapr.ext.workflow.util import getAddress +from durabletask import task, worker +from durabletask.aio.sandbox import SandboxMode from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint -from dapr.ext.workflow.logger import LoggerOptions, Logger +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, Handlers +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, + unwrap_payload_with_metadata, + wrap_payload_with_metadata, +) +from dapr.ext.workflow.logger import Logger, LoggerOptions +from dapr.ext.workflow.util import getAddress +from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext +from dapr.ext.workflow.workflow_context import Workflow T = TypeVar('T') TInput = TypeVar('TInput') @@ -43,9 +64,12 @@ def __init__( host: Optional[str] = None, port: Optional[str] = None, logger_options: Optional[LoggerOptions] = None, + *, + runtime_interceptors: Optional[list[RuntimeInterceptor]] = None, + workflow_outbound_interceptors: Optional[list[WorkflowOutboundInterceptor]] = None, ): self._logger = Logger('WorkflowRuntime', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) address = getAddress(host, port) @@ -63,16 +87,128 @@ def __init__( log_handler=options.log_handler, log_formatter=options.log_formatter, ) + # Interceptors + self._runtime_interceptors: List[RuntimeInterceptor] = list(runtime_interceptors or []) + self._workflow_outbound_interceptors: List[WorkflowOutboundInterceptor] = list( + workflow_outbound_interceptors or [] + ) + + # Outbound helpers apply interceptors and wrap metadata; no built-in transformations. + def _apply_outbound_activity( + self, + ctx: Any, + activity: Callable[..., Any] | str, + input: Any, + retry_policy: Any | None, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform CallActivityRequest + name = ( + activity + if isinstance(activity, str) + else ( + activity.__dict__['_dapr_alternate_name'] + if hasattr(activity, '_dapr_alternate_name') + else activity.__name__ + ) + ) + + def terminal(term_req: CallActivityRequest) -> CallActivityRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + # Use per-context default metadata when not provided + metadata = metadata or ctx.get_metadata() + act_req = CallActivityRequest( + activity_name=name, + input=input, + retry_policy=retry_policy, + workflow_ctx=ctx, + metadata=metadata, + ) + out = chain(act_req) + if isinstance(out, CallActivityRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_child( + self, + ctx: Any, + workflow: Callable[..., Any] | str, + input: Any, + metadata: dict[str, str] | None = None, + ): + name = ( + workflow + if isinstance(workflow, str) + else ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + ) + + def terminal(term_req: CallChildWorkflowRequest) -> CallChildWorkflowRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + child_req = CallChildWorkflowRequest( + workflow_name=name, input=input, instance_id=None, workflow_ctx=ctx, metadata=metadata + ) + out = chain(child_req) + if isinstance(out, CallChildWorkflowRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_continue_as_new( + self, + ctx: Any, + new_input: Any, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform ContinueAsNewRequest + from dapr.ext.workflow.interceptors import ContinueAsNewRequest + + def terminal(term_req: ContinueAsNewRequest) -> ContinueAsNewRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + cnr = ContinueAsNewRequest(input=new_input, workflow_ctx=ctx, metadata=metadata) + out = chain(cnr) + if isinstance(out, ContinueAsNewRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return new_input def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): + # Seamlessly support async workflows using the existing API + if inspect.iscoroutinefunction(fn): + return self.register_async_workflow(fn, name=name) + self._logger.info(f"Registering workflow '{fn.__name__}' with runtime") - def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): - """Responsible to call Workflow function in orchestrationWrapper""" - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: - return fn(daprWfContext) - return fn(daprWfContext, inp) + def orchestration_wrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + dapr_wf_context = self._get_workflow_context(ctx, md) + + # Build interceptor chain; terminal calls the user function (generator or non-generator) + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + try: + return ( + fn(dapr_wf_context) + if exec_req.input is None + else fn(dapr_wf_context, exec_req.input) + ) + except Exception as exc: # log and re-raise to surface failure details + self._logger.error( + f"{ctx.instance_id}: workflow '{fn.__name__}' raised {type(exc).__name__}: {exc}\n{traceback.format_exc()}" + ) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=dapr_wf_context, input=payload, metadata=md)) if hasattr(fn, '_workflow_registered'): # whenever a workflow is registered, it has a _dapr_alternate_name attribute @@ -87,7 +223,7 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_orchestrator( - fn.__dict__['_dapr_alternate_name'], orchestrationWrapper + fn.__dict__['_dapr_alternate_name'], orchestration_wrapper ) fn.__dict__['_workflow_registered'] = True @@ -97,12 +233,44 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): """ self._logger.info(f"Registering activity '{fn.__name__}' with runtime") - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: - return fn(wfActivityContext) - return fn(wfActivityContext, inp) + def activity_wrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): + """Activity entrypoint wrapped by runtime interceptors.""" + wf_activity_context = WorkflowActivityContext(ctx) + payload, md = unwrap_payload_with_metadata(inp) + # Populate inbound metadata onto activity context + wf_activity_context.set_metadata(md or {}) + + # Populate execution info + try: + # Determine activity name (registered alternate name or function __name__) + act_name = getattr(fn, '_dapr_alternate_name', fn.__name__) + ainfo = ActivityExecutionInfo(inbound_metadata=md or {}, activity_name=act_name) + wf_activity_context._set_execution_info(ainfo) + except Exception: + pass + + def final_handler(exec_req: ExecuteActivityRequest) -> Any: + try: + # Support async and sync activities + if inspect.iscoroutinefunction(fn): + if exec_req.input is None: + return asyncio.run(fn(wf_activity_context)) + return asyncio.run(fn(wf_activity_context, exec_req.input)) + if exec_req.input is None: + return fn(wf_activity_context) + return fn(wf_activity_context, exec_req.input) + except Exception as exc: + # Log details for troubleshooting (metadata, error type) + self._logger.error( + f"{ctx.orchestration_id}:{ctx.task_id} activity '{fn.__name__}' failed with {type(exc).__name__}: {exc}" + ) + self._logger.error(traceback.format_exc()) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain( + ExecuteActivityRequest(ctx=wf_activity_context, input=payload, metadata=md) + ) if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -117,17 +285,50 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper + fn.__dict__['_dapr_alternate_name'], activity_wrapper ) fn.__dict__['_activity_registered'] = True def start(self): """Starts the listening for work items on a background thread.""" self.__worker.start() + # Block until ready similar to durabletask e2e harness to avoid race conditions + try: + if hasattr(self.__worker, 'wait_for_ready'): + try: + # type: ignore[attr-defined] + self.__worker.wait_for_ready(timeout=10) + except TypeError: + self.__worker.wait_for_ready(10) # type: ignore[misc] + except Exception: + # If readiness isn't supported, proceed best-effort + pass def shutdown(self): """Stops the listening for work items on a background thread.""" - self.__worker.stop() + try: + self._logger.info('Stopping gRPC worker...') + self.__worker.stop() + self._logger.info('Worker shutdown completed') + except Exception as exc: # pragma: no cover + # DurableTask worker may emit CANCELLED warnings during local shutdown; not fatal + self._logger.warning(f'Worker stop encountered {type(exc).__name__}: {exc}') + + def wait_for_ready(self, timeout: Optional[float] = None) -> None: + """Optionally block until the underlying worker is connected and ready. + + If the durable task worker supports a readiness API, this will delegate to it. Otherwise it is a no-op. + + Args: + timeout: Optional timeout in seconds. + """ + if hasattr(self.__worker, 'wait_for_ready'): + try: + # type: ignore[attr-defined] + self.__worker.wait_for_ready(timeout=timeout) + except TypeError: + # Some implementations may not accept named arg + self.__worker.wait_for_ready(timeout) # type: ignore[misc] def workflow(self, __fn: Workflow = None, *, name: Optional[str] = None): """Decorator to register a workflow function. @@ -157,7 +358,11 @@ def add(ctx, x: int, y: int) -> int: """ def wrapper(fn: Workflow): - self.register_workflow(fn, name=name) + # Auto-detect coroutine and delegate to async registration + if inspect.iscoroutinefunction(fn): + self.register_async_workflow(fn, name=name) + else: + self.register_workflow(fn, name=name) @wraps(fn) def innerfn(): @@ -177,6 +382,121 @@ def innerfn(): return wrapper + # Async orchestrator registration (additive) + def register_async_workflow( + self, + fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]], + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ) -> None: + """Register an async workflow function. + + The async workflow is wrapped by a coroutine-to-generator driver so it can be + executed by the Durable Task runtime alongside existing generator workflows. + + Args: + fn: The async workflow function, taking ``AsyncWorkflowContext`` and optional input. + name: Optional alternate name for registration. + sandbox_mode: Scoped compatibility patching mode. + """ + self._logger.info(f"Registering ASYNC workflow '{fn.__name__}' with runtime") + + if hasattr(fn, '_workflow_registered'): + alt_name = fn.__dict__['_dapr_alternate_name'] + raise ValueError(f'Workflow {fn.__name__} already registered as {alt_name}') + if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn._dapr_alternate_name + if name is not None: + m = f'Workflow {fn.__name__} already has an alternate name {alt_name}' + raise ValueError(m) + else: + fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + + runner = CoroutineOrchestratorRunner(fn, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, inp: Optional[Any] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + base_ctx = self._get_workflow_context(ctx, md) + + async_ctx = AsyncWorkflowContext(base_ctx) + + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + # Build the generator using the (potentially shaped) input from interceptors. + shaped_input = exec_req.input + return runner.to_generator(async_ctx, shaped_input) + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=async_ctx, input=payload, metadata=md)) + + self.__worker._registry.add_named_orchestrator( + fn.__dict__['_dapr_alternate_name'], generator_orchestrator + ) + fn.__dict__['_workflow_registered'] = True + + def _get_workflow_context( + self, ctx: task.OrchestrationContext, metadata: dict[str, str] | None = None + ) -> DaprWorkflowContext: + """Get the workflow context and execution info for the given orchestration context and metadata. + Execution info serves as a read-only snapshot of the workflow context. + + Args: + ctx: The orchestration context. + metadata: The metadata for the workflow. + + Returns: + The workflow context. + """ + base_ctx = DaprWorkflowContext( + ctx, + self._logger.get_options(), + outbound_handlers={ + Handlers.CALL_ACTIVITY: self._apply_outbound_activity, + Handlers.CALL_CHILD_WORKFLOW: self._apply_outbound_child, + Handlers.CONTINUE_AS_NEW: self._apply_outbound_continue_as_new, + }, + ) + # Populate minimal execution info (only inbound metadata) + info = WorkflowExecutionInfo(inbound_metadata=metadata or {}) + base_ctx._set_execution_info(info) + base_ctx.set_metadata(metadata or {}) + return base_ctx + + def async_workflow( + self, + __fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + """Decorator to register an async workflow function. + + Usage: + @runtime.async_workflow(name="my_wf") + async def my_wf(ctx: AsyncWorkflowContext, input): + ... + """ + + def wrapper(fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]]): + self.register_async_workflow(fn, name=name, sandbox_mode=sandbox_mode) + + @wraps(fn) + def innerfn(): + return fn + + if hasattr(fn, '_dapr_alternate_name'): + innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + else: + innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__signature__ = inspect.signature(fn) + return innerfn + + if __fn: + return wrapper(__fn) + + return wrapper + def activity(self, __fn: Activity = None, *, name: Optional[str] = None): """Decorator to register an activity function. diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py index 10847fc54..af1d7e735 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py @@ -13,8 +13,8 @@ limitations under the License. """ -from enum import Enum import json +from enum import Enum from durabletask import client diff --git a/ext/dapr-ext-workflow/examples/generics_interceptors_example.py b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py new file mode 100644 index 000000000..7ef9c8f4c --- /dev/null +++ b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import os +from dataclasses import asdict, dataclass +from typing import List + +from dapr.ext.workflow import ( + DaprWorkflowClient, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ContinueAsNewRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, +) + +# ------------------------------ +# Typed payloads carried by interceptors +# ------------------------------ + + +@dataclass +class MyWorkflowInput: + question: str + tags: List[str] + + +@dataclass +class MyActivityInput: + name: str + count: int + + +# ------------------------------ +# Interceptors with generics + minimal (de)serialization +# ------------------------------ + + +class MyClientInterceptor(BaseClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[MyWorkflowInput], + nxt, + ) -> str: + # Ensure wire format is JSON-serializable (dict) + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ScheduleWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=input.metadata, + ) + return nxt(shaped) + + +class MyRuntimeInterceptor(BaseRuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert inbound dict into typed model for workflow code + data = input.input + if isinstance(data, dict) and 'question' in data: + input.input = MyWorkflowInput( + question=data.get('question', ''), tags=list(data.get('tags', [])) + ) # type: ignore[assignment] + return nxt(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[MyActivityInput], + nxt, + ): + data = input.input + if isinstance(data, dict) and 'name' in data: + input.input = MyActivityInput( + name=data.get('name', ''), count=int(data.get('count', 0)) + ) # type: ignore[assignment] + return nxt(input) + + +class MyOutboundInterceptor(BaseWorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert typed payload back to wire before sending + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallChildWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def continue_as_new( + self, + input: ContinueAsNewRequest[MyWorkflowInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ContinueAsNewRequest[MyWorkflowInput]( + input=payload, # type: ignore[arg-type] + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def call_activity( + self, + input: CallActivityRequest[MyActivityInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallActivityRequest[MyActivityInput]( + activity_name=input.activity_name, + input=payload, # type: ignore[arg-type] + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + +# ------------------------------ +# Minimal runnable example with sidecar +# ------------------------------ + + +def main() -> None: + # Expect DAPR_GRPC_ENDPOINT (e.g., dns:127.0.0.1:56179) to be set for local sidecar/dev hub + ep = os.getenv('DAPR_GRPC_ENDPOINT') + if not ep: + print('WARNING: DAPR_GRPC_ENDPOINT not set; default sidecar address will be used') + + # Build runtime with interceptors + runtime = WorkflowRuntime( + runtime_interceptors=[MyRuntimeInterceptor()], + workflow_outbound_interceptors=[MyOutboundInterceptor()], + ) + + # Register a simple activity + @runtime.activity(name='greet') + def greet(_ctx, x: dict | None = None) -> str: # wire format at activity boundary is dict + x = x or {} + return f'Hello {x.get("name", "world")} x{x.get("count", 0)}' + + # Register an async workflow that calls the activity once + @runtime.async_workflow(name='wf_greet') + async def wf_greet(ctx, arg: MyWorkflowInput | dict | None = None): + # At this point, runtime interceptor converted inbound to MyWorkflowInput + if isinstance(arg, MyWorkflowInput): + act_in = MyActivityInput(name=arg.question, count=len(arg.tags)) + else: + # Fallback if interceptor not present + d = arg or {} + act_in = MyActivityInput(name=str(d.get('question', '')), count=len(d.get('tags', []))) + return await ctx.call_activity('greet', input=asdict(act_in)) + + runtime.start() + try: + # Client with client-side interceptor for schedule typing + client = DaprWorkflowClient(interceptors=[MyClientInterceptor()]) + wf_input = MyWorkflowInput(question='World', tags=['a', 'b']) + instance_id = client.schedule_new_workflow(wf_greet, input=wf_input) + print('Started instance:', instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print('Final status:', getattr(st, 'runtime_status', None)) + if st: + print('Output:', st.to_json().get('serialized_output')) + finally: + runtime.shutdown() + + +if __name__ == '__main__': + main() diff --git a/ext/dapr-ext-workflow/tests/README.md b/ext/dapr-ext-workflow/tests/README.md new file mode 100644 index 000000000..6759a362d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/README.md @@ -0,0 +1,94 @@ +## Workflow tests: unit, integration, and custom ports + +This directory contains unit tests (no sidecar required) and integration tests (require a running sidecar/runtime). + +### Prereqs + +- Python 3.11+ (tox will create an isolated venv) +- Dapr sidecar for integration tests (HTTP and gRPC ports) +- Optional: Durable Task gRPC endpoint for DT e2e tests + +### Run all tests via tox (recommended) + +```bash +tox -e py311 +``` + +This runs: +- Core SDK tests (unittest) +- Workflow extension unit tests (pytest) +- Workflow extension integration tests (pytest) if your sidecar/runtime is reachable + +### Run only workflow unit tests + +Unit tests live at `ext/dapr-ext-workflow/tests` excluding the `integration/` subfolder. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +### Run workflow integration tests + +Integration tests live under `ext/dapr-ext-workflow/tests/integration/` and require a running sidecar/runtime. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests/integration +``` + +If tests cannot reach your sidecar/runtime, they will skip or fail fast depending on the specific test. + +### Configure custom sidecar ports/endpoints + +The SDK reads connection settings from env vars (see `dapr.conf.global_settings`). Use these to point tests at custom ports: + +- Dapr gRPC: + - `DAPR_GRPC_ENDPOINT` (preferred): endpoint string, e.g. `dns:127.0.0.1:50051` + - or `DAPR_RUNTIME_HOST` and `DAPR_GRPC_PORT`, e.g. `DAPR_RUNTIME_HOST=127.0.0.1`, `DAPR_GRPC_PORT=50051` + +- Dapr HTTP (only for HTTP-based tests): + - `DAPR_HTTP_ENDPOINT`: e.g. `http://127.0.0.1:3600` + - or `DAPR_RUNTIME_HOST` and `DAPR_HTTP_PORT`, e.g. `DAPR_HTTP_PORT=3600` + +Examples: +```bash +# Use custom gRPC 50051 and HTTP 3600 +export DAPR_GRPC_ENDPOINT=dns:127.0.0.1:50051 +export DAPR_HTTP_ENDPOINT=http://127.0.0.1:3600 + +# Alternatively, using host/port pairs +export DAPR_RUNTIME_HOST=127.0.0.1 +export DAPR_GRPC_PORT=50051 +export DAPR_HTTP_PORT=3600 + +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Note: For gRPC, avoid `http://` or `https://` schemes. Use `dns:host:port` or just set host/port separately. + +### Durable Task e2e tests (optional) + +Some tests (e.g., `integration/test_async_e2e_dt.py`) talk directly to a Durable Task gRPC endpoint. They use: + +- `DURABLETASK_GRPC_ENDPOINT` (default `localhost:56178`) + +If your DT runtime listens elsewhere: +```bash +export DURABLETASK_GRPC_ENDPOINT=127.0.0.1:56179 +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py +``` + + + + diff --git a/ext/dapr-ext-workflow/tests/_fakes.py b/ext/dapr-ext-workflow/tests/_fakes.py new file mode 100644 index 000000000..56245e109 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/_fakes.py @@ -0,0 +1,73 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + + +class FakeOrchestrationContext: + def __init__( + self, + *, + instance_id: str = 'wf-1', + current_utc_datetime: datetime | None = None, + is_replaying: bool = False, + workflow_name: str = 'wf', + parent_instance_id: str | None = None, + history_event_sequence: int | None = 1, + trace_parent: str | None = None, + trace_state: str | None = None, + orchestration_span_id: str | None = None, + workflow_attempt: int | None = None, + ) -> None: + self.instance_id = instance_id + self.current_utc_datetime = ( + current_utc_datetime if current_utc_datetime else datetime(2025, 1, 1) + ) + self.is_replaying = is_replaying + self.workflow_name = workflow_name + self.parent_instance_id = parent_instance_id + self.history_event_sequence = history_event_sequence + self.trace_parent = trace_parent + self.trace_state = trace_state + self.orchestration_span_id = orchestration_span_id + self.workflow_attempt = workflow_attempt + + +class FakeActivityContext: + def __init__( + self, + *, + orchestration_id: str = 'wf-1', + task_id: int = 1, + attempt: int | None = None, + trace_parent: str | None = None, + trace_state: str | None = None, + workflow_span_id: str | None = None, + ) -> None: + self.orchestration_id = orchestration_id + self.task_id = task_id + self.attempt = attempt + self.trace_parent = trace_parent + self.trace_state = trace_state + self.workflow_span_id = workflow_span_id + + +def make_orch_ctx(**overrides: Any) -> FakeOrchestrationContext: + return FakeOrchestrationContext(**overrides) + + +def make_act_ctx(**overrides: Any) -> FakeActivityContext: + return FakeActivityContext(**overrides) diff --git a/ext/dapr-ext-workflow/tests/conftest.py b/ext/dapr-ext-workflow/tests/conftest.py new file mode 100644 index 000000000..f20a225e7 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/conftest.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Ensure tests prefer the local python-sdk repository over any installed site-packages +# This helps when running pytest directly (outside tox/CI), so changes in the repo are exercised. +from __future__ import annotations # noqa: I001 + +import sys +from pathlib import Path +import importlib +import pytest + + +def pytest_configure(config): # noqa: D401 (pytest hook) + """Pytest configuration hook that prepends the repo root to sys.path. + + This ensures `import dapr` resolves to the local source tree when running tests directly. + Under tox/CI (editable installs), this is a no-op but still safe. + """ + try: + # ext/dapr-ext-workflow/tests/conftest.py -> repo root is 3 parents up + repo_root = Path(__file__).resolve().parents[3] + except Exception: + return + + repo_str = str(repo_root) + if repo_str not in sys.path: + sys.path.insert(0, repo_str) + + # Best-effort diagnostic: show where dapr was imported from + try: + dapr_mod = importlib.import_module('dapr') + dapr_path = Path(getattr(dapr_mod, '__file__', '')).resolve() + where = 'site-packages' if 'site-packages' in str(dapr_path) else 'local-repo' + print(f'[dapr-ext-workflow/tests] dapr resolved from {where}: {dapr_path}', file=sys.stderr) + except Exception: + # If dapr isn't importable yet, that's fine; tests importing it later will use modified sys.path + pass + + +@pytest.fixture(autouse=True) +def cleanup_workflow_registrations(request): + """Clean up workflow/activity registration markers after each test. + + This prevents test interference when the same function objects are reused across tests. + The workflow runtime marks functions with _dapr_alternate_name and _activity_registered + attributes, which can cause 'already registered' errors in subsequent tests. + """ + yield # Run the test + + # After test completes, clean up functions defined in the test module + test_module = sys.modules.get(request.module.__name__) + if test_module: + for name in dir(test_module): + obj = getattr(test_module, name, None) + if callable(obj) and hasattr(obj, '__dict__'): + try: + # Only clean up if __dict__ is writable (not mappingproxy) + if isinstance(obj.__dict__, dict): + obj.__dict__.pop('_dapr_alternate_name', None) + obj.__dict__.pop('_activity_registered', None) + except (AttributeError, TypeError): + # Skip objects with read-only __dict__ + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py new file mode 100644 index 000000000..a15df47c1 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +""" +Async e2e tests using durabletask worker/client directly. + +These validate basic orchestration behavior against a running sidecar +to isolate environment issues from WorkflowRuntime wiring. +""" + +from __future__ import annotations + +import os +import time + +import pytest +from durabletask.aio import AsyncWorkflowContext +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + + +def _is_runtime_available(ep_str: str) -> bool: + import socket + + try: + host, port = ep_str.split(':') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, int(port))) + sock.close() + return result == 0 + except Exception: + return False + + +endpoint = os.getenv('DURABLETASK_GRPC_ENDPOINT', 'localhost:50001') + +skip_if_no_runtime = pytest.mark.skipif( + not _is_runtime_available(endpoint), + reason='DurableTask runtime not available', +) + + +@skip_if_no_runtime +def test_dt_simple_activity_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, x: int) -> int: + return x * 3 + + worker.add_activity(act) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, x: int) -> int: + return await ctx.call_activity(act, input=x) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-act-{int(time.time() * 1000)}' + client.schedule_new_orchestration(orch, input=5, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + # Output is JSON serialized scalar + assert st.serialized_output.strip() in ('15', '"15"') + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_timer_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, delay: float) -> dict: + start = ctx.now() + await ctx.sleep(delay) + end = ctx.now() + return {'start': start.isoformat(), 'end': end.isoformat(), 'delay': delay} + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-timer-{int(time.time() * 1000)}' + delay = 1.0 + client.schedule_new_orchestration(orch, input=delay, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_sub_orchestrator_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, s: str) -> str: + return f'A:{s}' + + worker.add_activity(act) + + async def child(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] child start', s) + try: + res = await ctx.call_activity(act, input=s) + print('[E2E DEBUG] child done', res) + return res + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] child exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + # Explicit registration to avoid decorator replacing symbol with a string in newer versions + worker.add_async_orchestrator(child) + + async def parent(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] parent start', s) + try: + c = await ctx.call_sub_orchestrator(child, input=s) + out = f'P:{c}' + print('[E2E DEBUG] parent done', out) + return out + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] parent exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + worker.add_async_orchestrator(parent) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-sub-{int(time.time() * 1000)}' + print('[E2E DEBUG] scheduling instance', iid) + client.schedule_new_orchestration(parent, input='x', instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + if st.runtime_status.name != 'COMPLETED': + # Print orchestration state details to aid debugging + print('[E2E DEBUG] orchestration FAILED; details:') + to_json = getattr(st, 'to_json', None) + if callable(to_json): + try: + print(to_json()) + except Exception: + pass + print('status=', getattr(st, 'runtime_status', None)) + print('output=', getattr(st, 'serialized_output', None)) + print('failure=', getattr(st, 'failure_details', None)) + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py new file mode 100644 index 000000000..1acabb43a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py @@ -0,0 +1,896 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import time + +import pytest + +from dapr.ext.workflow import AsyncWorkflowContext, DaprWorkflowClient, WorkflowRuntime +from dapr.ext.workflow.interceptors import ( + BaseRuntimeInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, +) + +skip_integration = pytest.mark.skipif( + False, + reason='integration enabled', +) + + +@skip_integration +def test_integration_suspension_and_buffering(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='suspend_orchestrator_async') + async def suspend_orchestrator(ctx: AsyncWorkflowContext): + # Expose suspension state via custom status + ctx.set_custom_status({'is_suspended': getattr(ctx, 'is_suspended', False)}) + # Wait for 'resume_event' and then complete + data = await ctx.wait_for_external_event('resume_event') + return {'resumed_with': data} + + runtime.start() + try: + # Allow connection to stabilize before scheduling + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'suspend-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=suspend_orchestrator, instance_id=instance_id) + + # Wait until started + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Pause and verify state becomes SUSPENDED and custom status updates on next activation + client.pause_workflow(instance_id) + # Give the worker time to process suspension + time.sleep(1) + state = client.get_workflow_state(instance_id) + assert state is not None + assert state.runtime_status.name in ( + 'SUSPENDED', + 'RUNNING', + ) # some hubs report SUSPENDED explicitly + + # While suspended, raise the event; it should buffer + client.raise_workflow_event(instance_id, 'resume_event', data={'ok': True}) + + # Resume and expect completion + client.resume_workflow(instance_id) + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_generator_metadata_propagation(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md_gen') + def recv_md_gen(ctx, _=None): + return ctx.get_metadata() or {} + + @runtime.workflow(name='gen_parent_sets_md') + def parent_gen(ctx: 'WorkflowContext'): + ctx.set_metadata({'tenant': 'acme', 'tier': 'gold'}) + md = yield ctx.call_activity(recv_md_gen, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'gen-md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_gen, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('tier') == 'gold' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + return { + 'tp': getattr(ctx, 'trace_parent', None), + 'ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + } + + @runtime.async_workflow(name='child_trace') + async def child(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_tp': getattr(ctx, 'trace_parent', None), + 'wf_ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + 'act': await ctx.call_activity(trace_probe, input=None), + } + + @runtime.async_workflow(name='parent_trace') + async def parent(ctx: AsyncWorkflowContext): + child_out = await ctx.call_child_workflow(child, input=None) + return { + 'parent_tp': getattr(ctx, 'trace_parent', None), + 'parent_span': getattr(ctx, 'workflow_span_id', None), + 'child': child_out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + + # TODO: assert more specifically when we have trace context information + + # Parent (engine-provided fields may be absent depending on runtime build/config) + assert isinstance(data.get('parent_tp'), (str, type(None))) + assert isinstance(data.get('parent_span'), (str, type(None))) + # Child orchestrator fields + _child = data.get('child') or {} + assert isinstance(_child.get('wf_tp'), (str, type(None))) + assert isinstance(_child.get('wf_span'), (str, type(None))) + # Activity fields under child + act = _child.get('act') or {} + assert isinstance(act.get('tp'), (str, type(None))) + assert isinstance(act.get('wf_span'), (str, type(None))) + + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow_injected_metadata(): + # Deterministic trace propagation using interceptors via durable metadata + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_id' + + class InjectTraceClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class InjectTraceOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class RestoreTraceRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Ensure metadata arrives + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + runtime = WorkflowRuntime( + runtime_interceptors=[RestoreTraceRuntime()], + workflow_outbound_interceptors=[InjectTraceOutbound()], + ) + + @runtime.activity(name='trace_probe2') + def trace_probe2(ctx, _=None): + return getattr(ctx, 'get_metadata', lambda: {})().get(TRACE_KEY) + + @runtime.async_workflow(name='child_trace2') + async def child2(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'act_md': await ctx.call_activity(trace_probe2, input=None), + } + + @runtime.async_workflow(name='parent_trace2') + async def parent2(ctx: AsyncWorkflowContext): + out = await ctx.call_child_workflow(child2, input=None) + return { + 'parent_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'child': out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient(interceptors=[InjectTraceClient()]) + iid = f'trace-child-md-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent2, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + assert data.get('parent_md') == 'sdk-trace-123' + child = data.get('child') or {} + assert child.get('wf_md') == 'sdk-trace-123' + assert child.get('act_md') == 'sdk-trace-123' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_termination_semantics(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='termination_orchestrator_async') + async def termination_orchestrator(ctx: AsyncWorkflowContext): + # Long timer; test will terminate before it fires + await ctx.create_timer(300.0) + return 'not-reached' + + print(list(runtime._WorkflowRuntime__worker._registry.orchestrators.keys())) + + runtime.start() + try: + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'term-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=termination_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Terminate and assert TERMINATED state, not raising inside orchestrator + client.terminate_workflow(instance_id, output='terminated') + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'TERMINATED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_when_any_first_wins(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='when_any_async') + async def when_any_orchestrator(ctx: AsyncWorkflowContext): + first = await ctx.when_any( + [ + ctx.wait_for_external_event('go'), + ctx.create_timer(300.0), + ] + ) + # Return a simple, serializable value (winner's result) to avoid output serialization issues + try: + result = first.get_result() + except Exception: + result = None + return {'winner_result': result} + + runtime.start() + try: + # Ensure worker has established streams before scheduling + try: + if hasattr(runtime, 'wait_for_ready'): + runtime.wait_for_ready(timeout=15) # type: ignore[attr-defined] + except Exception: + pass + time.sleep(2) + + client = DaprWorkflowClient() + instance_id = f'whenany-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=when_any_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + # Confirm RUNNING state before raising event (mitigates race conditions) + try: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is None + or getattr(st, 'runtime_status', None) is None + or st.runtime_status.name != 'RUNNING' + ): + end = time.time() + 10 + while time.time() < end: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is not None + and getattr(st, 'runtime_status', None) is not None + and st.runtime_status.name == 'RUNNING' + ): + break + time.sleep(0.2) + except Exception: + pass + + # Raise event immediately to win the when_any + client.raise_workflow_event(instance_id, 'go', data={'ok': True}) + + # Brief delay to allow event processing, then strictly use DaprWorkflowClient + time.sleep(1.0) + final = None + try: + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + except TimeoutError: + final = None + if final is None: + deadline = time.time() + 30 + while time.time() < deadline: + s = client.get_workflow_state(instance_id, fetch_payloads=False) + if s is not None and getattr(s, 'runtime_status', None) is not None: + if s.runtime_status.name in ('COMPLETED', 'FAILED', 'TERMINATED'): + final = s + break + time.sleep(0.5) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + # TODO: when sidecar exposes command diagnostics, assert only one command set was emitted + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_activity_completes(): + runtime = WorkflowRuntime() + + @runtime.activity(name='echo_int') + def echo_act(ctx, x: int) -> int: + return x + + @runtime.async_workflow(name='async_activity_once') + async def wf(ctx: AsyncWorkflowContext): + out = await ctx.call_activity(echo_act, input=7) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'act-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + if state.runtime_status.name != 'COMPLETED': + fd = getattr(state, 'failure_details', None) + msg = getattr(fd, 'message', None) if fd else None + et = getattr(fd, 'error_type', None) if fd else None + print(f'[INTEGRATION DEBUG] Failure details: {et} {msg}') + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_activity(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md') + def recv_md(ctx, _=None): + md = ctx.get_metadata() if hasattr(ctx, 'get_metadata') else {} + return md + + @runtime.async_workflow(name='wf_with_md') + async def wf(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme'}) + md = await ctx.call_activity(recv_md, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_recv_md') + async def child(ctx: AsyncWorkflowContext, _=None): + # Echo inbound metadata + return ctx.get_metadata() or {} + + @runtime.async_workflow(name='parent_sets_md') + async def parent(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'role': 'user'}) + out = await ctx.call_child_workflow(child, input=None) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + # Validate output has metadata keys + data = state.to_json() + import json as _json + + out = _json.loads(data.get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('role') == 'user' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_with_runtime_interceptors(): + """E2E: Verify trace_parent and orchestration_span_id via runtime interceptors.""" + records = { # captured by interceptor + 'wf_tp': None, + 'wf_span': None, + 'act_tp': None, + 'act_span': None, + } + + class TraceInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['wf_tp'] = getattr(ctx, 'trace_parent', None) + records['wf_span'] = getattr(ctx, 'workflow_span_id', None) + except Exception: + pass + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['act_tp'] = getattr(ctx, 'trace_parent', None) + # Activity contexts don't have orchestration_span_id; capture task span if present + records['act_span'] = getattr(ctx, 'activity_span_id', None) + except Exception: + pass + return next(request) + + runtime = WorkflowRuntime(runtime_interceptors=[TraceInterceptor()]) + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + # Return trace context seen inside activity + return { + 'trace_parent': getattr(ctx, 'trace_parent', None), + 'trace_state': getattr(ctx, 'trace_state', None), + } + + @runtime.async_workflow(name='trace_parent_wf') + async def wf(ctx: AsyncWorkflowContext): + # Access orchestration span id and trace parent from workflow context + _ = getattr(ctx, 'workflow_span_id', None) + _ = getattr(ctx, 'trace_parent', None) + return await ctx.call_activity(trace_probe, input=None) + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + # Activity returned strings (may be empty); assert types + assert isinstance(out.get('trace_parent'), (str, type(None))) + assert isinstance(out.get('trace_state'), (str, type(None))) + # Interceptor captured workflow and activity contexts + wf_tp = records['wf_tp'] + wf_span = records['wf_span'] + act_tp = records['act_tp'] + # TODO: assert more specifically when we have trace context information + assert isinstance(wf_tp, (str, type(None))) + assert isinstance(wf_span, (str, type(None))) + assert isinstance(act_tp, (str, type(None))) + # If we have a workflow span id, it should appear as parent-id inside activity traceparent + if isinstance(wf_span, str) and wf_span and isinstance(act_tp, str) and act_tp: + assert wf_span.lower() in act_tp.lower() + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_runtime_shutdown_is_clean(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='noop') + async def noop(ctx: AsyncWorkflowContext): + return 'ok' + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'shutdown-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=noop, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=30) + assert st is not None and st.runtime_status.name == 'COMPLETED' + finally: + # Call shutdown multiple times to ensure idempotent and clean behavior + for _ in range(3): + try: + runtime.shutdown() + except Exception: + # Test should not raise even if worker logs cancellation warnings + assert False, 'runtime.shutdown() raised unexpectedly' + # Recreate and shutdown again to ensure no lingering background threads break next startup + rt2 = WorkflowRuntime() + rt2.start() + try: + time.sleep(1) + finally: + try: + rt2.shutdown() + except Exception: + assert False, 'second runtime.shutdown() raised unexpectedly' + + +@skip_integration +def test_integration_continue_as_new_outbound_interceptor_metadata(): + # Verify continue_as_new outbound interceptor can inject metadata carried to the new run + from dapr.ext.workflow import BaseWorkflowOutboundInterceptor + + INJECT_KEY = 'injected' + + class InjectOnContinueAsNew(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(INJECT_KEY, 'yes') + request.metadata = md + return next(request) + + runtime = WorkflowRuntime( + workflow_outbound_interceptors=[InjectOnContinueAsNew()], + ) + + @runtime.workflow(name='continue_as_new_probe') + def wf(ctx, arg: dict | None = None): + if not arg or arg.get('phase') != 'second': + ctx.set_metadata({'tenant': 'acme'}) + # carry over existing metadata; interceptor will also inject + ctx.continue_as_new({'phase': 'second'}, carryover_metadata=True) + return # Must not yield after continue_as_new + # Second run: return inbound metadata observed + return ctx.get_metadata() or {} + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'can-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Confirm both carried and injected metadata are present + assert out.get('tenant') == 'acme' + assert out.get(INJECT_KEY) == 'yes' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_activity_attempt_exposed(): + # Verify that activity ctx exposes attempt via pass-through property + runtime = WorkflowRuntime() + + @runtime.activity(name='probe_attempt_activity') + def probe_attempt_activity(ctx, _=None): + attempt = getattr(ctx, 'attempt', None) + return {'attempt': attempt} + + @runtime.async_workflow(name='activity_attempt_wf') + async def activity_attempt_wf(ctx: AsyncWorkflowContext): + return await ctx.call_activity(probe_attempt_activity, input=None) + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'act-attempt-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=activity_attempt_wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + val = out.get('attempt', None) + assert (val is None) or isinstance(val, int) + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_child_workflow_attempt_exposed(): + # Verify that child workflow ctx exposes workflow_attempt + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_probe_attempt') + async def child_probe_attempt(ctx: AsyncWorkflowContext, _=None): + att = getattr(ctx, 'workflow_attempt', None) + return {'wf_attempt': att} + + @runtime.async_workflow(name='parent_calls_child_for_attempt') + async def parent_calls_child_for_attempt(ctx: AsyncWorkflowContext): + return await ctx.call_child_workflow(child_probe_attempt, input=None) + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'child-attempt-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_calls_child_for_attempt, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + val = out.get('wf_attempt', None) + assert (val is None) or isinstance(val, int) + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_contextvars_trace_propagation(monkeypatch): + # Demonstrates contextvars-based trace propagation via interceptors in async workflows + import contextvars + import json as _json + + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_ctx' + current_trace: contextvars.ContextVar[str | None] = contextvars.ContextVar( + 'trace', default=None + ) + + class CVClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'wf-parent') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class CVOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class CVRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + runtime = WorkflowRuntime( + runtime_interceptors=[CVRuntime()], workflow_outbound_interceptors=[CVOutbound()] + ) + + @runtime.activity(name='cv_probe') + def cv_probe(_ctx, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/act') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + @runtime.activity(name='cv_flaky_probe') + def cv_flaky_probe(ctx, _=None): + before = current_trace.get() + attempt = getattr(ctx, 'attempt', None) + if attempt is not None and attempt == 0: + # Fail first attempt (when engine exposes attempt) to trigger retry + raise Exception('fail-once') + tok = current_trace.set(f'{before}/act-retry') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after, 'attempt': attempt} + + @runtime.async_workflow(name='cv_child') + async def cv_child(ctx: AsyncWorkflowContext, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/child') if before else None + try: + act = await ctx.call_activity(cv_probe, input=None) + finally: + if tok is not None: + current_trace.reset(tok) + restored = current_trace.get() + return {'before': before, 'restored': restored, 'act': act} + + @runtime.async_workflow(name='cv_parent') + async def cv_parent(ctx: AsyncWorkflowContext, _=None): + from datetime import timedelta + + from dapr.ext.workflow import RetryPolicy + + top_before = current_trace.get() + child = await ctx.call_child_workflow(cv_child, input=None) + after_child = current_trace.get() + act = await ctx.call_activity(cv_probe, input=None) + after_act = current_trace.get() + act_retry = await ctx.call_activity( + cv_flaky_probe, + input=None, + retry_policy=RetryPolicy( + first_retry_interval=timedelta(seconds=0), max_number_of_attempts=2 + ), + ) + return { + 'before': top_before, + 'child': child, + 'act': act, + 'act_retry': act_retry, + 'after_child': after_child, + 'after_act': after_act, + } + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient(interceptors=[CVClient()]) + iid = f'cv-ctx-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=cv_parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Top-level activity sees parent trace context during execution + act = out.get('act') or {} + assert act.get('before') == 'wf-parent' + assert act.get('inner') == 'wf-parent/act' + assert act.get('after') == 'wf-parent' + # Child workflow's activity at least inherits parent context + child = out.get('child') or {} + child_act = child.get('act') or {} + assert child_act.get('before') == 'wf-parent' + assert child_act.get('inner') == 'wf-parent/act' + assert child_act.get('after') == 'wf-parent' + # Flaky activity retried: second attempt returns with attempt >= 1 and parent context + act_retry = out.get('act_retry') or {} + att = act_retry.get('attempt') + if att is not None: + assert att in (1, 2) + assert act_retry.get('before') == 'wf-parent' + assert act_retry.get('inner') == 'wf-parent/act-retry' + assert act_retry.get('after') == 'wf-parent' + finally: + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/integration/test_perf_real_activity.py b/ext/dapr-ext-workflow/tests/integration/test_perf_real_activity.py new file mode 100644 index 000000000..f7686d6b8 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_perf_real_activity.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import os +import time + +import pytest + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + DaprWorkflowContext, + WorkflowActivityContext, + WorkflowRuntime, +) + +""" +Integration micro-benchmark using real activities via the sidecar. + +Skips by default. Enable with RUN_INTEGRATION_BENCH=1 and ensure a sidecar +with workflows enabled is running and DURABLETASK_GRPC_ENDPOINT is set. +""" + +skip_bench = pytest.mark.skipif( + os.getenv('RUN_INTEGRATION_BENCH', '0') != '1', + reason='Set RUN_INTEGRATION_BENCH=1 to run integration benchmark', +) + + +@skip_bench +def test_real_activity_benchmark(): + runtime = WorkflowRuntime() + + @runtime.activity(name='echo') + def echo_act(ctx: WorkflowActivityContext, x: int) -> int: + return x + + @runtime.workflow(name='gen_chain') + def gen_chain(ctx: DaprWorkflowContext, num_steps: int) -> int: + total = 0 + for i in range(num_steps): + total += yield ctx.call_activity(echo_act, input=i) + return total + + @runtime.async_workflow(name='async_chain') + async def async_chain(ctx: AsyncWorkflowContext, num_steps: int) -> int: + total = 0 + for i in range(num_steps): + total += await ctx.call_activity(echo_act, input=i) + return total + + runtime.start() + try: + try: + runtime.wait_for_ready(timeout=15) + except Exception: + pass + + client = DaprWorkflowClient() + steps = int(os.getenv('BENCH_STEPS', '100')) + + # Generator run + gid = 'bench-gen' + t0 = time.perf_counter() + client.schedule_new_workflow(gen_chain, input=steps, instance_id=gid) + state_g = client.wait_for_workflow_completion(gid, timeout_in_seconds=300) + t_gen = time.perf_counter() - t0 + assert state_g is not None and state_g.runtime_status.name == 'COMPLETED' + + # Async run + aid = 'bench-async' + t1 = time.perf_counter() + client.schedule_new_workflow(async_chain, input=steps, instance_id=aid) + state_a = client.wait_for_workflow_completion(aid, timeout_in_seconds=300) + t_async = time.perf_counter() - t1 + assert state_a is not None and state_a.runtime_status.name == 'COMPLETED' + + print( + { + 'steps': steps, + 'gen_time_s': t_gen, + 'async_time_s': t_async, + 'ratio': (t_async / t_gen) if t_gen else None, + } + ) + finally: + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/perf/test_driver_overhead.py b/ext/dapr-ext-workflow/tests/perf/test_driver_overhead.py new file mode 100644 index 000000000..76ac75fbd --- /dev/null +++ b/ext/dapr-ext-workflow/tests/perf/test_driver_overhead.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import os +import time + +import pytest + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + +skip_bench = pytest.mark.skipif( + os.getenv('RUN_DRIVER_BENCH', '0') != '1', + reason='Set RUN_DRIVER_BENCH=1 to run driver micro-benchmark', +) + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'bench' + + def create_timer(self, fire_at): + return FakeTask('timer') + + +def drive(gen, steps: int): + next(gen) + for _ in range(steps - 1): + try: + gen.send(None) + except StopIteration: + break + + +def gen_orchestrator(ctx, steps: int): + for _ in range(steps): + yield ctx.create_timer(0) + return 'done' + + +async def async_orchestrator(ctx: AsyncWorkflowContext, steps: int): + for _ in range(steps): + await ctx.create_timer(0) + return 'done' + + +@skip_bench +def test_driver_overhead_vs_generator(): + fake = FakeCtx() + steps = 1000 + + # Generator path timing + def gen_wrapper(ctx): + return gen_orchestrator(ctx, steps) + + start = time.perf_counter() + g = gen_wrapper(fake) + drive(g, steps) + gen_time = time.perf_counter() - start + + # Async driver timing + runner = CoroutineOrchestratorRunner(async_orchestrator) + start = time.perf_counter() + ag = runner.to_generator(AsyncWorkflowContext(fake), steps) + drive(ag, steps) + async_time = time.perf_counter() - start + + ratio = async_time / gen_time if gen_time > 0 else float('inf') + print({'gen_time_s': gen_time, 'async_time_s': async_time, 'ratio': ratio}) + # Assert driver overhead stays within reasonable bound + assert ratio < 3.0 diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_registration.py b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py new file mode 100644 index 000000000..cbffdd5a5 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py @@ -0,0 +1,58 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.activities = {} + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_activity_decorator_supports_async(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + return x + 2 + + # Ensure registered + reg = rt._WorkflowRuntime__worker._registry + assert 'async_act' in reg.activities + + # Call the wrapper and ensure it runs the coroutine to completion + wrapper = reg.activities['async_act'] + + class _Ctx: + pass + + out = wrapper(_Ctx(), 5) + assert out == 7 diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py new file mode 100644 index 000000000..1aa920e29 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-act-retry' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf(ctx: AsyncWorkflowContext): + # One activity that ultimately fails after retries + await ctx.call_activity(lambda: None, retry_policy={'dummy': True}) + return 'not-reached' + + +def test_activity_retry_final_failure_raises(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime + next(gen) + # Simulate final failure after retry policy exhausts + with pytest.raises(RuntimeError, match='activity failed'): + gen.throw(RuntimeError('activity failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_api_coverage.py b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py new file mode 100644 index 000000000..c4bb28bce --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-cov' + self._status = None + + def set_custom_status(self, status): + self._status = status + + def continue_as_new(self, new_input, *, save_events=False): + self._continued = (new_input, save_events) + + # methods used by awaitables + def call_activity(self, activity, *, input=None, retry_policy=None): + class _T: + pass + + return _T() + + def call_child_workflow(self, workflow, *, input=None, instance_id=None, retry_policy=None): + class _T: + pass + + return _T() + + def create_timer(self, fire_at): + class _T: + pass + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + pass + + return _T() + + +def test_async_context_exposes_required_methods(): + base = FakeCtx() + ctx = AsyncWorkflowContext(base) + + # basic deterministic utils existence + assert isinstance(ctx.now(), datetime) + _ = ctx.random() + _ = ctx.uuid4() + + # pass-throughs + ctx.set_custom_status('ok') + assert base._status == 'ok' + ctx.continue_as_new({'foo': 1}, save_events=True) + assert getattr(base, '_continued', None) == ({'foo': 1}, True) + + # awaitable constructors do not raise + ctx.call_activity(lambda: None, input={'x': 1}) + ctx.call_child_workflow(lambda: None) + ctx.sleep(1.0) + ctx.wait_for_external_event('go') + ctx.when_all([]) + ctx.when_any([]) diff --git a/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py new file mode 100644 index 000000000..4228096f3 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from durabletask import task as durable_task_module + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.deterministic import deterministic_random, deterministic_uuid4 + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-123' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_first_wins(gen, winner_name): + # Simulate when_any: first send the winner, then finish + next(gen) # prime + result = gen.send({'task': winner_name}) + # the coroutine should complete; StopIteration will be raised by caller + return result + + +async def wf_when_all(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(1.0) + res = await ctx.when_all([a, b]) + return res + + +def test_when_all_maps_and_completes(monkeypatch): + # Patch durabletask.when_all to accept our FakeTask inputs and return a FakeTask + monkeypatch.setattr(durable_task_module, 'when_all', lambda tasks: FakeTask('when_all')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Drive two yields: when_all yields a task once; we simply return a list result + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send([{'task': 'activity:lambda'}, {'task': 'timer'}]) + except StopIteration as stop: + out = stop.value + assert isinstance(out, list) + assert len(out) == 2 + + +async def wf_when_any(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(5.0) + first = await ctx.when_any([a, b]) + # Return the first result only; losers ignored deterministically + return first + + +def test_when_any_first_wins_behavior(monkeypatch): + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send({'task': 'activity:lambda'}) + except StopIteration as stop: + out = stop.value + assert out == {'task': 'activity:lambda'} + + +def test_deterministic_random_and_uuid_are_stable(): + iid = 'iid-123' + now = datetime(2024, 1, 1) + rnd1 = deterministic_random(iid, now) + rnd2 = deterministic_random(iid, now) + seq1 = [rnd1.random() for _ in range(5)] + seq2 = [rnd2.random() for _ in range(5)] + assert seq1 == seq2 + u1 = deterministic_uuid4(deterministic_random(iid, now)) + u2 = deterministic_uuid4(deterministic_random(iid, now)) + assert u1 == u2 diff --git a/ext/dapr-ext-workflow/tests/test_async_context.py b/ext/dapr-ext-workflow/tests/test_async_context.py new file mode 100644 index 000000000..c5f4fed08 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_context.py @@ -0,0 +1,296 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import types +from datetime import datetime, timedelta, timezone + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.interceptors import BaseRuntimeInterceptor, ExecuteWorkflowRequest +from dapr.ext.workflow.workflow_context import WorkflowContext + + +class DummyBaseCtx: + def __init__(self): + self.instance_id = 'abc-123' + # freeze a deterministic timestamp + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + self._custom_status = None + self._continued = None + self._metadata = None + self._ei = types.SimpleNamespace( + workflow_id='abc-123', + workflow_name='wf', + is_replaying=False, + history_event_sequence=1, + inbound_metadata={'a': 'b'}, + parent_instance_id=None, + ) + + def set_custom_status(self, s: str): + self._custom_status = s + + def continue_as_new(self, new_input, *, save_events: bool = False): + self._continued = (new_input, save_events) + + # Metadata parity + def set_metadata(self, md): + self._metadata = md + + def get_metadata(self): + return self._metadata + + @property + def execution_info(self): + return self._ei + + +def test_parity_properties_and_now(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + assert ctx.instance_id == 'abc-123' + assert ctx.current_utc_datetime == datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + # now() should mirror current_utc_datetime + assert ctx.now() == ctx.current_utc_datetime + + +def test_timer_accepts_float_and_timedelta(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + # Float should be interpreted as seconds and produce a SleepAwaitable + aw1 = ctx.create_timer(1.5) + # Timedelta should pass through + aw2 = ctx.create_timer(timedelta(seconds=2)) + + # We only assert types by duck-typing public attribute presence to avoid + # importing internal classes in tests + assert hasattr(aw1, '_ctx') and hasattr(aw1, '__await__') + assert hasattr(aw2, '_ctx') and hasattr(aw2, '__await__') + + +def test_wait_for_external_event_and_concurrency_factories(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + + evt = ctx.wait_for_external_event('go') + assert hasattr(evt, '__await__') + + # when_all/when_any/gather return awaitables + a = ctx.create_timer(0.1) + b = ctx.create_timer(0.2) + + all_aw = ctx.when_all([a, b]) + any_aw = ctx.when_any([a, b]) + gat_aw = ctx.gather(a, b) + gat_exc_aw = ctx.gather(a, b, return_exceptions=True) + + for x in (all_aw, any_aw, gat_aw, gat_exc_aw): + assert hasattr(x, '__await__') + + +def test_deterministic_utils_and_passthroughs(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + rnd = ctx.random() + # should behave like a random.Random-like object; test a stable first value + val = rnd.random() + # Just assert it is within (0,1) and stable across two calls to the seeded RNG instance + assert 0.0 < val < 1.0 + assert rnd.random() != val # next value changes + + uid = ctx.uuid4() + # Should be a UUID-like string representation + assert isinstance(str(uid), str) and len(str(uid)) >= 32 + + # passthroughs + ctx.set_custom_status('hello') + assert base._custom_status == 'hello' + + ctx.continue_as_new({'x': 1}, save_events=True) + assert base._continued == ({'x': 1}, True) + + +def test_async_metadata_api_and_execution_info(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + ctx.set_metadata({'k': 'v'}) + assert base._metadata == {'k': 'v'} + assert ctx.get_metadata() == {'k': 'v'} + ei = ctx.execution_info + assert ei and ei.workflow_id == 'abc-123' and ei.workflow_name == 'wf' + + +def test_async_outbound_metadata_plumbed_into_awaitables(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + a = ctx.call_activity(lambda: None, input=1, metadata={'m': 'n'}) + c = ctx.call_child_workflow(lambda c, x: None, input=2, metadata={'x': 'y'}) + # Introspect for test (internal attribute) + assert getattr(a, '_metadata', None) == {'m': 'n'} + assert getattr(c, '_metadata', None) == {'x': 'y'} + + +def test_async_parity_surface_exists(): + # Guard: ensure essential parity members exist + ctx = AsyncWorkflowContext(DummyBaseCtx()) + for name in ( + 'set_metadata', + 'get_metadata', + 'execution_info', + 'call_activity', + 'call_child_workflow', + 'continue_as_new', + ): + assert hasattr(ctx, name) + + +def test_public_api_parity_against_workflowcontext_abc(): + # Derive the required sync API surface from the ABC plus metadata/execution_info + required = { + name + for name, attr in WorkflowContext.__dict__.items() + if getattr(attr, '__isabstractmethod__', False) + } + required.update({'set_metadata', 'get_metadata', 'execution_info'}) + + # Async context must expose the same names + async_ctx = AsyncWorkflowContext(DummyBaseCtx()) + missing_in_async = [name for name in required if not hasattr(async_ctx, name)] + assert not missing_in_async, f'AsyncWorkflowContext missing: {missing_in_async}' + + # Sync context should also expose these names + class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'abc-123' + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + + def set_custom_status(self, s: str): + pass + + def create_timer(self, fire_at): + return object() + + def wait_for_external_event(self, name: str): + return object() + + def continue_as_new(self, new_input, *, save_events: bool = False): + pass + + def call_activity(self, *, activity, input=None, retry_policy=None): + return object() + + def call_sub_orchestrator(self, fn, *, input=None, instance_id=None, retry_policy=None): + return object() + + sync_ctx = DaprWorkflowContext(_FakeOrchCtx()) + missing_in_sync = [name for name in required if not hasattr(sync_ctx, name)] + assert not missing_in_sync, f'DaprWorkflowContext missing: {missing_in_sync}' + + +def test_runtime_interceptor_shapes_async_input(): + runtime = WorkflowRuntime() + + class ShapeInput(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + data = request.input + # Mutate input passed to workflow + if isinstance(data, dict): + shaped = {**data, 'shaped': True} + else: + shaped = {'value': data, 'shaped': True} + request.input = shaped + return next(request) + + # Recreate runtime with interceptor wired in + runtime = WorkflowRuntime(runtime_interceptors=[ShapeInput()]) + + @runtime.async_workflow(name='wf_shape_input') + async def wf_shape_input(ctx: AsyncWorkflowContext, arg: dict | None = None): + # Verify shaped input is observed by the workflow + return arg + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'shape-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_shape_input, instance_id=iid, input={'x': 1}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + assert out.get('x') == 1 + assert out.get('shaped') is True + finally: + runtime.shutdown() + + +def test_runtime_interceptor_context_manager_with_async_workflow(): + """Test that context managers stay active during async workflow execution.""" + runtime = WorkflowRuntime() + + # Track when context enters and exits + context_state = {'entered': False, 'exited': False, 'workflow_ran': False} + + class ContextInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + # Wrapper generator to keep context manager alive + def wrapper(): + from contextlib import ExitStack + + with ExitStack(): + # Mark context as entered + context_state['entered'] = True + + # Get the workflow generator + gen = next(request) + + # Use yield from to keep context alive during execution + yield from gen + + # Context will exit after generator completes + context_state['exited'] = True + + return wrapper() + + runtime = WorkflowRuntime(runtime_interceptors=[ContextInterceptor()]) + + @runtime.async_workflow(name='wf_context_test') + async def wf_context_test(ctx: AsyncWorkflowContext, arg: dict | None = None): + context_state['workflow_ran'] = True + return {'result': 'ok'} + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'ctx-test-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_context_test, instance_id=iid, input={}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify context manager was active during workflow execution + assert context_state['entered'], 'Context should have been entered' + assert context_state['workflow_ran'], 'Workflow should have executed' + assert context_state['exited'], 'Context should have exited after completion' + finally: + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py new file mode 100644 index 000000000..9fa9735bb --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeOrchestrationContext: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-errors' + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + +def drive_raise(gen, exc: Exception): + # Prime + task = gen.send(None) + assert isinstance(task, FakeTask) + # Simulate runtime failure of yielded task + try: + gen.throw(exc) + except StopIteration as stop: + return stop.value + + +async def wf_catches_activity_error(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: (_ for _ in ()).throw(RuntimeError('boom'))) + except RuntimeError as e: + return f'caught:{e}' + return 'not-reached' + + +def test_activity_error_propagates_into_coroutine_and_can_be_caught(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_catches_activity_error) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive_raise(gen, RuntimeError('boom')) + assert result == 'caught:boom' + + +async def wf_returns_sync(ctx: AsyncWorkflowContext): + return 42 + + +def test_sync_return_is_handled_without_runtime_error(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_returns_sync) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and complete + try: + gen.send(None) + except StopIteration as stop: + assert stop.value == 42 + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + self.activities = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_generator_and_async_registration_coexist(monkeypatch): + # Monkeypatch TaskHubGrpcWorker to avoid real gRPC + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='gen_wf') + def gen(ctx): + yield ctx.create_timer(0) + return 'ok' + + async def async_wf(ctx: AsyncWorkflowContext): + await ctx.sleep(0) + return 'ok' + + rt.register_async_workflow(async_wf, name='async_wf') + + # Verify registry got both entries + reg = rt._WorkflowRuntime__worker._registry + assert 'gen_wf' in reg.orchestrators + assert 'async_wf' in reg.orchestrators + + # Drive generator orchestrator wrapper + gen_fn = reg.orchestrators['gen_wf'] + g = gen_fn(FakeOrchestrationContext()) + t = next(g) + assert isinstance(t, FakeTask) + try: + g.send(None) + except StopIteration as stop: + assert stop.value == 'ok' + + # Also verify CancelledError propagates and can be caught + import asyncio + + async def wf_cancel(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: None) + except asyncio.CancelledError: + return 'cancelled' + return 'not-reached' + + runner = CoroutineOrchestratorRunner(wf_cancel) + gen_2 = runner.to_generator(AsyncWorkflowContext(FakeOrchestrationContext()), None) + # prime + next(gen_2) + try: + gen_2.throw(asyncio.CancelledError()) + except StopIteration as stop: + assert stop.value == 'cancelled' diff --git a/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py new file mode 100644 index 000000000..f1df67202 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_workflow_decorator_detects_async_and_registers(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='async_wf') + async def async_wf(ctx: AsyncWorkflowContext, x: int) -> int: + # no awaits to keep simple + return x + 1 + + # ensure it was placed into registry + reg = rt._WorkflowRuntime__worker._registry + assert 'async_wf' in reg.orchestrators diff --git a/ext/dapr-ext-workflow/tests/test_async_replay.py b/ext/dapr-ext-workflow/tests/test_async_replay.py new file mode 100644 index 000000000..8fa33f122 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_replay.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime, timedelta + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self, instance_id: str = 'iid-replay', now: datetime | None = None): + self.current_utc_datetime = now or datetime(2024, 1, 1) + self.instance_id = instance_id + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}:{input}') + + def create_timer(self, fire_at): + return FakeTask(f'timer:{fire_at}') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_with_history(gen, results): + """Drive the generator with a pre-baked sequence of results, simulating replay history.""" + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +async def wf_mixed(ctx: AsyncWorkflowContext): + # activity + r1 = await ctx.call_activity(lambda: None, input={'x': 1}) + # timer + await ctx.sleep(timedelta(seconds=5)) + # event + e = await ctx.wait_for_external_event('go') + # deterministic utils + t = ctx.now() + u = str(ctx.uuid4()) + return {'a': r1, 'e': e, 't': t.isoformat(), 'u': u} + + +def test_replay_same_history_same_outputs(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_mixed) + # Pre-bake results sequence corresponding to activity -> timer -> event + history = [ + {'task': "activity:lambda:{'x': 1}"}, + None, + {'event': 42}, + ] + out1 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + out2 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + assert out1 == out2 diff --git a/ext/dapr-ext-workflow/tests/test_async_sandbox.py b/ext/dapr-ext-workflow/tests/test_async_sandbox.py new file mode 100644 index 000000000..d940b3b85 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sandbox.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import random +import time + +import pytest +from durabletask.aio.errors import SandboxViolationError +from durabletask.aio.sandbox import SandboxMode + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'iid-sandbox' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +async def wf_sleep(ctx: AsyncWorkflowContext): + # asyncio.sleep should be patched to workflow timer + await asyncio.sleep(0.1) + return 'ok' + + +def drive(gen, first_result=None): + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + result = None + except StopIteration as stop: + return stop.value + + +def test_sandbox_best_effort_patches_sleep(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_sleep, sandbox_mode=SandboxMode.BEST_EFFORT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' + + +def test_sandbox_random_uuid_time_are_deterministic(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner( + lambda ctx: _wf_random_uuid_time(ctx), sandbox_mode=SandboxMode.BEST_EFFORT + ) + gen1 = runner.to_generator(AsyncWorkflowContext(fake), None) + out1 = drive(gen1) + gen2 = runner.to_generator(AsyncWorkflowContext(fake), None) + out2 = drive(gen2) + assert out1 == out2 + + +async def _wf_random_uuid_time(ctx: AsyncWorkflowContext): + r1 = random.random() + u1 = __import__('uuid').uuid4() + t1 = time.time(), getattr(time, 'time_ns', lambda: int(time.time() * 1_000_000_000))() + # no awaits needed; return tuple + return (r1, str(u1), t1[0], t1[1]) + + +def test_strict_blocks_create_task(): + async def wf(ctx: AsyncWorkflowContext): + with pytest.raises(SandboxViolationError): + asyncio.create_task(asyncio.sleep(0)) + return 'ok' + + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf, sandbox_mode=SandboxMode.STRICT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' diff --git a/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py new file mode 100644 index 000000000..2e7740363 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-sub' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_success(gen, results): + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +def drive_raise(gen, exc: Exception): + # Prime + next(gen) + # Throw failure into orchestrator + return pytest.raises(Exception, gen.throw, exc) + + +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +async def parent_success(ctx: AsyncWorkflowContext): + res = await ctx.call_child_workflow(child, input=3) + return res + 1 + + +def test_sub_orchestrator_success(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_success) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # First yield is the sub-orchestrator task + result = drive_success(gen, results=[6]) + assert result == 7 + + +async def parent_failure(ctx: AsyncWorkflowContext): + # Do not catch; allow failure to propagate + await ctx.call_child_workflow(child, input=1) + return 'not-reached' + + +def test_sub_orchestrator_failure_raises_into_orchestrator(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_failure) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and then throw into the coroutine to simulate child failure + next(gen) + with pytest.raises(RuntimeError, match='child failed'): + gen.throw(RuntimeError('child failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py new file mode 100644 index 000000000..728653ea3 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from durabletask import task as durable_task_module + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-any' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf_when_any(ctx: AsyncWorkflowContext): + # Two awaitables: an activity and a timer + a = ctx.call_activity(lambda: None) + b = ctx.sleep(10) + first = await ctx.when_any([a, b]) + return first + + +def test_when_any_yields_once_and_returns_first_result(monkeypatch): + # Patch durabletask.when_any to avoid requiring real durabletask.Task objects + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + + # Prime; expect a single composite yield + yielded = gen.send(None) + assert isinstance(yielded, FakeTask) + # Send the 'first' completion; generator should complete without yielding again + try: + gen.send({'task': 'activity'}) + raise AssertionError('generator should have completed') + except StopIteration as stop: + assert stop.value == {'task': 'activity'} diff --git a/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py new file mode 100644 index 000000000..b56f5af62 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'test-instance' + self._events: dict[str, list] = {} + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive(gen, first_result=None): + """Drive a generator produced by the async driver, emulating the runtime.""" + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + # Provide a generic result for every yield + result = {'task': task.name} + except StopIteration as stop: + return stop.value + + +async def sample_activity(ctx: AsyncWorkflowContext): + return await ctx.call_activity(lambda: None) + + +def test_activity_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_activity) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'task': 'activity:lambda'}) + assert result == {'task': 'activity:lambda'} + + +async def sample_timer(ctx: AsyncWorkflowContext): + await ctx.create_timer(1.0) + return 'done' + + +def test_timer_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_timer) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result=None) + assert result == 'done' + + +async def sample_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return ('event', data) + + +def test_event_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_event) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'hello': 'world'}) + assert result == ('event', {'hello': 'world'}) diff --git a/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py b/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py index 9fdfe0440..fcdfff902 100644 --- a/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Copyright 2023 The Dapr Authors +Copyright 2025 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -13,12 +13,14 @@ limitations under the License. """ +import unittest from datetime import datetime from unittest import mock -import unittest + +from durabletask import worker + from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext -from durabletask import worker mock_date_time = datetime(2023, 4, 27) mock_instance_id = 'instance001' diff --git a/ext/dapr-ext-workflow/tests/test_deterministic.py b/ext/dapr-ext-workflow/tests/test_deterministic.py new file mode 100644 index 000000000..6e1828790 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_deterministic.py @@ -0,0 +1,75 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import datetime as _dt + +import pytest + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + +""" +Tests for deterministic helpers shared across workflow contexts. +""" + + +class _FakeBaseCtx: + def __init__(self, instance_id: str, dt: _dt.datetime): + self.instance_id = instance_id + self.current_utc_datetime = dt + + +def _fixed_dt(): + return _dt.datetime(2024, 1, 1) + + +def test_random_string_deterministic_across_instances_async(): + base = _FakeBaseCtx('iid-1', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + b_ctx = AsyncWorkflowContext(base) + a = a_ctx.random_string(16) + b = b_ctx.random_string(16) + assert a == b + + +def test_random_string_deterministic_across_context_types(): + base = _FakeBaseCtx('iid-2', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + s1 = a_ctx.random_string(12) + + # Minimal fake orchestration context for DaprWorkflowContext + d_ctx = DaprWorkflowContext(base) + s2 = d_ctx.random_string(12) + assert s1 == s2 + + +def test_random_string_respects_alphabet(): + base = _FakeBaseCtx('iid-3', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + s = ctx.random_string(20, alphabet='abc') + assert set(s).issubset(set('abc')) + + +def test_random_string_length_and_edge_cases(): + base = _FakeBaseCtx('iid-4', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + + assert ctx.random_string(0) == '' + + with pytest.raises(ValueError): + ctx.random_string(-1) + + with pytest.raises(ValueError): + ctx.random_string(5, alphabet='') diff --git a/ext/dapr-ext-workflow/tests/test_generic_serialization.py b/ext/dapr-ext-workflow/tests/test_generic_serialization.py new file mode 100644 index 000000000..0aeb0c841 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_generic_serialization.py @@ -0,0 +1,64 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from typing import Any + +from dapr.ext.workflow import ( + ActivityIOAdapter, + CanonicalSerializable, + ensure_canonical_json, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) + + +@dataclass +class _Point(CanonicalSerializable): + x: int + y: int + + def to_canonical_json(self, *, strict: bool = True) -> Any: + return {'x': self.x, 'y': self.y} + + +def test_ensure_canonical_json_on_custom_object(): + p = _Point(1, 2) + out = ensure_canonical_json(p, strict=True) + assert out == {'x': 1, 'y': 2} + + +class _IO(ActivityIOAdapter): + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: + if isinstance(input, _Point): + return {'pt': [input.x, input.y]} + return ensure_canonical_json(input, strict=strict) + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: + return {'ok': ensure_canonical_json(output, strict=strict)} + + +def test_activity_adapter_decorator_customizes_io(): + _use = use_activity_adapter(_IO()) + + @_use + def act(obj): + return obj + + pt = _Point(3, 4) + inp = serialize_activity_input(act, pt, strict=True) + assert inp == {'pt': [3, 4]} + + out = serialize_activity_output(act, {'k': 'v'}, strict=True) + assert out == {'ok': {'k': 'v'}} diff --git a/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py new file mode 100644 index 000000000..bc90ad528 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py @@ -0,0 +1,560 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from dapr.ext.workflow import ( + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowRuntime, +) + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _TracingInterceptor(RuntimeInterceptor): + """Interceptor that injects and restores trace context.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'wf_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'wf_trace_cleanup:{tracing_data}') + + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'act_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'act_trace_cleanup:{tracing_data}') + + return result + + +class _LoggingInterceptor(RuntimeInterceptor): + """Interceptor that logs workflow and activity execution.""" + + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + self.events.append(f'{self.label}:wf_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:wf_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:wf_error:{type(e).__name__}') + raise + + def execute_activity(self, request: ExecuteActivityRequest, next): + self.events.append(f'{self.label}:act_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:act_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:act_error:{type(e).__name__}') + raise + + +class _ValidationInterceptor(RuntimeInterceptor): + """Interceptor that validates inputs and outputs.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('wf_validation_failed') + raise ValueError('Invalid workflow input') + + self.events.append('wf_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, dict) and result.get('invalid_output'): + self.events.append('wf_output_validation_failed') + raise ValueError('Invalid workflow output') + + self.events.append('wf_output_validation_passed') + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('act_validation_failed') + raise ValueError('Invalid activity input') + + self.events.append('act_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, str) and 'invalid' in result: + self.events.append('act_output_validation_failed') + raise ValueError('Invalid activity output') + + self.events.append('act_output_validation_passed') + return result + + +def test_single_interceptor_workflow_execution(monkeypatch): + """Test single interceptor around workflow execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='simple') + def simple(ctx, x: int): + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['simple'] + result = orch(_make_orch_ctx(), 5) + + # For non-generator workflows, the result is returned directly + assert result == 10 + assert events == [ + 'log:wf_start:5', + 'log:wf_complete:10', + ] + + +def test_single_interceptor_activity_execution(monkeypatch): + """Test single interceptor around activity execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + result = act(_make_act_ctx(), 7) + + assert result == 14 + assert events == [ + 'log:act_start:7', + 'log:act_complete:14', + ] + + +def test_multiple_interceptors_execution_order(monkeypatch): + """Test multiple interceptors execute in correct order.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + outer_interceptor = _LoggingInterceptor(events, 'outer') + inner_interceptor = _LoggingInterceptor(events, 'inner') + + # First interceptor in list is outermost + rt = WorkflowRuntime(runtime_interceptors=[outer_interceptor, inner_interceptor]) + + @rt.workflow(name='ordered') + def ordered(ctx, x: int): + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['ordered'] + result = orch(_make_orch_ctx(), 3) + + assert result == 4 + # Outer interceptor enters first, exits last (stack semantics) + assert events == [ + 'outer:wf_start:3', + 'inner:wf_start:3', + 'inner:wf_complete:4', + 'outer:wf_complete:4', + ] + + +def test_tracing_interceptor_context_restoration(monkeypatch): + """Test tracing interceptor properly handles trace context.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + tracing_interceptor = _TracingInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[tracing_interceptor]) + + @rt.workflow(name='traced') + def traced(ctx, input_data): + # Workflow can access the trace context that was restored + return {'result': input_data.get('value', 0) * 2} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['traced'] + + # Input with tracing data + input_with_trace = {'value': 5, 'tracing': {'trace_id': 'abc123', 'span_id': 'def456'}} + + result = orch(_make_orch_ctx(), input_with_trace) + + assert result == {'result': 10} + assert events == [ + "wf_trace_restored:{'trace_id': 'abc123', 'span_id': 'def456'}", + "wf_trace_cleanup:{'trace_id': 'abc123', 'span_id': 'def456'}", + ] + + +def test_validation_interceptor_input_validation(monkeypatch): + """Test validation interceptor catches invalid inputs.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + validation_interceptor = _ValidationInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[validation_interceptor]) + + @rt.workflow(name='validated') + def validated(ctx, input_data): + return {'result': 'ok'} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['validated'] + + # Test valid input + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == {'result': 'ok'} + assert 'wf_validation_passed' in events + assert 'wf_output_validation_passed' in events + + # Test invalid input + events.clear() + + with pytest.raises(ValueError, match='Invalid workflow input'): + orch(_make_orch_ctx(), {'invalid': True}) + + assert 'wf_validation_failed' in events + + +def test_interceptor_error_handling_workflow(monkeypatch): + """Test interceptor properly handles workflow errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='error_wf') + def error_wf(ctx, x: int): + raise ValueError('workflow error') + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['error_wf'] + + with pytest.raises(ValueError, match='workflow error'): + orch(_make_orch_ctx(), 1) + + assert events == [ + 'log:wf_start:1', + 'log:wf_error:ValueError', + ] + + +def test_interceptor_error_handling_activity(monkeypatch): + """Test interceptor properly handles activity errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='error_act') + def error_act(ctx, x: int) -> int: + raise RuntimeError('activity error') + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['error_act'] + + with pytest.raises(RuntimeError, match='activity error'): + act(_make_act_ctx(), 5) + + assert events == [ + 'log:act_start:5', + 'log:act_error:RuntimeError', + ] + + +def test_async_workflow_with_interceptors(monkeypatch): + """Test interceptors work with async workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='async_wf') + async def async_wf(ctx, x: int): + # Simple async workflow + return x * 3 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['async_wf'] + gen_result = orch(_make_orch_ctx(), 4) + + # Async workflows return a generator that needs to be driven + with pytest.raises(StopIteration) as stop: + next(gen_result) + result = stop.value.value + + assert result == 12 + # The interceptor sees the generator being returned, not the final result + assert events[0] == 'log:wf_start:4' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_async_activity_with_interceptors(monkeypatch): + """Test interceptors work with async activities.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) # Simulate async work + return x * 4 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['async_act'] + result = act(_make_act_ctx(), 3) + + assert result == 12 + assert events == [ + 'log:act_start:3', + 'log:act_complete:12', + ] + + +def test_generator_workflow_with_interceptors(monkeypatch): + """Test interceptors work with generator workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='gen_wf') + def gen_wf(ctx, x: int): + v1 = yield 'step1' + v2 = yield 'step2' + return (x, v1, v2) + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen_wf'] + gen_orch = orch(_make_orch_ctx(), 1) + + # Drive the generator + assert next(gen_orch) == 'step1' + assert gen_orch.send('result1') == 'step2' + with pytest.raises(StopIteration) as stop: + gen_orch.send('result2') + result = stop.value.value + + assert result == (1, 'result1', 'result2') + # For generator workflows, interceptor sees the generator being returned + assert events[0] == 'log:wf_start:1' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_interceptor_chain_with_early_return(monkeypatch): + """Test interceptor can modify or short-circuit execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ShortCircuitInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + events.append('short_circuit_check') + if isinstance(request.input, dict) and request.input.get('short_circuit'): + events.append('short_circuited') + return 'short_circuit_result' + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + logging_interceptor = _LoggingInterceptor(events, 'log') + short_circuit_interceptor = _ShortCircuitInterceptor() + + rt = WorkflowRuntime(runtime_interceptors=[short_circuit_interceptor, logging_interceptor]) + + @rt.workflow(name='maybe_short') + def maybe_short(ctx, input_data): + return 'normal_result' + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['maybe_short'] + + # Test normal execution + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == 'normal_result' + assert 'short_circuit_check' in events + assert 'log:wf_start' in str(events) + assert 'log:wf_complete' in str(events) + + # Test short-circuit execution + events.clear() + result = orch(_make_orch_ctx(), {'short_circuit': True}) + + assert result == 'short_circuit_result' + assert 'short_circuit_check' in events + assert 'short_circuited' in events + # Logging interceptor should not be called when short-circuited + assert 'log:wf_start' not in str(events) + + +def test_interceptor_input_transformation(monkeypatch): + """Test interceptor can transform inputs before execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _TransformInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Transform input by adding metadata + if isinstance(request.input, dict): + transformed_input = {**request.input, 'interceptor_metadata': 'added'} + new_input = ExecuteWorkflowRequest(ctx=request.ctx, input=transformed_input) + events.append(f'transformed_input:{transformed_input}') + return next(new_input) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + transform_interceptor = _TransformInterceptor() + rt = WorkflowRuntime(runtime_interceptors=[transform_interceptor]) + + @rt.workflow(name='transform_test') + def transform_test(ctx, input_data): + # Workflow should see the transformed input + return input_data + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['transform_test'] + result = orch(_make_orch_ctx(), {'original': 'value'}) + + # Result should include the interceptor metadata + assert result == {'original': 'value', 'interceptor_metadata': 'added'} + assert 'transformed_input:' in str(events) + + +def test_runtime_interceptor_can_shape_activity_result(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _ShapeResult(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + res = next(request) + return {'wrapped': res} + + rt = WorkflowRuntime(runtime_interceptors=[_ShapeResult()]) + + @rt.activity(name='echo') + def echo(_ctx, x): + return x + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['echo'] + out = act(_make_act_ctx(), 7) + assert out == {'wrapped': 7} diff --git a/ext/dapr-ext-workflow/tests/test_interceptors.py b/ext/dapr-ext-workflow/tests/test_interceptors.py new file mode 100644 index 000000000..80253ca6e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_interceptors.py @@ -0,0 +1,177 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from dapr.ext.workflow import RuntimeInterceptor, WorkflowRuntime + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +""" +Runtime interceptor chain tests for `WorkflowRuntime`. + +This suite intentionally uses a fake worker/registry to validate interceptor composition +without requiring a sidecar. It focuses on the "why" behind runtime interceptors: + +- Ensure `execute_workflow` and `execute_activity` hooks compose in order and are + invoked exactly once around workflow entry/activity execution. +- Cover both generator-based and async workflows, asserting the chain returns a + generator to the runtime (rather than iterating it), preserving send()/throw() + semantics during orchestration replay. +- Keep signal-to-noise high for failures in chain logic independent of gRPC/sidecar. + +These tests complement outbound/client interceptor tests and e2e tests by providing +fast, deterministic coverage of the chaining behavior and generator handling rules. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _RecorderInterceptor(RuntimeInterceptor): + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:wf_enter:{request.input!r}') + ret = next(request) + self.events.append(f'{self.label}:wf_ret_type:{ret.__class__.__name__}') + return ret + + def execute_activity(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:act_enter:{request.input!r}') + res = next(request) + self.events.append(f'{self.label}:act_exit:{res!r}') + return res + + +def test_generator_workflow_hooks_sequence(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='gen') + def gen(ctx, x: int): + v = yield 'A' + v2 = yield 'B' + return (x, v, v2) + + # Drive the registered orchestrator + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen'] + gen_driver = orch(_make_orch_ctx(), 10) + # Prime and run + assert next(gen_driver) == 'A' + assert gen_driver.send('ra') == 'B' + with pytest.raises(StopIteration) as stop: + gen_driver.send('rb') + result = stop.value.value + + assert result == (10, 'ra', 'rb') + # Interceptors run once around the workflow entry; they return a generator to the runtime + assert events[0] == 'mw:wf_enter:10' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_async_workflow_hooks_called(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='awf') + async def awf(ctx, x: int): + # No awaits to keep the driver simple + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['awf'] + gen_orch = orch(_make_orch_ctx(), 41) + with pytest.raises(StopIteration) as stop: + next(gen_orch) + result = stop.value.value + + assert result == 42 + # For async workflow, interceptor sees entry and a generator return type + assert events[0] == 'mw:wf_enter:41' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_activity_hooks_and_policy(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ExplodingActivity(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + raise RuntimeError('boom') + + def execute_workflow(self, request, next): # type: ignore[override] + return next(request) + + # Continue-on-error policy + rt = WorkflowRuntime( + runtime_interceptors=[_RecorderInterceptor(events, 'mw'), _ExplodingActivity()] + ) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + # Error in interceptor bubbles up + with pytest.raises(RuntimeError): + act(_make_act_ctx(), 5) diff --git a/ext/dapr-ext-workflow/tests/test_metadata_context.py b/ext/dapr-ext-workflow/tests/test_metadata_context.py new file mode 100644 index 000000000..ab318c7a6 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_metadata_context.py @@ -0,0 +1,370 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Optional + +import pytest + +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = datetime(2024, 1, 1) + self._custom_status = None + self.is_replaying = False + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator(self, wf, *, input=None, instance_id=None, retry_policy=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + +def _drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +def test_client_schedule_metadata_envelope(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, + name, + *, + input=None, + instance_id=None, + start_at: Optional[datetime] = None, + reuse_id_policy=None, + ): # noqa: E501 + captured['name'] = name + captured['input'] = input + captured['instance_id'] = instance_id + captured['start_at'] = start_at + captured['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _InjectMetadata(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + # Add metadata without touching args + md = {'otel.trace_id': 't-123'} + new_request = ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + return next(new_request) + + client = DaprWorkflowClient(interceptors=[_InjectMetadata()]) + + def wf(ctx, x): + yield 'noop' + + wf.__name__ = 'meta_wf' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + env = captured['input'] + assert isinstance(env, dict) + assert '__dapr_meta__' in env and '__dapr_payload__' in env + assert env['__dapr_payload__'] == {'a': 1} + assert env['__dapr_meta__']['metadata']['otel.trace_id'] == 't-123' + + +def test_runtime_inbound_unwrap_and_metadata_visible(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _Recorder(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + seen['metadata'] = request.metadata + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + seen['act_metadata'] = request.metadata + return next(request) + + rt = WorkflowRuntime(runtime_interceptors=[_Recorder()]) + + @rt.workflow(name='unwrap') + def unwrap(ctx, x): + # x should be the original payload, not the envelope + assert x == {'hello': 'world'} + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['unwrap'] + envelope = { + '__dapr_meta__': {'v': 1, 'metadata': {'c': 'd'}}, + '__dapr_payload__': {'hello': 'world'}, + } + result = orch(_FakeOrchCtx(), envelope) + assert result == 'ok' + assert seen['metadata'] == {'c': 'd'} + + +def test_outbound_activity_and_child_wrap_metadata(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _AddActMeta(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + # Wrap returned args with metadata by returning a new CallActivityRequest + return next( + type(request)( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata={'k': 'v'}, + ) + ) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata={'p': 'q'}, + ) + ) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_AddActMeta()]) + + @rt.workflow(name='parent') + def parent(ctx, x): + a = yield ctx.call_activity(lambda: None, input={'i': 1}) + b = yield ctx.call_child_workflow(lambda c, y: None, input={'j': 2}) + # Return both so we can assert envelopes surfaced through our fake driver + return a, b + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + # First yield: activity token received by driver; shape may be envelope or raw depending on adapter + t1 = gen.send(None) + assert hasattr(t1, '_v') + # Resume with any value; our fake driver ignores and loops + t2 = gen.send({'act': 'done'}) + assert hasattr(t2, '_v') + with pytest.raises(StopIteration) as stop: + gen.send({'child': 'done'}) + result = stop.value.value + # The result is whatever user returned; envelopes validated above + assert isinstance(result, tuple) and len(result) == 2 + + +def test_context_set_metadata_default_propagation(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + # No outbound interceptor needed; runtime will wrap using ctx.get_metadata() + rt = WorkflowRuntime() + + @rt.workflow(name='use_ctx_md') + def use_ctx_md(ctx, x): + # Set default metadata on context + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}) + # Return the raw yielded value for assertion + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['use_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + assert hasattr(yielded, '_v') + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'ctx' + + +def test_per_call_metadata_overrides_context(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='override_ctx_md') + def override_ctx_md(ctx, x): + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}, metadata={'k': 'per'}) + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['override_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'per' + + +def test_execution_info_workflow_and_activity(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + def act(ctx, x): + # activity inbound metadata and execution info available + md = ctx.get_metadata() + ei = ctx.execution_info + assert md == {'m': 'v'} + assert ei is not None and ei.inbound_metadata == {'m': 'v'} + # activity_name should reflect the registered name + assert getattr(ei, 'activity_name', None) == 'act' + return x + + @rt.workflow(name='execinfo') + def execinfo(ctx, x): + # set default metadata + ctx.set_metadata({'m': 'v'}) + # workflow execution info available (minimal inbound only) + wi = ctx.execution_info + assert wi is not None and wi.inbound_metadata == {} + v = yield ctx.call_activity(act, input=42) + return v + + # register activity + rt.activity(name='act')(act) + orch = rt._WorkflowRuntime__worker._registry.orchestrators['execinfo'] + gen = orch(_FakeOrchCtx(), 7) + # drive one yield (call_activity) + gen.send(None) + # send back a value for activity result + with pytest.raises(StopIteration) as stop: + gen.send(42) + assert stop.value.value == 42 + + +def test_client_interceptor_can_shape_schedule_response(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + captured['name'] = name + return 'raw-id-123' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _ShapeId(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + rid = next(request) + return f'shaped:{rid}' + + client = DaprWorkflowClient(interceptors=[_ShapeId()]) + + def wf(ctx): + yield 'noop' + + wf.__name__ = 'shape_test' + iid = client.schedule_new_workflow(wf, input=None) + assert iid == 'shaped:raw-id-123' diff --git a/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py new file mode 100644 index 000000000..c76714e34 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py @@ -0,0 +1,201 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + self._continued_payload = None + self.workflow_attempt = None + + def call_activity(self, activity, *, input=None, retry_policy=None): + # return input back for assertion through driver + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator(self, wf, *, input=None, instance_id=None, retry_policy=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + def continue_as_new(self, new_request, *, save_events: bool = False): + # Record payload for assertions + self._continued_payload = new_request + + +def drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +class _InjectTrace(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + x = request.input + if x is None: + request = type(request)( + activity_name=request.activity_name, + input={'tracing': 'T'}, + retry_policy=request.retry_policy, + ) + elif isinstance(x, dict): + out = dict(x) + out.setdefault('tracing', 'T') + request = type(request)( + activity_name=request.activity_name, input=out, retry_policy=request.retry_policy + ) + return next(request) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input={'child': request.input}, + instance_id=request.instance_id, + ) + ) + + +def test_outbound_activity_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + @rt.workflow(name='w') + def w(ctx, x): + # schedule an activity; runtime should pass transformed input to durable task + y = yield ctx.call_activity(lambda: None, input={'a': 1}) + return y['tracing'] + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'tracing': 'T', 'a': 1}) + assert out == 'T' + + +def test_outbound_child_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + def child(ctx, x): + yield 'noop' + + @rt.workflow(name='parent') + def parent(ctx, x): + y = yield ctx.call_child_workflow(child, input={'b': 2}) + return y + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'child': {'b': 2}}) + assert out == {'child': {'b': 2}} + + +def test_outbound_continue_as_new_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _InjectCAN(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault('x', '1') + request.metadata = md + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectCAN()]) + + @rt.workflow(name='w2') + def w2(ctx, x): + ctx.continue_as_new({'p': 1}) + return 'unreached' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w2'] + fake = _FakeOrchCtx() + _ = orch(fake, 0) + # Verify envelope contains injected metadata + assert isinstance(fake._continued_payload, dict) + meta = fake._continued_payload.get('__dapr_meta__') + payload = fake._continued_payload.get('__dapr_payload__') + assert isinstance(meta, dict) and isinstance(payload, dict) + assert meta.get('metadata', {}).get('x') == '1' + assert payload == {'p': 1} diff --git a/ext/dapr-ext-workflow/tests/test_sandbox_gather.py b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py new file mode 100644 index 000000000..993fbdb5a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py @@ -0,0 +1,170 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta + +import pytest +from durabletask.aio.sandbox import SandboxMode + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.aio.sandbox import sandbox_scope + +""" +Tests for sandboxed asyncio.gather behavior in async orchestrators. +""" + + +class _FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'test-instance' + + def create_timer(self, fire_at): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + +def drive(gen, results): + try: + gen.send(None) + i = 0 + while True: + gen.send(results[i]) + i += 1 + except StopIteration as stop: + return stop.value + + +async def _plain(value): + return value + + +async def awf_empty(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + out = await asyncio.gather() + return out + + +def test_sandbox_gather_empty_returns_list(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_empty) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == [] + + +async def awf_when_all(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + b = ctx.wait_for_external_event('x') + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, b) + return res + + +def test_sandbox_gather_all_workflow_maps_to_when_all(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[[1, 2]]) + assert out == [1, 2] + + +async def awf_mixed(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _plain('ok')) + return res + + +def test_sandbox_gather_mixed_returns_sequential_results(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_mixed) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[123]) + assert out == [123, 'ok'] + + +async def awf_return_exceptions(ctx: AsyncWorkflowContext): + async def _boom(): + raise RuntimeError('x') + + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _boom(), return_exceptions=True) + return res + + +def test_sandbox_gather_return_exceptions(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_return_exceptions) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[321]) + assert isinstance(out[1], RuntimeError) + + +async def awf_multi_await(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + g = asyncio.gather() + a = await g + b = await g + return (a, b) + + +def test_sandbox_gather_multi_await_safe(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_multi_await) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == ([], []) + + +def test_sandbox_gather_restored_outside(): + import asyncio as aio + + original = aio.gather + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + pass + # After exit, gather should be restored + assert aio.gather is original + + +def test_strict_mode_blocks_create_task(): + import asyncio as aio + + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.STRICT): + if hasattr(aio, 'create_task'): + with pytest.raises(RuntimeError): + # Use a dummy coroutine to trigger the block + async def _c(): + return 1 + + aio.create_task(_c()) diff --git a/ext/dapr-ext-workflow/tests/test_trace_fields.py b/ext/dapr-ext-workflow/tests/test_trace_fields.py new file mode 100644 index 000000000..3b1f6ec89 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_trace_fields.py @@ -0,0 +1,81 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'wf-123' + self.current_utc_datetime = datetime(2025, 1, 1, tzinfo=timezone.utc) + self.is_replaying = False + self.workflow_name = 'wf_name' + self.parent_instance_id = 'parent-1' + self.history_event_sequence = 42 + self.trace_parent = '00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01' + self.trace_state = 'vendor=state' + self.orchestration_span_id = 'bbbbbbbbbbbbbbbb' + + +class _FakeActivityCtx: + def __init__(self): + self.orchestration_id = 'wf-123' + self.task_id = 7 + self.trace_parent = '00-cccccccccccccccccccccccccccccccc-dddddddddddddddd-01' + self.trace_state = 'v=1' + + +def test_dapr_workflow_context_trace_properties(): + base = _FakeOrchCtx() + ctx = DaprWorkflowContext(base) + + assert ctx.trace_parent == base.trace_parent + assert ctx.trace_state == base.trace_state + # SDK renames orchestration span id to workflow_span_id + assert ctx.workflow_span_id == base.orchestration_span_id + + +def test_async_workflow_context_trace_properties(): + base = _FakeOrchCtx() + actx = AsyncWorkflowContext(DaprWorkflowContext(base)) + + assert actx.trace_parent == base.trace_parent + assert actx.trace_state == base.trace_state + assert actx.workflow_span_id == base.orchestration_span_id + + +def test_workflow_execution_info_minimal(): + ei = WorkflowExecutionInfo(inbound_metadata={'k': 'v'}) + assert ei.inbound_metadata == {'k': 'v'} + + +def test_activity_execution_info_minimal(): + aei = ActivityExecutionInfo(inbound_metadata={'m': 'v'}) + assert aei.inbound_metadata == {'m': 'v'} + + +def test_workflow_activity_context_execution_info_trace_fields(): + base = _FakeActivityCtx() + actx = WorkflowActivityContext(base) + aei = ActivityExecutionInfo(inbound_metadata={}) + actx._set_execution_info(aei) + got = actx.execution_info + assert got is not None + assert got.inbound_metadata == {} diff --git a/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py new file mode 100644 index 000000000..35f933611 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py @@ -0,0 +1,171 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + RuntimeInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchestrationContext: + def __init__(self, *, is_replaying: bool = False): + self.instance_id = 'wf-1' + self.current_utc_datetime = datetime(2025, 1, 1) + self.is_replaying = is_replaying + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + +def _drive_generator(gen, returned_value): + # Prime to first yield; then drive + next(gen) + while True: + try: + gen.send(returned_value) + except StopIteration as stop: + return stop.value + + +def test_client_injects_tracing_on_schedule(monkeypatch): + import durabletask.client as client_mod + + # monkeypatch TaskHubGrpcClient to capture inputs + scheduled: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + scheduled['name'] = name + scheduled['input'] = input + scheduled['instance_id'] = instance_id + scheduled['start_at'] = start_at + scheduled['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _TracingClient(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': uuid.uuid4().hex} + if isinstance(request.input, dict) and 'tracing' not in request.input: + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + client = DaprWorkflowClient(interceptors=[_TracingClient()]) + + # We only need a callable with a __name__ for scheduling + def wf(ctx): + yield 'noop' + + wf.__name__ = 'inject_test' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + assert scheduled['name'] == 'inject_test' + assert isinstance(scheduled['input'], dict) + assert 'tracing' in scheduled['input'] + assert scheduled['input']['a'] == 1 + + +def test_runtime_restores_tracing_before_user_code(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _TracingRuntime(RuntimeInterceptor): + def execute_workflow(self, request, next): # type: ignore[override] + # no-op; real restoration is app concern; test just ensures input contains tracing + return next(request) + + def execute_activity(self, request, next): # type: ignore[override] + return next(request) + + class _TracingClient2(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': 't1'} + if isinstance(request.input, dict): + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + rt = WorkflowRuntime( + runtime_interceptors=[_TracingRuntime()], + ) + + @rt.workflow(name='w') + def w(ctx, x): + # The tracing should already be present in input + assert isinstance(x, dict) + assert 'tracing' in x + seen['trace'] = x['tracing'] + yield 'noop' + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + # Orchestrator input will have tracing injected via outbound when scheduled as a child or via client + # Here, we directly pass the input simulating schedule with tracing present + gen = orch(_FakeOrchestrationContext(), {'hello': 'world', 'tracing': {'trace_id': 't1'}}) + out = _drive_generator(gen, returned_value='noop') + assert out == 'ok' + assert seen['trace']['trace_id'] == 't1' diff --git a/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py b/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py index a45b8b7cd..ef4333fbb 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py @@ -15,14 +15,16 @@ import unittest from unittest import mock + from durabletask import task + from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext mock_orchestration_id = 'orchestration001' mock_task = 10 -class FakeActivityContext: +class _CompatFakeActivityContext: @property def orchestration_id(self): return mock_orchestration_id @@ -34,7 +36,9 @@ def task_id(self): class WorkflowActivityContextTest(unittest.TestCase): def test_workflow_activity_context(self): - with mock.patch('durabletask.task.ActivityContext', return_value=FakeActivityContext()): + with mock.patch( + 'durabletask.task.ActivityContext', return_value=_CompatFakeActivityContext() + ): fake_act_ctx = task.ActivityContext( orchestration_id=mock_orchestration_id, task_id=mock_task ) diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client.py b/ext/dapr-ext-workflow/tests/test_workflow_client.py index 540c0e801..e58dcb0d1 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client.py @@ -13,16 +13,18 @@ limitations under the License. """ +import unittest from datetime import datetime from typing import Any, Union -import unittest -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from unittest import mock -from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient -from durabletask import client + import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask import client from grpc import RpcError +from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + mock_schedule_result = 'workflow001' mock_raise_event_result = 'event001' mock_terminate_result = 'terminate001' diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index 02d6c6f3b..f367a43fd 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -13,12 +13,13 @@ limitations under the License. """ -from typing import List import unittest -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from typing import List from unittest import mock -from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name + +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name listOrchestrators: List[str] = [] listActivities: List[str] = [] @@ -36,7 +37,10 @@ class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() listOrchestrators.clear() - mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + self.patcher = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self.patcher.start() self.runtime_options = WorkflowRuntime() if hasattr(self.mock_client_wf, '_dapr_alternate_name'): del self.mock_client_wf.__dict__['_dapr_alternate_name'] @@ -47,6 +51,11 @@ def setUp(self): if hasattr(self.mock_client_activity, '_activity_registered'): del self.mock_client_activity.__dict__['_activity_registered'] + def tearDown(self): + """Stop the mock patch to prevent interference with other tests.""" + self.patcher.stop() + mock.patch.stopall() # Ensure all patches are stopped + def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') diff --git a/ext/dapr-ext-workflow/tests/test_workflow_util.py b/ext/dapr-ext-workflow/tests/test_workflow_util.py index 878ee7374..c76184382 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_util.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_util.py @@ -1,11 +1,25 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import unittest -from dapr.ext.workflow.util import getAddress from unittest.mock import patch from dapr.conf import settings +from dapr.ext.workflow.util import getAddress class DaprWorkflowUtilTest(unittest.TestCase): + @patch.object(settings, 'DAPR_GRPC_ENDPOINT', '') def test_get_address_default(self): expected = f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' self.assertEqual(expected, getAddress()) diff --git a/ext/flask_dapr/flask_dapr/app.py b/ext/flask_dapr/flask_dapr/app.py index c8d5def92..80e42220f 100644 --- a/ext/flask_dapr/flask_dapr/app.py +++ b/ext/flask_dapr/flask_dapr/app.py @@ -14,6 +14,7 @@ """ from typing import Dict, List, Optional + from flask import Flask, jsonify diff --git a/mypy.ini b/mypy.ini index 8c0fee4f0..7ca609e0a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,6 +12,7 @@ files = dapr/clients/**/*.py, dapr/conf/**/*.py, dapr/serializers/**/*.py, + ext/dapr-ext-workflow/dapr/ext/workflow/**/*.py, ext/dapr-ext-grpc/dapr/**/*.py, ext/dapr-ext-fastapi/dapr/**/*.py, ext/flask_dapr/flask_dapr/*.py, @@ -19,3 +20,6 @@ files = [mypy-dapr.proto.*] ignore_errors = True + +[mypy-dapr.ext.workflow.*] +python_version = 3.11 diff --git a/pyproject.toml b/pyproject.toml index 2b8ddf72e..49e164031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,12 @@ target-version = "py38" line-length = 100 fix = true -extend-exclude = [".github", "dapr/proto"] +extend-exclude = [ + ".github", + "dapr/proto", + "*_pb2.py", + "*_pb2_grpc.py", +] [tool.ruff.lint] select = [ "E", # pycodestyle errors diff --git a/tests/actor/fake_actor_classes.py b/tests/actor/fake_actor_classes.py index 50fe63fcf..04780f1dc 100644 --- a/tests/actor/fake_actor_classes.py +++ b/tests/actor/fake_actor_classes.py @@ -12,17 +12,16 @@ See the License for the specific language governing permissions and limitations under the License. """ -from dapr.serializers.json import DefaultJSONSerializer -import asyncio +import asyncio from datetime import timedelta from typing import Optional -from dapr.actor.runtime.actor import Actor -from dapr.actor.runtime.remindable import Remindable from dapr.actor.actor_interface import ActorInterface, actormethod - +from dapr.actor.runtime.actor import Actor from dapr.actor.runtime.reentrancy_context import reentrancy_ctx +from dapr.actor.runtime.remindable import Remindable +from dapr.serializers.json import DefaultJSONSerializer # Fake Simple Actor Class for testing diff --git a/tests/actor/fake_client.py b/tests/actor/fake_client.py index fa5fe1577..f349a63f7 100644 --- a/tests/actor/fake_client.py +++ b/tests/actor/fake_client.py @@ -13,9 +13,10 @@ limitations under the License. """ -from dapr.clients import DaprActorClientBase from typing import Optional +from dapr.clients import DaprActorClientBase + # Fake Dapr Actor Client Base Class for testing class FakeDaprActorClientBase(DaprActorClientBase): diff --git a/tests/actor/test_actor.py b/tests/actor/test_actor.py index d9b602c9d..7a7bee2d2 100644 --- a/tests/actor/test_actor.py +++ b/tests/actor/test_actor.py @@ -14,25 +14,22 @@ """ import unittest - -from unittest import mock from datetime import timedelta +from unittest import mock from dapr.actor.id import ActorId +from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime.config import ActorRuntimeConfig from dapr.actor.runtime.context import ActorRuntimeContext from dapr.actor.runtime.runtime import ActorRuntime -from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.conf import settings from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import ( + FakeMultiInterfacesActor, FakeSimpleActor, FakeSimpleReminderActor, FakeSimpleTimerActor, - FakeMultiInterfacesActor, ) - from tests.actor.fake_client import FakeDaprActorClient from tests.actor.utils import _async_mock, _run from tests.clients.fake_http_server import FakeHttpServer diff --git a/tests/actor/test_actor_factory.py b/tests/actor/test_actor_factory.py index 0715c33f4..4f629bb25 100644 --- a/tests/actor/test_actor_factory.py +++ b/tests/actor/test_actor_factory.py @@ -18,16 +18,13 @@ from dapr.actor import Actor from dapr.actor.id import ActorId from dapr.actor.runtime._type_information import ActorTypeInformation -from dapr.actor.runtime.manager import ActorManager from dapr.actor.runtime.context import ActorRuntimeContext +from dapr.actor.runtime.manager import ActorManager from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import ( FakeSimpleActorInterface, ) - from tests.actor.fake_client import FakeDaprActorClient - from tests.actor.utils import _run diff --git a/tests/actor/test_actor_manager.py b/tests/actor/test_actor_manager.py index 6c21abfb7..af0e2e410 100644 --- a/tests/actor/test_actor_manager.py +++ b/tests/actor/test_actor_manager.py @@ -19,19 +19,16 @@ from dapr.actor.id import ActorId from dapr.actor.runtime._type_information import ActorTypeInformation -from dapr.actor.runtime.manager import ActorManager from dapr.actor.runtime.context import ActorRuntimeContext +from dapr.actor.runtime.manager import ActorManager from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import ( FakeMultiInterfacesActor, FakeSimpleActor, FakeSimpleReminderActor, FakeSimpleTimerActor, ) - from tests.actor.fake_client import FakeDaprActorClient - from tests.actor.utils import ( _async_mock, _run, diff --git a/tests/actor/test_actor_reentrancy.py b/tests/actor/test_actor_reentrancy.py index 834273f41..a440d2750 100644 --- a/tests/actor/test_actor_reentrancy.py +++ b/tests/actor/test_actor_reentrancy.py @@ -13,22 +13,19 @@ limitations under the License. """ -import unittest import asyncio - +import unittest from unittest import mock +from dapr.actor.runtime.config import ActorReentrancyConfig, ActorRuntimeConfig from dapr.actor.runtime.runtime import ActorRuntime -from dapr.actor.runtime.config import ActorRuntimeConfig, ActorReentrancyConfig from dapr.conf import settings from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import ( - FakeReentrantActor, FakeMultiInterfacesActor, + FakeReentrantActor, FakeSlowReentrantActor, ) - from tests.actor.utils import _run from tests.clients.fake_http_server import FakeHttpServer @@ -212,9 +209,10 @@ async def expected_return_value(*args, **kwargs): _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) def test_parse_incoming_reentrancy_header_flask(self): - from ext.flask_dapr import flask_dapr from flask import Flask + from ext.flask_dapr import flask_dapr + app = Flask(f'{FakeReentrantActor.__name__}Service') flask_dapr.DaprActor(app) @@ -246,6 +244,7 @@ def test_parse_incoming_reentrancy_header_flask(self): def test_parse_incoming_reentrancy_header_fastapi(self): from fastapi import FastAPI from fastapi.testclient import TestClient + from dapr.ext import fastapi app = FastAPI(title=f'{FakeReentrantActor.__name__}Service') diff --git a/tests/actor/test_actor_runtime.py b/tests/actor/test_actor_runtime.py index f17f96cc8..7725c3728 100644 --- a/tests/actor/test_actor_runtime.py +++ b/tests/actor/test_actor_runtime.py @@ -14,20 +14,17 @@ """ import unittest - from datetime import timedelta -from dapr.actor.runtime.runtime import ActorRuntime from dapr.actor.runtime.config import ActorRuntimeConfig +from dapr.actor.runtime.runtime import ActorRuntime from dapr.conf import settings from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import ( - FakeSimpleActor, FakeMultiInterfacesActor, + FakeSimpleActor, FakeSimpleTimerActor, ) - from tests.actor.utils import _run from tests.clients.fake_http_server import FakeHttpServer diff --git a/tests/actor/test_actor_runtime_config.py b/tests/actor/test_actor_runtime_config.py index 7bbd8cefc..e39894c77 100644 --- a/tests/actor/test_actor_runtime_config.py +++ b/tests/actor/test_actor_runtime_config.py @@ -14,9 +14,9 @@ """ import unittest - from datetime import timedelta -from dapr.actor.runtime.config import ActorRuntimeConfig, ActorReentrancyConfig, ActorTypeConfig + +from dapr.actor.runtime.config import ActorReentrancyConfig, ActorRuntimeConfig, ActorTypeConfig class ActorTypeConfigTests(unittest.TestCase): diff --git a/tests/actor/test_client_proxy.py b/tests/actor/test_client_proxy.py index fe667d629..172e5d283 100644 --- a/tests/actor/test_client_proxy.py +++ b/tests/actor/test_client_proxy.py @@ -12,22 +12,18 @@ See the License for the specific language governing permissions and limitations under the License. """ -import unittest +import unittest from unittest import mock - -from dapr.actor.id import ActorId from dapr.actor.client.proxy import ActorProxy +from dapr.actor.id import ActorId from dapr.serializers import DefaultJSONSerializer from tests.actor.fake_actor_classes import ( - FakeMultiInterfacesActor, FakeActorCls2Interface, + FakeMultiInterfacesActor, ) - - from tests.actor.fake_client import FakeDaprActorClient - from tests.actor.utils import _async_mock, _run diff --git a/tests/actor/test_method_dispatcher.py b/tests/actor/test_method_dispatcher.py index 94f48a7b6..a32fba455 100644 --- a/tests/actor/test_method_dispatcher.py +++ b/tests/actor/test_method_dispatcher.py @@ -15,11 +15,10 @@ import unittest +from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime.context import ActorRuntimeContext from dapr.actor.runtime.method_dispatcher import ActorMethodDispatcher -from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import FakeSimpleActor from tests.actor.fake_client import FakeDaprActorClient from tests.actor.utils import _run diff --git a/tests/actor/test_state_manager.py b/tests/actor/test_state_manager.py index c9406dbd2..11a7c4f08 100644 --- a/tests/actor/test_state_manager.py +++ b/tests/actor/test_state_manager.py @@ -15,19 +15,16 @@ import base64 import unittest - from unittest import mock from dapr.actor.id import ActorId +from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.actor.runtime.context import ActorRuntimeContext from dapr.actor.runtime.state_change import StateChangeKind from dapr.actor.runtime.state_manager import ActorStateManager -from dapr.actor.runtime._type_information import ActorTypeInformation from dapr.serializers import DefaultJSONSerializer - from tests.actor.fake_actor_classes import FakeSimpleActor from tests.actor.fake_client import FakeDaprActorClient - from tests.actor.utils import _async_mock, _run diff --git a/tests/actor/test_timer_data.py b/tests/actor/test_timer_data.py index ba410cecd..8a193f416 100644 --- a/tests/actor/test_timer_data.py +++ b/tests/actor/test_timer_data.py @@ -13,9 +13,9 @@ limitations under the License. """ -from typing import Any import unittest from datetime import timedelta +from typing import Any from dapr.actor.runtime._timer_data import ActorTimerData diff --git a/tests/actor/test_type_information.py b/tests/actor/test_type_information.py index 1532e3956..201eb87fb 100644 --- a/tests/actor/test_type_information.py +++ b/tests/actor/test_type_information.py @@ -17,10 +17,10 @@ from dapr.actor.runtime._type_information import ActorTypeInformation from tests.actor.fake_actor_classes import ( - FakeSimpleActor, - FakeMultiInterfacesActor, FakeActorCls1Interface, FakeActorCls2Interface, + FakeMultiInterfacesActor, + FakeSimpleActor, ReentrantActorInterface, ) diff --git a/tests/actor/test_type_utils.py b/tests/actor/test_type_utils.py index f8b2eee2a..6b2a9319b 100644 --- a/tests/actor/test_type_utils.py +++ b/tests/actor/test_type_utils.py @@ -17,19 +17,18 @@ from dapr.actor.actor_interface import ActorInterface from dapr.actor.runtime._type_utils import ( + get_actor_interfaces, get_class_method_args, + get_dispatchable_attrs, get_method_arg_types, get_method_return_types, is_dapr_actor, - get_actor_interfaces, - get_dispatchable_attrs, ) - from tests.actor.fake_actor_classes import ( - FakeSimpleActor, - FakeMultiInterfacesActor, FakeActorCls1Interface, FakeActorCls2Interface, + FakeMultiInterfacesActor, + FakeSimpleActor, ) diff --git a/tests/clients/certs.py b/tests/clients/certs.py index a30b25312..9d851ca46 100644 --- a/tests/clients/certs.py +++ b/tests/clients/certs.py @@ -1,7 +1,7 @@ import os import ssl -import grpc +import grpc from OpenSSL import crypto diff --git a/tests/clients/fake_dapr_server.py b/tests/clients/fake_dapr_server.py index a1cbeb4b7..a1ee695eb 100644 --- a/tests/clients/fake_dapr_server.py +++ b/tests/clients/fake_dapr_server.py @@ -1,48 +1,47 @@ -import grpc import json - from concurrent import futures -from google.protobuf.any_pb2 import Any as GrpcAny +from typing import Dict + +import grpc from google.protobuf import empty_pb2, struct_pb2 -from google.rpc import status_pb2, code_pb2 +from google.protobuf.any_pb2 import Any as GrpcAny +from google.rpc import code_pb2, status_pb2 from grpc_status import rpc_status from dapr.clients.grpc._helpers import to_bytes -from dapr.proto import api_service_v1, common_v1, api_v1, appcallback_v1 -from dapr.proto.common.v1.common_pb2 import ConfigurationItem from dapr.clients.grpc._response import WorkflowRuntimeStatus +from dapr.proto import api_service_v1, api_v1, appcallback_v1, common_v1 +from dapr.proto.common.v1.common_pb2 import ConfigurationItem from dapr.proto.runtime.v1.dapr_pb2 import ( ActiveActorsCount, + ConversationResponseAlpha2, + ConversationResultAlpha2, + ConversationResultChoices, + ConversationResultMessage, + ConversationToolCalls, + ConversationToolCallsOfFunction, + DecryptRequest, + DecryptResponse, + EncryptRequest, + EncryptResponse, GetMetadataResponse, + GetWorkflowRequest, + GetWorkflowResponse, + PauseWorkflowRequest, + PurgeWorkflowRequest, QueryStateItem, + RaiseEventWorkflowRequest, RegisteredComponents, + ResumeWorkflowRequest, SetMetadataRequest, + StartWorkflowRequest, + StartWorkflowResponse, + TerminateWorkflowRequest, TryLockRequest, TryLockResponse, UnlockRequest, UnlockResponse, - StartWorkflowRequest, - StartWorkflowResponse, - GetWorkflowRequest, - GetWorkflowResponse, - PauseWorkflowRequest, - ResumeWorkflowRequest, - TerminateWorkflowRequest, - PurgeWorkflowRequest, - RaiseEventWorkflowRequest, - EncryptRequest, - EncryptResponse, - DecryptRequest, - DecryptResponse, - ConversationResultAlpha2, - ConversationResultChoices, - ConversationResultMessage, - ConversationResponseAlpha2, - ConversationToolCalls, - ConversationToolCallsOfFunction, ) -from typing import Dict - from tests.clients.certs import GrpcCerts from tests.clients.fake_http_server import FakeHttpServer diff --git a/tests/clients/fake_http_server.py b/tests/clients/fake_http_server.py index e08e82d29..8476b18ba 100644 --- a/tests/clients/fake_http_server.py +++ b/tests/clients/fake_http_server.py @@ -1,8 +1,7 @@ import time +from http.server import BaseHTTPRequestHandler, HTTPServer from ssl import PROTOCOL_TLS_SERVER, SSLContext - from threading import Thread -from http.server import BaseHTTPRequestHandler, HTTPServer from tests.clients.certs import HttpCerts diff --git a/tests/clients/test_conversation.py b/tests/clients/test_conversation.py index 8a6cc697e..50daebc64 100644 --- a/tests/clients/test_conversation.py +++ b/tests/clients/test_conversation.py @@ -13,7 +13,6 @@ limitations under the License. """ - import asyncio import json import unittest @@ -33,7 +32,14 @@ from dapr.clients.grpc.conversation import ( ConversationInput, ConversationInputAlpha2, + ConversationMessage, + ConversationMessageOfAssistant, ConversationResponseAlpha2, + ConversationResultAlpha2, + ConversationResultAlpha2Choices, + ConversationResultAlpha2Message, + ConversationToolCalls, + ConversationToolCallsOfFunction, ConversationTools, ConversationToolsFunction, FunctionBackend, @@ -41,18 +47,11 @@ create_system_message, create_tool_message, create_user_message, + execute_registered_tool, execute_registered_tool_async, get_registered_tools, register_tool, unregister_tool, - ConversationResultAlpha2Message, - ConversationResultAlpha2Choices, - ConversationResultAlpha2, - ConversationMessage, - ConversationMessageOfAssistant, - ConversationToolCalls, - ConversationToolCallsOfFunction, - execute_registered_tool, ) from dapr.clients.grpc.conversation import ( tool as tool_decorator, @@ -1010,7 +1009,7 @@ def test_multiline_example(self): def test_zero_indent(self): result = conversation._indent_lines('Title', 'Line one\nLine two', 0) - expected = 'Title: Line one\n' ' Line two' + expected = 'Title: Line one\n Line two' self.assertEqual(result, expected) def test_empty_string(self): @@ -1026,7 +1025,7 @@ def test_title_length_affects_indent(self): # Title length is 1, indent_after_first_line should be indent + len(title) + 2 # indent=2, len(title)=1 => 2 + 1 + 2 = 5 spaces on continuation lines result = conversation._indent_lines('T', 'a\nb', 2) - expected = ' T: a\n' ' b' + expected = ' T: a\n b' self.assertEqual(result, expected) diff --git a/tests/clients/test_conversation_helpers.py b/tests/clients/test_conversation_helpers.py index 62f2f69ae..73dc17270 100644 --- a/tests/clients/test_conversation_helpers.py +++ b/tests/clients/test_conversation_helpers.py @@ -12,37 +12,39 @@ See the License for the specific language governing permissions and limitations under the License. """ + +import base64 import io import json -import base64 import unittest import warnings from contextlib import redirect_stdout from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union, Set -from dapr.conf import settings +from typing import Any, Dict, List, Literal, Optional, Set, Union + from dapr.clients.grpc._conversation_helpers import ( - stringify_tool_output, - bind_params_to_func, - function_to_json_schema, + ToolArgumentError, _extract_docstring_args, _python_type_to_json_schema, + bind_params_to_func, extract_docstring_summary, - ToolArgumentError, + function_to_json_schema, + stringify_tool_output, ) from dapr.clients.grpc.conversation import ( - ConversationToolsFunction, - ConversationMessageOfUser, + ConversationMessage, ConversationMessageContent, - ConversationToolCalls, - ConversationToolCallsOfFunction, ConversationMessageOfAssistant, - ConversationMessageOfTool, - ConversationMessage, ConversationMessageOfDeveloper, ConversationMessageOfSystem, + ConversationMessageOfTool, + ConversationMessageOfUser, + ConversationToolCalls, + ConversationToolCallsOfFunction, + ConversationToolsFunction, ) +from dapr.conf import settings def test_string_passthrough(): diff --git a/tests/clients/test_dapr_grpc_client.py b/tests/clients/test_dapr_grpc_client.py index e0713f703..d4c642858 100644 --- a/tests/clients/test_dapr_grpc_client.py +++ b/tests/clients/test_dapr_grpc_client.py @@ -13,43 +13,43 @@ limitations under the License. """ +import asyncio import json import socket import tempfile import time import unittest import uuid -import asyncio - from unittest.mock import patch -from google.rpc import status_pb2, code_pb2 +from google.rpc import code_pb2, status_pb2 -from dapr.clients.exceptions import DaprGrpcError -from dapr.clients.grpc.client import DaprGrpcClient from dapr.clients import DaprClient -from dapr.clients.grpc.subscription import StreamInactiveError -from dapr.proto import common_v1 -from .fake_dapr_server import FakeDaprSidecar -from dapr.conf import settings +from dapr.clients.exceptions import DaprGrpcError +from dapr.clients.grpc import conversation +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.clients.grpc._helpers import to_bytes +from dapr.clients.grpc._jobs import Job from dapr.clients.grpc._request import ( TransactionalStateOperation, TransactionOperationType, ) -from dapr.clients.grpc._jobs import Job -from dapr.clients.grpc._state import StateOptions, Consistency, Concurrency, StateItem -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions from dapr.clients.grpc._response import ( ConfigurationItem, ConfigurationResponse, ConfigurationWatcher, DaprResponse, + TopicEventResponse, UnlockResponseStatus, WorkflowRuntimeStatus, - TopicEventResponse, ) -from dapr.clients.grpc import conversation +from dapr.clients.grpc._state import Concurrency, Consistency, StateItem, StateOptions +from dapr.clients.grpc.client import DaprGrpcClient +from dapr.clients.grpc.subscription import StreamInactiveError +from dapr.conf import settings +from dapr.proto import common_v1 + +from .fake_dapr_server import FakeDaprSidecar class DaprGrpcClientTests(unittest.TestCase): @@ -1694,7 +1694,7 @@ def test_delete_job_alpha1_validation_error(self): def test_jobs_error_handling(self): """Test error handling for Jobs API using fake server's exception mechanism.""" - from google.rpc import status_pb2, code_pb2 + from google.rpc import code_pb2, status_pb2 dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.grpc_port}') diff --git a/tests/clients/test_dapr_grpc_client_async.py b/tests/clients/test_dapr_grpc_client_async.py index 50043912d..8109aa21f 100644 --- a/tests/clients/test_dapr_grpc_client_async.py +++ b/tests/clients/test_dapr_grpc_client_async.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import json import socket import tempfile @@ -19,28 +20,29 @@ import uuid from unittest.mock import patch -from google.rpc import status_pb2, code_pb2 +from google.rpc import code_pb2, status_pb2 -from dapr.aio.clients.grpc.client import DaprGrpcClientAsync from dapr.aio.clients import DaprClient +from dapr.aio.clients.grpc.client import DaprGrpcClientAsync from dapr.clients.exceptions import DaprGrpcError -from dapr.common.pubsub.subscription import StreamInactiveError -from dapr.proto import common_v1 -from .fake_dapr_server import FakeDaprSidecar -from dapr.conf import settings -from dapr.clients.grpc._helpers import to_bytes -from dapr.clients.grpc._request import TransactionalStateOperation from dapr.clients.grpc import conversation +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions +from dapr.clients.grpc._helpers import to_bytes from dapr.clients.grpc._jobs import Job -from dapr.clients.grpc._state import StateOptions, Consistency, Concurrency, StateItem -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions +from dapr.clients.grpc._request import TransactionalStateOperation from dapr.clients.grpc._response import ( ConfigurationItem, - ConfigurationWatcher, ConfigurationResponse, + ConfigurationWatcher, DaprResponse, UnlockResponseStatus, ) +from dapr.clients.grpc._state import Concurrency, Consistency, StateItem, StateOptions +from dapr.common.pubsub.subscription import StreamInactiveError +from dapr.conf import settings +from dapr.proto import common_v1 + +from .fake_dapr_server import FakeDaprSidecar class DaprGrpcClientAsyncTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/clients/test_dapr_grpc_client_async_secure.py b/tests/clients/test_dapr_grpc_client_async_secure.py index 652feac20..a49fe5fc0 100644 --- a/tests/clients/test_dapr_grpc_client_async_secure.py +++ b/tests/clients/test_dapr_grpc_client_async_secure.py @@ -14,16 +14,15 @@ """ import unittest - from unittest.mock import patch from dapr.aio.clients.grpc.client import DaprGrpcClientAsync from dapr.clients.health import DaprHealth +from dapr.conf import settings from tests.clients.certs import replacement_get_credentials_func, replacement_get_health_context from tests.clients.test_dapr_grpc_client_async import DaprGrpcClientAsyncTests -from .fake_dapr_server import FakeDaprSidecar -from dapr.conf import settings +from .fake_dapr_server import FakeDaprSidecar DaprGrpcClientAsync.get_credentials = replacement_get_credentials_func DaprHealth.get_ssl_context = replacement_get_health_context diff --git a/tests/clients/test_dapr_grpc_client_secure.py b/tests/clients/test_dapr_grpc_client_secure.py index 41dedca1a..2a6710403 100644 --- a/tests/clients/test_dapr_grpc_client_secure.py +++ b/tests/clients/test_dapr_grpc_client_secure.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import unittest from unittest.mock import patch @@ -19,8 +20,8 @@ from dapr.clients.health import DaprHealth from dapr.conf import settings from tests.clients.certs import replacement_get_credentials_func, replacement_get_health_context - from tests.clients.test_dapr_grpc_client import DaprGrpcClientTests + from .fake_dapr_server import FakeDaprSidecar diff --git a/tests/clients/test_dapr_grpc_helpers.py b/tests/clients/test_dapr_grpc_helpers.py index 9e794aab7..6c7c27be9 100644 --- a/tests/clients/test_dapr_grpc_helpers.py +++ b/tests/clients/test_dapr_grpc_helpers.py @@ -1,22 +1,22 @@ import base64 import unittest -from google.protobuf.struct_pb2 import Struct from google.protobuf import json_format -from google.protobuf.json_format import ParseError from google.protobuf.any_pb2 import Any as GrpcAny +from google.protobuf.json_format import ParseError +from google.protobuf.struct_pb2 import Struct from google.protobuf.wrappers_pb2 import ( BoolValue, - StringValue, + BytesValue, + DoubleValue, Int32Value, Int64Value, - DoubleValue, - BytesValue, + StringValue, ) from dapr.clients.grpc._helpers import ( - convert_value_to_struct, convert_dict_to_grpc_dict_of_any, + convert_value_to_struct, ) diff --git a/tests/clients/test_dapr_grpc_request.py b/tests/clients/test_dapr_grpc_request.py index 98d8e2005..396a8ec95 100644 --- a/tests/clients/test_dapr_grpc_request.py +++ b/tests/clients/test_dapr_grpc_request.py @@ -16,13 +16,13 @@ import io import unittest +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.clients.grpc._request import ( - InvokeMethodRequest, BindingRequest, - EncryptRequestIterator, DecryptRequestIterator, + EncryptRequestIterator, + InvokeMethodRequest, ) -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions from dapr.proto import api_v1, common_v1 diff --git a/tests/clients/test_dapr_grpc_request_async.py b/tests/clients/test_dapr_grpc_request_async.py index 75fe74fce..7782fecdf 100644 --- a/tests/clients/test_dapr_grpc_request_async.py +++ b/tests/clients/test_dapr_grpc_request_async.py @@ -16,8 +16,8 @@ import io import unittest -from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions -from dapr.aio.clients.grpc._request import EncryptRequestIterator, DecryptRequestIterator +from dapr.aio.clients.grpc._request import DecryptRequestIterator, EncryptRequestIterator +from dapr.clients.grpc._crypto import DecryptOptions, EncryptOptions from dapr.proto import api_v1 diff --git a/tests/clients/test_dapr_grpc_response.py b/tests/clients/test_dapr_grpc_response.py index 1c91805eb..c2fe237f9 100644 --- a/tests/clients/test_dapr_grpc_response.py +++ b/tests/clients/test_dapr_grpc_response.py @@ -18,15 +18,14 @@ from google.protobuf.any_pb2 import Any as GrpcAny from dapr.clients.grpc._response import ( - DaprResponse, - InvokeMethodResponse, BindingResponse, - StateResponse, BulkStateItem, - EncryptResponse, + DaprResponse, DecryptResponse, + EncryptResponse, + InvokeMethodResponse, + StateResponse, ) - from dapr.proto import api_v1, common_v1 diff --git a/tests/clients/test_dapr_grpc_response_async.py b/tests/clients/test_dapr_grpc_response_async.py index 2626cbf41..02b09716f 100644 --- a/tests/clients/test_dapr_grpc_response_async.py +++ b/tests/clients/test_dapr_grpc_response_async.py @@ -15,7 +15,7 @@ import unittest -from dapr.aio.clients.grpc._response import EncryptResponse, DecryptResponse +from dapr.aio.clients.grpc._response import DecryptResponse, EncryptResponse from dapr.proto import api_v1, common_v1 diff --git a/tests/clients/test_exceptions.py b/tests/clients/test_exceptions.py index 08eea4d53..e8b4c6d9f 100644 --- a/tests/clients/test_exceptions.py +++ b/tests/clients/test_exceptions.py @@ -3,9 +3,9 @@ import unittest import grpc -from google.rpc import error_details_pb2, status_pb2, code_pb2 from google.protobuf.any_pb2 import Any from google.protobuf.duration_pb2 import Duration +from google.rpc import code_pb2, error_details_pb2, status_pb2 from dapr.clients import DaprGrpcClient from dapr.clients.exceptions import DaprGrpcError, DaprInternalError diff --git a/tests/clients/test_heatlhcheck.py b/tests/clients/test_heatlhcheck.py index f3be8a475..1f533dc96 100644 --- a/tests/clients/test_heatlhcheck.py +++ b/tests/clients/test_heatlhcheck.py @@ -12,9 +12,10 @@ See the License for the specific language governing permissions and limitations under the License. """ + import time import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from dapr.clients.health import DaprHealth from dapr.conf import settings diff --git a/tests/clients/test_http_helpers.py b/tests/clients/test_http_helpers.py index ab173cd73..abf284dbe 100644 --- a/tests/clients/test_http_helpers.py +++ b/tests/clients/test_http_helpers.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import patch -from dapr.conf import settings from dapr.clients.http.helpers import get_api_url +from dapr.conf import settings class DaprHttpClientHelpersTests(unittest.TestCase): diff --git a/tests/clients/test_http_service_invocation_client.py b/tests/clients/test_http_service_invocation_client.py index c0b43a863..a0a7aadd6 100644 --- a/tests/clients/test_http_service_invocation_client.py +++ b/tests/clients/test_http_service_invocation_client.py @@ -24,13 +24,12 @@ from opentelemetry.sdk.trace.sampling import ALWAYS_ON from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - +from dapr.clients import DaprClient from dapr.clients.exceptions import DaprInternalError from dapr.conf import settings from dapr.proto import common_v1 from .fake_http_server import FakeHttpServer -from dapr.clients import DaprClient class DaprInvocationHttpClientTests(unittest.TestCase): diff --git a/tests/clients/test_jobs.py b/tests/clients/test_jobs.py index fe3d70b53..645d43256 100644 --- a/tests/clients/test_jobs.py +++ b/tests/clients/test_jobs.py @@ -5,9 +5,10 @@ """ import unittest + from google.protobuf.any_pb2 import Any as GrpcAny -from dapr.clients.grpc._jobs import Job, DropFailurePolicy, ConstantFailurePolicy +from dapr.clients.grpc._jobs import ConstantFailurePolicy, DropFailurePolicy, Job from dapr.proto.runtime.v1 import dapr_pb2 as api_v1 diff --git a/tests/clients/test_retries_policy.py b/tests/clients/test_retries_policy.py index b5137e643..d4a383fc1 100644 --- a/tests/clients/test_retries_policy.py +++ b/tests/clients/test_retries_policy.py @@ -12,11 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. """ + import unittest from unittest import mock -from unittest.mock import Mock, MagicMock, patch, AsyncMock +from unittest.mock import AsyncMock, MagicMock, Mock, patch -from grpc import StatusCode, RpcError +from grpc import RpcError, StatusCode from dapr.clients.retry import RetryPolicy from dapr.serializers import DefaultJSONSerializer diff --git a/tests/clients/test_retries_policy_async.py b/tests/clients/test_retries_policy_async.py index ebe6865db..2b35c35c4 100644 --- a/tests/clients/test_retries_policy_async.py +++ b/tests/clients/test_retries_policy_async.py @@ -12,11 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. """ + import unittest from unittest import mock -from unittest.mock import MagicMock, patch, AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch -from grpc import StatusCode, RpcError +from grpc import RpcError, StatusCode from dapr.clients.retry import RetryPolicy diff --git a/tests/clients/test_secure_http_service_invocation_client.py b/tests/clients/test_secure_http_service_invocation_client.py index 4d1bdda1f..df13d8197 100644 --- a/tests/clients/test_secure_http_service_invocation_client.py +++ b/tests/clients/test_secure_http_service_invocation_client.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import ssl import typing from asyncio import TimeoutError @@ -29,8 +30,7 @@ from dapr.conf import settings from dapr.proto import common_v1 - -from .certs import replacement_get_health_context, replacement_get_credentials_func, GrpcCerts +from .certs import GrpcCerts, replacement_get_credentials_func, replacement_get_health_context from .fake_http_server import FakeHttpServer from .test_http_service_invocation_client import DaprInvocationHttpClientTests diff --git a/tests/clients/test_subscription.py b/tests/clients/test_subscription.py index ed2eae3fa..21018aaac 100644 --- a/tests/clients/test_subscription.py +++ b/tests/clients/test_subscription.py @@ -1,8 +1,9 @@ -from dapr.clients.grpc.subscription import SubscriptionMessage -from dapr.proto.runtime.v1.appcallback_pb2 import TopicEventRequest +import unittest + from google.protobuf.struct_pb2 import Struct -import unittest +from dapr.clients.grpc.subscription import SubscriptionMessage +from dapr.proto.runtime.v1.appcallback_pb2 import TopicEventRequest class SubscriptionMessageTests(unittest.TestCase): diff --git a/tests/clients/test_timeout_interceptor.py b/tests/clients/test_timeout_interceptor.py index 79859b2e5..c60331bed 100644 --- a/tests/clients/test_timeout_interceptor.py +++ b/tests/clients/test_timeout_interceptor.py @@ -15,6 +15,7 @@ import unittest from unittest.mock import Mock, patch + from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor from dapr.conf import settings diff --git a/tests/clients/test_timeout_interceptor_async.py b/tests/clients/test_timeout_interceptor_async.py index d057df9fc..88b5831dc 100644 --- a/tests/clients/test_timeout_interceptor_async.py +++ b/tests/clients/test_timeout_interceptor_async.py @@ -15,6 +15,7 @@ import unittest from unittest.mock import Mock, patch + from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync from dapr.conf import settings diff --git a/tests/serializers/test_default_json_serializer.py b/tests/serializers/test_default_json_serializer.py index 86e727ad0..8f65595c0 100644 --- a/tests/serializers/test_default_json_serializer.py +++ b/tests/serializers/test_default_json_serializer.py @@ -13,8 +13,8 @@ limitations under the License. """ -import unittest import datetime +import unittest from dapr.serializers.json import DefaultJSONSerializer diff --git a/tests/serializers/test_util.py b/tests/serializers/test_util.py index 9f3b9e026..25124fdf6 100644 --- a/tests/serializers/test_util.py +++ b/tests/serializers/test_util.py @@ -13,12 +13,12 @@ limitations under the License. """ -import unittest import json +import unittest from datetime import timedelta -from dapr.serializers.util import convert_from_dapr_duration, convert_to_dapr_duration from dapr.serializers.json import DaprJSONDecoder +from dapr.serializers.util import convert_from_dapr_duration, convert_to_dapr_duration class UtilTests(unittest.TestCase): diff --git a/tox.ini b/tox.ini index ebd403c3f..837a1b8d8 100644 --- a/tox.ini +++ b/tox.ini @@ -6,37 +6,55 @@ envlist = flake8, ruff, mypy, +runner = virtualenv # TODO: switch to uv (tox-uv plugin) [testenv] setenv = PYTHONDONTWRITEBYTECODE=1 deps = -rdev-requirements.txt +package = editable commands = coverage run -m unittest discover -v ./tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-workflow/tests + # ext/dapr-ext-workflow uses pytest-based tests + coverage run -a -m pytest -q ext/dapr-ext-workflow/tests coverage run -a -m unittest discover -v ./ext/dapr-ext-grpc/tests coverage run -a -m unittest discover -v ./ext/dapr-ext-fastapi/tests coverage run -a -m unittest discover -v ./ext/flask_dapr/tests coverage xml commands_pre = - pip3 install -e {toxinidir}/ - pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ - pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ - pip3 install -e {toxinidir}/ext/dapr-ext-fastapi/ - pip3 install -e {toxinidir}/ext/flask_dapr/ + # TODO: remove this before merging (after durable task is merged) + {envpython} -m pip install -e {toxinidir}/../durabletask-python/ + + {envpython} -m pip install -e {toxinidir}/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-workflow/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-grpc/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-fastapi/ + {envpython} -m pip install -e {toxinidir}/ext/flask_dapr/ +# allow for overriding sidecar ports +pass_env = DAPR_GRPC_ENDPOINT,DAPR_HTTP_ENDPOINT,DAPR_RUNTIME_HOST,DAPR_GRPC_PORT,DAPR_HTTP_PORT,DURABLETASK_GRPC_ENDPOINT + +[flake8] +extend-exclude = .tox,venv,build,dist,dapr/proto,examples/**/.venv +ignore = E203,E501,W503,E701,E704,F821 +max-line-length = 100 [testenv:flake8] basepython = python3 usedevelop = False -deps = flake8 +deps = + flake8==7.3.0 + pip commands = - flake8 . + flake8 . --config={toxinidir}/tox.ini [testenv:ruff] basepython = python3 usedevelop = False -deps = ruff==0.2.2 +deps = + ruff==0.2.2 # TODO: upgrade to 0.13.3 + pip commands = + ruff check --select I --fix ruff format [testenv:examples] @@ -69,6 +87,8 @@ commands = ./validate.sh jobs ./validate.sh ../ commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ @@ -101,6 +121,8 @@ deps = -rdev-requirements.txt commands = mypy --config-file mypy.ini commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/