diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index 9d5d9176259d..86be0c9ff8f3 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -1501,6 +1501,17 @@ def health_check_sync(self) -> List[Dict]: """ pass + @abstractmethod + def get_queue_lengths_sync(self) -> Dict[str, int]: + """ + Get the lengths of all queues synchronously. + + Returns: + Dict[str, int]: A dictionary mapping queue names to their lengths. + Returns empty dict if queue length information is unavailable. + """ + pass + async def enqueue_task_async( self, task_name: str, diff --git a/python/ray/serve/task_processor.py b/python/ray/serve/task_processor.py index 92e5f68de90c..85c0161c91f1 100644 --- a/python/ray/serve/task_processor.py +++ b/python/ray/serve/task_processor.py @@ -1,3 +1,5 @@ +import asyncio +import concurrent.futures import logging import threading import time @@ -5,6 +7,7 @@ from celery import Celery from celery.signals import task_failure, task_unknown +from flower.utils.broker import Broker from ray.serve import get_replica_context from ray.serve._private.constants import ( @@ -37,6 +40,141 @@ ] +class FlowerQueueMonitor: + """ + Thread-safe queue length monitor using Flower's broker utility. + + This class provides a broker-agnostic way to query queue lengths by: + 1. Running a background asyncio event loop in a dedicated thread + 2. Using Flower's Broker utility to query queue information + 3. Bridging async/sync with run_coroutine_threadsafe + """ + + def __init__(self, app: Celery): + """Initialize the queue monitor with Celery app configuration.""" + self.app = app + + # Initialize Flower's Broker utility (broker-agnostic) + # This utility handles Redis, RabbitMQ, SQS, and other brokers + try: + self.broker = Broker( + app.connection().as_uri(include_password=True), + broker_options=app.conf.broker_transport_options or {}, + broker_use_ssl=app.conf.broker_use_ssl or {}, + ) + except Exception as e: + logger.error(f"Failed to initialize Flower Broker: {e}") + raise + + # Event loop management + self._loop = None + self._loop_thread = None + self._loop_ready = threading.Event() + self._should_stop = threading.Event() + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="FlowerQueueMonitor" + ) + + def start(self): + """ + Start the background event loop in a dedicated thread. + + This creates a new thread that runs an asyncio event loop. + The thread acts as the "main thread" for the event loop, + avoiding signal handler issues. + """ + if self._loop is not None: + logger.warning("Queue monitor already started") + return + + def _run_event_loop(): + """Run the event loop in the background thread.""" + try: + # Create new event loop for this thread + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Signal that loop is ready + self._loop_ready.set() + + # Run loop until stop is called + while not self._should_stop.is_set(): + self._loop.run_until_complete(asyncio.sleep(0.1)) + + except Exception as e: + logger.error(f"Error in event loop thread: {e}") + finally: + # Clean up loop + if self._loop and not self._loop.is_closed(): + self._loop.close() + + # Start event loop in background thread + self._loop_thread = self._executor.submit(_run_event_loop) + + # Wait for loop to be ready (with timeout) + if not self._loop_ready.wait(timeout=10): + raise RuntimeError("Failed to start event loop thread within 10 seconds") + + logger.info("Queue monitor event loop started successfully") + + def stop(self): + """Stop the background event loop and cleanup resources.""" + if self._loop is None: + logger.info("Flower queue monitor not running, nothing to stop") + return + + try: + # Signal the loop to stop + self._should_stop.set() + + # Wait for thread to finish (with timeout) + if self._loop_thread: + self._loop_thread.result(timeout=20) + + # Shutdown executor + self._executor.shutdown(wait=True) + + logger.info("Queue monitor stopped successfully") + self._should_stop.clear() + except Exception as e: + logger.error(f"Error stopping queue monitor: {e}") + finally: + self._loop = None + self._loop_thread = None + self._loop_ready.clear() + + def get_queue_lengths( + self, queue_names: List[str], timeout: float = 5.0 + ) -> Dict[str, int]: + """Get queue lengths synchronously (thread-safe).""" + if self._loop is None or not self._loop_ready.is_set(): + logger.error("Event loop not initialized. Call start() first.") + return {} + + try: + # Schedule coroutine in background event loop thread + future = asyncio.run_coroutine_threadsafe( + self.broker.queues(queue_names), self._loop + ) + + # Wait for result with timeout + queue_stats = future.result(timeout=timeout) + + # Convert to simple dict format + queue_lengths = {} + for queue_info in queue_stats: + queue_name = queue_info.get("name") + messages = queue_info.get("messages", 0) + if queue_name: + queue_lengths[queue_name] = messages + + return queue_lengths + + except Exception as e: + logger.error(f"Error getting queue lengths: {e}") + return {} + + @PublicAPI(stability="alpha") class CeleryTaskProcessorAdapter(TaskProcessorAdapter): """ @@ -50,6 +188,7 @@ class CeleryTaskProcessorAdapter(TaskProcessorAdapter): _worker_thread: Optional[threading.Thread] = None _worker_hostname: Optional[str] = None _worker_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY + _queue_monitor: Optional[FlowerQueueMonitor] = None def __init__(self, config: TaskProcessorConfig, *args, **kwargs): super().__init__(*args, **kwargs) @@ -136,6 +275,13 @@ def initialize(self, consumer_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY): if self._config.unprocessable_task_queue_name: task_unknown.connect(self._handle_unknown_task) + try: + self._queue_monitor = FlowerQueueMonitor(self._app) + logger.info("Queue monitor initialized") + except Exception as e: + logger.warning(f"Failed to initialize flower queue monitor: {e}.") + self._queue_monitor = None + def register_task_handle(self, func, name=None): task_options = { "autoretry_for": (Exception,), @@ -179,7 +325,7 @@ def get_task_status_sync(self, task_id) -> TaskResult: ) def start_consumer(self, **kwargs): - """Starts the Celery worker thread.""" + """Starts the Celery worker thread and queue monitoring.""" if self._worker_thread is not None and self._worker_thread.is_alive(): logger.info("Celery worker thread is already running.") return @@ -204,8 +350,27 @@ def start_consumer(self, **kwargs): f"Celery worker thread started with hostname: {self._worker_hostname}" ) + # Start queue monitor + if self._queue_monitor: + try: + self._queue_monitor.start() + logger.info("Queue monitor started") + except Exception as e: + logger.error(f"Failed to start queue monitor: {e}") + self._queue_monitor = None + def stop_consumer(self, timeout: float = 10.0): """Signals the Celery worker to shut down and waits for it to terminate.""" + # Stop queue monitor first + if self._queue_monitor: + try: + self._queue_monitor.stop() + logger.info("Queue monitor stopped") + except Exception as e: + logger.warning(f"Error stopping queue monitor: {e}") + finally: + self._queue_monitor = None + if self._worker_thread is None or not self._worker_thread.is_alive(): logger.info("Celery worker thread is not running.") return @@ -253,6 +418,34 @@ def health_check_sync(self) -> List[Dict]: """ return self._app.control.ping() + def get_queue_lengths_sync(self) -> Dict[str, int]: + """ + Get the lengths of all queues by querying broker directly. + + Returns: + Dict[str, int]: A dictionary mapping queue names to their lengths. + Returns empty dict if monitoring is unavailable. + """ + if self._queue_monitor is None: + logger.warning( + "Flower queue monitor not initialized. Cannot retrieve queue lengths." + ) + return {} + + # Collect all queue names to check + queue_names = [self._config.queue_name] + if self._config.failed_task_queue_name: + queue_names.append(self._config.failed_task_queue_name) + if self._config.unprocessable_task_queue_name: + queue_names.append(self._config.unprocessable_task_queue_name) + + try: + queue_lengths = self._queue_monitor.get_queue_lengths(queue_names) + return queue_lengths + except Exception as e: + logger.error(f"Flower failed to retrieve queue lengths: {e}") + return {} + def _handle_task_failure( self, sender: Any = None, diff --git a/python/ray/serve/tests/test_task_processor.py b/python/ray/serve/tests/test_task_processor.py index f09e09c1b3f6..af5643208d07 100644 --- a/python/ray/serve/tests/test_task_processor.py +++ b/python/ray/serve/tests/test_task_processor.py @@ -706,6 +706,78 @@ def get_message_received(self): lambda: "test_data_3" in handle.get_message_received.remote().result() ) + def test_flower_monitors_queue_with_redis( + self, external_redis, serve_instance # noqa: F811 + ): + """Test that queue monitor can monitor queue lengths with Redis broker.""" + + redis_address = os.environ.get("RAY_REDIS_ADDRESS") + processor_config = TaskProcessorConfig( + queue_name="flower_test_queue", + adapter_config=CeleryAdapterConfig( + broker_url=f"redis://{redis_address}/0", + backend_url=f"redis://{redis_address}/1", + app_custom_config={"worker_prefetch_multiplier": 1}, + ), + ) + + signal = SignalActor.remote() + + @serve.deployment(max_ongoing_requests=1) + @task_consumer(task_processor_config=processor_config) + class BlockingTaskConsumer: + def __init__(self, signal_actor): + self._signal = signal_actor + self.tasks_started = 0 + + @task_handler(name="blocking_task") + def blocking_task(self, data): + self.tasks_started += 1 + ray.get( + self._signal.wait.remote() + ) # Block indefinitely waiting for signal + return f"processed: {data}" + + def get_tasks_started(self): + return self.tasks_started + + def get_queue_lengths(self): + return self._adapter.get_queue_lengths_sync() + + handle = serve.run(BlockingTaskConsumer.bind(signal)) + + # Push 10 tasks to the queue + num_tasks = 10 + for i in range(num_tasks): + send_request_to_queue.remote( + processor_config, f"task_{i}", task_name="blocking_task" + ) + + # Wait for one task to start (will be blocked waiting for signal) + wait_for_condition( + lambda: handle.get_tasks_started.remote().result() == 1, + timeout=30, + retry_interval_ms=1000, + ) + + def check_queue_lengths(): + result = handle.get_queue_lengths.remote().result() + # 8 should still be queued (1 executing, possibly 1 prefetched) + return result.get("flower_test_queue", 0) >= 8 + + wait_for_condition(check_queue_lengths, timeout=30) + + # Get final queue lengths + queue_lengths = handle.get_queue_lengths.remote().result() + + assert len(queue_lengths) > 0, "Should have queue length data" + assert "flower_test_queue" in queue_lengths, "Should report main queue" + assert ( + queue_lengths["flower_test_queue"] == 8 + ), f"Expected 8 tasks in queue, got {queue_lengths['flower_test_queue']} instead" + + ray.get(signal.send.remote()) + @pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.") class TestTaskConsumerWithDLQsConfiguration: diff --git a/python/ray/serve/tests/unit/test_task_consumer.py b/python/ray/serve/tests/unit/test_task_consumer.py index 107a592ed071..10cc1f769082 100644 --- a/python/ray/serve/tests/unit/test_task_consumer.py +++ b/python/ray/serve/tests/unit/test_task_consumer.py @@ -58,6 +58,9 @@ def get_metrics_sync(self) -> Dict[str, Any]: def health_check_sync(self) -> List[Dict]: pass + def get_queue_lengths_sync(self) -> Dict[str, int]: + pass + @pytest.fixture def config():