44
55from pydantic import BaseModel
66
7+ from crawlee ._utils .raise_if_too_many_kwargs import raise_if_too_many_kwargs
78from crawlee .events ._types import Event , EventPersistStateData
89
910if TYPE_CHECKING :
1011 import logging
12+ from collections .abc import Callable , Coroutine
1113
12- from crawlee .storages . _key_value_store import KeyValueStore
14+ from crawlee .storages import KeyValueStore
1315
1416TStateModel = TypeVar ('TStateModel' , bound = BaseModel )
1517
@@ -37,6 +39,7 @@ def __init__(
3739 persistence_enabled : Literal [True , False , 'explicit_only' ] = False ,
3840 persist_state_kvs_name : str | None = None ,
3941 persist_state_kvs_id : str | None = None ,
42+ persist_state_kvs_factory : Callable [[], Coroutine [None , None , KeyValueStore ]] | None = None ,
4043 logger : logging .Logger ,
4144 ) -> None :
4245 """Initialize a new recoverable state object.
@@ -51,16 +54,40 @@ def __init__(
5154 If neither a name nor and id are supplied, the default store will be used.
5255 persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
5356 If neither a name nor and id are supplied, the default store will be used.
57+ persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
58+ not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
5459 logger: A logger instance for logging operations related to state persistence
5560 """
61+ raise_if_too_many_kwargs (
62+ persist_state_kvs_name = persist_state_kvs_name ,
63+ persist_state_kvs_id = persist_state_kvs_id ,
64+ persist_state_kvs_factory = persist_state_kvs_factory ,
65+ )
66+ if not persist_state_kvs_factory :
67+ logger .debug (
68+ 'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
69+ 'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
70+ 'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
71+ 'global side effects.'
72+ )
73+
5674 self ._default_state = default_state
5775 self ._state_type : type [TStateModel ] = self ._default_state .__class__
5876 self ._state : TStateModel | None = None
5977 self ._persistence_enabled = persistence_enabled
6078 self ._persist_state_key = persist_state_key
61- self ._persist_state_kvs_name = persist_state_kvs_name
62- self ._persist_state_kvs_id = persist_state_kvs_id
63- self ._key_value_store : 'KeyValueStore | None' = None # noqa: UP037
79+ if persist_state_kvs_factory is None :
80+
81+ async def kvs_factory () -> KeyValueStore :
82+ from crawlee .storages import KeyValueStore # noqa: PLC0415 avoid circular import
83+
84+ return await KeyValueStore .open (name = persist_state_kvs_name , id = persist_state_kvs_id )
85+
86+ self ._persist_state_kvs_factory = kvs_factory
87+ else :
88+ self ._persist_state_kvs_factory = persist_state_kvs_factory
89+
90+ self ._key_value_store : KeyValueStore | None = None
6491 self ._log = logger
6592
6693 async def initialize (self ) -> TStateModel :
@@ -77,11 +104,8 @@ async def initialize(self) -> TStateModel:
77104 return self .current_value
78105
79106 # Import here to avoid circular imports.
80- from crawlee .storages ._key_value_store import KeyValueStore # noqa: PLC0415
81107
82- self ._key_value_store = await KeyValueStore .open (
83- name = self ._persist_state_kvs_name , id = self ._persist_state_kvs_id
84- )
108+ self ._key_value_store = await self ._persist_state_kvs_factory ()
85109
86110 await self ._load_saved_state ()
87111
0 commit comments