Skip to content

Commit b35dee7

Browse files
authored
fix: Fix KeyValueStore.auto_saved_value failing in some scenarios (#1438)
### Description - Reduce the amount of global side effects in `service_locator` by using an explicit KVS factory in `RecoverableState`. - Fix `KeyValueStore.auto_saved_value` not working properly if the global storage_client was different from the current kvs storage client. - Improve test isolation. ### Issues - Closes: #1354 ### Testing - Added tests for some edge cases.
1 parent 11944b7 commit b35dee7

File tree

8 files changed

+124
-22
lines changed

8 files changed

+124
-22
lines changed

src/crawlee/_utils/recoverable_state.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from pydantic import BaseModel
66

7+
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
78
from crawlee.events._types import Event, EventPersistStateData
89

910
if TYPE_CHECKING:
1011
import logging
12+
from collections.abc import Callable, Coroutine
1113

12-
from crawlee.storages._key_value_store import KeyValueStore
14+
from crawlee.storages import KeyValueStore
1315

1416
TStateModel = TypeVar('TStateModel', bound=BaseModel)
1517

@@ -37,6 +39,7 @@ def __init__(
3739
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
3840
persist_state_kvs_name: str | None = None,
3941
persist_state_kvs_id: str | None = None,
42+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
4043
logger: logging.Logger,
4144
) -> None:
4245
"""Initialize a new recoverable state object.
@@ -51,16 +54,40 @@ def __init__(
5154
If neither a name nor and id are supplied, the default store will be used.
5255
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
5356
If neither a name nor and id are supplied, the default store will be used.
57+
persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
58+
not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
5459
logger: A logger instance for logging operations related to state persistence
5560
"""
61+
raise_if_too_many_kwargs(
62+
persist_state_kvs_name=persist_state_kvs_name,
63+
persist_state_kvs_id=persist_state_kvs_id,
64+
persist_state_kvs_factory=persist_state_kvs_factory,
65+
)
66+
if not persist_state_kvs_factory:
67+
logger.debug(
68+
'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
69+
'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
70+
'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
71+
'global side effects.'
72+
)
73+
5674
self._default_state = default_state
5775
self._state_type: type[TStateModel] = self._default_state.__class__
5876
self._state: TStateModel | None = None
5977
self._persistence_enabled = persistence_enabled
6078
self._persist_state_key = persist_state_key
61-
self._persist_state_kvs_name = persist_state_kvs_name
62-
self._persist_state_kvs_id = persist_state_kvs_id
63-
self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037
79+
if persist_state_kvs_factory is None:
80+
81+
async def kvs_factory() -> KeyValueStore:
82+
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
83+
84+
return await KeyValueStore.open(name=persist_state_kvs_name, id=persist_state_kvs_id)
85+
86+
self._persist_state_kvs_factory = kvs_factory
87+
else:
88+
self._persist_state_kvs_factory = persist_state_kvs_factory
89+
90+
self._key_value_store: KeyValueStore | None = None
6491
self._log = logger
6592

6693
async def initialize(self) -> TStateModel:
@@ -77,11 +104,8 @@ async def initialize(self) -> TStateModel:
77104
return self.current_value
78105

79106
# Import here to avoid circular imports.
80-
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415
81107

82-
self._key_value_store = await KeyValueStore.open(
83-
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
84-
)
108+
self._key_value_store = await self._persist_state_kvs_factory()
85109

86110
await self._load_saved_state()
87111

src/crawlee/statistics/_statistics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
from crawlee.statistics._error_tracker import ErrorTracker
1818

1919
if TYPE_CHECKING:
20+
from collections.abc import Callable, Coroutine
2021
from types import TracebackType
2122

23+
from crawlee.storages import KeyValueStore
24+
2225
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
2326
TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState)
2427
logger = getLogger(__name__)
@@ -70,6 +73,7 @@ def __init__(
7073
persistence_enabled: bool | Literal['explicit_only'] = False,
7174
persist_state_kvs_name: str | None = None,
7275
persist_state_key: str | None = None,
76+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
7377
log_message: str = 'Statistics',
7478
periodic_message_logger: Logger | None = None,
7579
log_interval: timedelta = timedelta(minutes=1),
@@ -95,6 +99,7 @@ def __init__(
9599
persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}',
96100
persistence_enabled=persistence_enabled,
97101
persist_state_kvs_name=persist_state_kvs_name,
102+
persist_state_kvs_factory=persist_state_kvs_factory,
98103
logger=logger,
99104
)
100105

