|
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | +from typing import Callable |
| 4 | + |
| 5 | + |
| 6 | +async def wait_for_condition( |
| 7 | + predicate: Callable[[], bool], |
| 8 | + timeout: float = 0.2, |
| 9 | + check_interval: float = 0.01, |
| 10 | + error_message: str = "Timeout waiting for condition", |
| 11 | +) -> None: |
| 12 | + """ |
| 13 | + Poll a condition until it becomes True or timeout is reached. |
| 14 | +
|
| 15 | + Args: |
| 16 | + predicate: A callable that returns True when the condition is met |
| 17 | + timeout: Maximum time to wait in seconds (default: 0.2s = 20 * 0.01s) |
| 18 | + check_interval: Time to sleep between checks in seconds (default: 0.01s) |
| 19 | + error_message: Error message to raise if timeout occurs |
| 20 | +
|
| 21 | + Raises: |
| 22 | + AssertionError: If the condition is not met within the timeout period |
| 23 | +
|
| 24 | + Example: |
| 25 | + # Wait for circuit breaker to open |
| 26 | + await wait_for_condition( |
| 27 | + lambda: cb2.state == CBState.OPEN, |
| 28 | + timeout=0.2, |
| 29 | + error_message="Timeout waiting for cb2 to open" |
| 30 | + ) |
| 31 | +
|
| 32 | + # Wait for failover strategy to select a specific database |
| 33 | + await wait_for_condition( |
| 34 | + lambda: client.command_executor.active_database is mock_db, |
| 35 | + timeout=0.2, |
| 36 | + error_message="Timeout waiting for active database to change" |
| 37 | + ) |
| 38 | + """ |
| 39 | + max_retries = int(timeout / check_interval) |
| 40 | + |
| 41 | + for attempt in range(max_retries): |
| 42 | + if predicate(): |
| 43 | + logging.debug(f"Condition met after {attempt} attempts") |
| 44 | + return |
| 45 | + await asyncio.sleep(check_interval) |
| 46 | + |
| 47 | + raise AssertionError(error_message) |
0 commit comments