@@ -110,8 +115,8 @@ def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statist
110115
"""Create near copy of the `Statistics` with replaced `state_model`."""
111116
new_statistics: Statistics[TNewStatisticsState] = Statistics(
112117
persistence_enabled=self._state._persistence_enabled, # noqa: SLF001
113-
persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001
114118
persist_state_key=self._state._persist_state_key, # noqa: SLF001
119+
persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001
115120
log_message=self._log_message,
116121
periodic_message_logger=self._periodic_message_logger,
117122
state_model=state_model,

src/crawlee/storage_clients/_file_system/_request_queue_client.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from collections.abc import Sequence
3232

3333
from crawlee.configuration import Configuration
34+
from crawlee.storages import KeyValueStore
3435

3536
logger = getLogger(__name__)
3637

@@ -92,6 +93,7 @@ def __init__(
9293
metadata: RequestQueueMetadata,
9394
path_to_rq: Path,
9495
lock: asyncio.Lock,
96+
recoverable_state: RecoverableState[RequestQueueState],
9597
) -> None:
9698
"""Initialize a new instance.
9799
@@ -114,12 +116,7 @@ def __init__(
114116
self._is_empty_cache: bool | None = None
115117
"""Cache for is_empty result: None means unknown, True/False is cached state."""
116118

117-
self._state = RecoverableState[RequestQueueState](
118-
default_state=RequestQueueState(),
119-
persist_state_key=f'__RQ_STATE_{self._metadata.id}',
120-
persistence_enabled=True,
121-
logger=logger,
122-
)
119+
self._state = recoverable_state
123120
"""Recoverable state to maintain request ordering, in-progress status, and handled status."""
124121

125122
@override
@@ -136,6 +133,22 @@ def path_to_metadata(self) -> Path:
136133
"""The full path to the request queue metadata file."""
137134
return self.path_to_rq / METADATA_FILENAME
138135

136+
@classmethod
137+
async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
138+
async def kvs_factory() -> KeyValueStore:
139+
from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import
140+
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
141+
142+
return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration)
143+
144+
return RecoverableState[RequestQueueState](
145+
default_state=RequestQueueState(),
146+
persist_state_key=f'__RQ_STATE_{id}',
147+
persist_state_kvs_factory=kvs_factory,
148+
persistence_enabled=True,
149+
logger=logger,
150+
)
151+
139152
@classmethod
140153
async def open(
141154
cls,
@@ -194,6 +207,9 @@ async def open(
194207
metadata=metadata,
195208
path_to_rq=rq_base_path / rq_dir,
196209
lock=asyncio.Lock(),
210+
recoverable_state=await cls._create_recoverable_state(
211+
id=id, configuration=configuration
212+
),
197213
)
198214
await client._state.initialize()
199215
await client._discover_existing_requests()
@@ -230,6 +246,7 @@ async def open(
230246
metadata=metadata,
231247
path_to_rq=path_to_rq,
232248
lock=asyncio.Lock(),
249+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
233250
)
234251

235252
await client._state.initialize()
@@ -254,6 +271,7 @@ async def open(
254271
metadata=metadata,
255272
path_to_rq=path_to_rq,
256273
lock=asyncio.Lock(),
274+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
257275
)
258276
await client._state.initialize()
259277
await client._update_metadata()

src/crawlee/storages/_key_value_store.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,14 @@ async def get_auto_saved_value(
281281
if key in cache:
282282
return cache[key].current_value.root
283283

284+
async def kvs_factory() -> KeyValueStore:
285+
return self
286+
284287
cache[key] = recoverable_state = RecoverableState(
285288
default_state=AutosavedValue(default_value),
286-
persistence_enabled=True,
287-
persist_state_kvs_id=self.id,
288289
persist_state_key=key,
290+
persistence_enabled=True,
291+
persist_state_kvs_factory=kvs_factory,
289292
logger=logger,
290293
)
291294

tests/unit/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network
1818
from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient
1919
from crawlee.proxy_configuration import ProxyInfo
20+
from crawlee.statistics import Statistics
2021
from crawlee.storages import KeyValueStore
2122
from tests.unit.server import TestServer, app, serve_in_thread
2223

@@ -72,6 +73,10 @@ def _prepare_test_env() -> None:
7273
# Verify that the test environment was set up correctly.
7374
assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path)
7475

76+
# Reset global class variables to ensure test isolation.
77+
KeyValueStore._autosaved_values = {}
78+
Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
79+
7580
return _prepare_test_env
7681

7782

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,6 @@ async def test_adaptive_playwright_crawler_statistics_in_init() -> None:
493493
assert type(crawler._statistics.state) is AdaptivePlaywrightCrawlerStatisticState
494494

495495
assert crawler._statistics._state._persistence_enabled == persistence_enabled
496-
assert crawler._statistics._state._persist_state_kvs_name == persist_state_kvs_name
497496
assert crawler._statistics._state._persist_state_key == persist_state_key
498497

499498
assert crawler._statistics._log_message == log_message

tests/unit/storage_clients/_file_system/test_fs_rq_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import pytest
88

9-
from crawlee import Request
9+
from crawlee import Request, service_locator
1010
from crawlee.configuration import Configuration
11-
from crawlee.storage_clients import FileSystemStorageClient
11+
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import AsyncGenerator
@@ -78,6 +78,14 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient)
7878
assert request_data['url'].startswith('https://example.com/')
7979

8080

81+
async def test_opening_rq_does_not_have_side_effect_on_service_locator(configuration: Configuration) -> None:
82+
"""Opening request queue client should cause setting storage client in the global service locator."""
83+
await FileSystemStorageClient().create_rq_client(name='test_request_queue', configuration=configuration)
84+
85+
# Set some specific storage client in the service locator. There should be no `ServiceConflictError`.
86+
service_locator.set_storage_client(MemoryStorageClient())
87+
88+
8189
async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None:
8290
"""Test that drop removes the entire RQ directory from disk."""
8391
await rq_client.add_batch_of_requests([Request.from_url('https://example.com')])

tests/unit/storages/test_key_value_store.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from crawlee import service_locator
1212
from crawlee.configuration import Configuration
13+
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SqlStorageClient, StorageClient
1314
from crawlee.storages import KeyValueStore
1415
from crawlee.storages._storage_instance_manager import StorageInstanceManager
1516

1617
if TYPE_CHECKING:
1718
from collections.abc import AsyncGenerator
18-
19-
from crawlee.storage_clients import StorageClient
19+
from pathlib import Path
2020

2121

2222
@pytest.fixture
@@ -1095,3 +1095,43 @@ async def test_validate_name(storage_client: StorageClient, name: str, *, is_val
10951095
else:
10961096
with pytest.raises(ValueError, match=rf'Invalid storage name "{name}".*'):
10971097
await KeyValueStore.open(name=name, storage_client=storage_client)
1098+
1099+
1100+
@pytest.mark.parametrize(
1101+
'tested_storage_client',
1102+
[
1103+
pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'),
1104+
pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'),
1105+
pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'),
1106+
],
1107+
)
1108+
@pytest.mark.parametrize(
1109+
'global_storage_client',
1110+
[
1111+
pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'),
1112+
pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'),
1113+
pytest.param(SqlStorageClient(), id='global=SqlStorageClient'),
1114+
],
1115+
)
1116+
async def test_get_auto_saved_value_various_global_clients(
1117+
tmp_path: Path, tested_storage_client: StorageClient, global_storage_client: StorageClient
1118+
) -> None:
1119+
"""Ensure that persistence is working for all clients regardless of what is set in service locator."""
1120+
service_locator.set_configuration(
1121+
Configuration(
1122+
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
1123+
purge_on_start=True,
1124+
)
1125+
)
1126+
service_locator.set_storage_client(global_storage_client)
1127+
1128+
kvs = await KeyValueStore.open(storage_client=tested_storage_client)
1129+
values_kvs = {'key': 'some_value'}
1130+
test_key = 'test_key'
1131+
1132+
autosaved_value_kvs = await kvs.get_auto_saved_value(test_key)
1133+
assert autosaved_value_kvs == {}
1134+
autosaved_value_kvs.update(values_kvs)
1135+
await kvs.persist_autosaved_values()
1136+
1137+
assert await kvs.get_value(test_key) == autosaved_value_kvs

0 commit comments

Comments
 (0)