Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/ray/serve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,17 @@ def health_check_sync(self) -> List[Dict]:
"""
pass

@abstractmethod
def get_queue_lengths_sync(self) -> Dict[str, int]:
"""
Get the lengths of all queues synchronously.

Returns:
Dict[str, int]: A dictionary mapping queue names to their lengths.
Returns empty dict if queue length information is unavailable.
"""
pass

async def enqueue_task_async(
self,
task_name: str,
Expand Down
195 changes: 194 additions & 1 deletion python/ray/serve/task_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import concurrent.futures
import logging
import threading
import time
from typing import Any, Dict, List, Optional

from celery import Celery
from celery.signals import task_failure, task_unknown
from flower.utils.broker import Broker

from ray.serve import get_replica_context
from ray.serve._private.constants import (
Expand Down Expand Up @@ -37,6 +40,141 @@
]


class FlowerQueueMonitor:
"""
Thread-safe queue length monitor using Flower's broker utility.

This class provides a broker-agnostic way to query queue lengths by:
1. Running a background asyncio event loop in a dedicated thread
2. Using Flower's Broker utility to query queue information
3. Bridging async/sync with run_coroutine_threadsafe
"""

def __init__(self, app: Celery):
"""Initialize the queue monitor with Celery app configuration."""
self.app = app

# Initialize Flower's Broker utility (broker-agnostic)
# This utility handles Redis, RabbitMQ, SQS, and other brokers
try:
self.broker = Broker(
app.connection().as_uri(include_password=True),
broker_options=app.conf.broker_transport_options or {},
broker_use_ssl=app.conf.broker_use_ssl or {},
)
except Exception as e:
logger.error(f"Failed to initialize Flower Broker: {e}")
raise

# Event loop management
self._loop = None
self._loop_thread = None
self._loop_ready = threading.Event()
self._should_stop = threading.Event()
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1, thread_name_prefix="FlowerQueueMonitor"
)

def start(self):
"""
Start the background event loop in a dedicated thread.

This creates a new thread that runs an asyncio event loop.
The thread acts as the "main thread" for the event loop,
avoiding signal handler issues.
"""
if self._loop is not None:
logger.warning("Queue monitor already started")
return

def _run_event_loop():
"""Run the event loop in the background thread."""
try:
# Create new event loop for this thread
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

# Signal that loop is ready
self._loop_ready.set()

# Run loop until stop is called
while not self._should_stop.is_set():
self._loop.run_until_complete(asyncio.sleep(0.1))

except Exception as e:
logger.error(f"Error in event loop thread: {e}")
finally:
# Clean up loop
if self._loop and not self._loop.is_closed():
self._loop.close()

# Start event loop in background thread
self._loop_thread = self._executor.submit(_run_event_loop)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Concurrent Initialization Race Condition

The start() method has a race condition where self._loop is checked in the main thread but assigned in the background thread. Multiple concurrent calls to start() can pass the if self._loop is not None check before any thread assigns the value, causing multiple event loops and threads to be created. The check should use a thread-safe flag or lock to prevent concurrent initialization.

Fix in Cursor Fix in Web


# Wait for loop to be ready (with timeout)
if not self._loop_ready.wait(timeout=10):
raise RuntimeError("Failed to start event loop thread within 10 seconds")

logger.info("Queue monitor event loop started successfully")

def stop(self):
"""Stop the background event loop and cleanup resources."""
if self._loop is None:
logger.info("Flower queue monitor not running, nothing to stop")
return

try:
# Signal the loop to stop
self._should_stop.set()

# Wait for thread to finish (with timeout)
if self._loop_thread:
self._loop_thread.result(timeout=20)

# Shutdown executor
self._executor.shutdown(wait=True)

logger.info("Queue monitor stopped successfully")
self._should_stop.clear()
except Exception as e:
logger.error(f"Error stopping queue monitor: {e}")
finally:
self._loop = None
self._loop_thread = None
self._loop_ready.clear()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Race Condition: Loop Shutdown Inconsistency

The stop() method has a race condition where get_queue_lengths() can be called after the event loop is closed but before _loop is set to None. After line 132 waits for the thread to complete, the loop is closed, but _loop remains non-None and _loop_ready remains set until the finally block executes. A concurrent call to get_queue_lengths() during this window will pass the check on line 150 and attempt to schedule a coroutine on the closed loop, causing an error.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incomplete Cleanup Blocks System Restart

The _should_stop.clear() call at line 138 is in the try block instead of the finally block. If _executor.shutdown() raises an exception, _should_stop remains set, preventing the monitor from restarting properly. The flag should be cleared in the finally block alongside other state cleanup to ensure consistent state regardless of exceptions.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Monitor fails to restart after shutdown.

The ThreadPoolExecutor is shut down in stop() at line 135 but never recreated. If start() is called again after stop(), the submit() call at line 112 fails because the executor cannot accept new tasks after shutdown. The executor should be recreated in start() or not shut down in stop() to allow the monitor to be restarted.

Fix in Cursor Fix in Web


def get_queue_lengths(
self, queue_names: List[str], timeout: float = 5.0
) -> Dict[str, int]:
"""Get queue lengths synchronously (thread-safe)."""
if self._loop is None or not self._loop_ready.is_set():
logger.error("Event loop not initialized. Call start() first.")
return {}

try:
# Schedule coroutine in background event loop thread
future = asyncio.run_coroutine_threadsafe(
self.broker.queues(queue_names), self._loop
)

# Wait for result with timeout
queue_stats = future.result(timeout=timeout)

# Convert to simple dict format
queue_lengths = {}
for queue_info in queue_stats:
queue_name = queue_info.get("name")
messages = queue_info.get("messages", 0)
if queue_name:
queue_lengths[queue_name] = messages

return queue_lengths

except Exception as e:
logger.error(f"Error getting queue lengths: {e}")
return {}


@PublicAPI(stability="alpha")
class CeleryTaskProcessorAdapter(TaskProcessorAdapter):
"""
Expand All @@ -50,6 +188,7 @@ class CeleryTaskProcessorAdapter(TaskProcessorAdapter):
_worker_thread: Optional[threading.Thread] = None
_worker_hostname: Optional[str] = None
_worker_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY
_queue_monitor: Optional[FlowerQueueMonitor] = None

def __init__(self, config: TaskProcessorConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -136,6 +275,13 @@ def initialize(self, consumer_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY):
if self._config.unprocessable_task_queue_name:
task_unknown.connect(self._handle_unknown_task)

try:
self._queue_monitor = FlowerQueueMonitor(self._app)
logger.info("Queue monitor initialized")
except Exception as e:
logger.warning(f"Failed to initialize flower queue monitor: {e}.")
self._queue_monitor = None

def register_task_handle(self, func, name=None):
task_options = {
"autoretry_for": (Exception,),
Expand Down Expand Up @@ -179,7 +325,7 @@ def get_task_status_sync(self, task_id) -> TaskResult:
)

def start_consumer(self, **kwargs):
"""Starts the Celery worker thread."""
"""Starts the Celery worker thread and queue monitoring."""
if self._worker_thread is not None and self._worker_thread.is_alive():
logger.info("Celery worker thread is already running.")
return
Expand All @@ -204,8 +350,27 @@ def start_consumer(self, **kwargs):
f"Celery worker thread started with hostname: {self._worker_hostname}"
)

# Start queue monitor
if self._queue_monitor:
try:
self._queue_monitor.start()
logger.info("Queue monitor started")
except Exception as e:
logger.error(f"Failed to start queue monitor: {e}")
self._queue_monitor = None

def stop_consumer(self, timeout: float = 10.0):
"""Signals the Celery worker to shut down and waits for it to terminate."""
# Stop queue monitor first
if self._queue_monitor:
try:
self._queue_monitor.stop()
logger.info("Queue monitor stopped")
except Exception as e:
logger.warning(f"Error stopping queue monitor: {e}")
finally:
self._queue_monitor = None

if self._worker_thread is None or not self._worker_thread.is_alive():
logger.info("Celery worker thread is not running.")
return
Expand Down Expand Up @@ -253,6 +418,34 @@ def health_check_sync(self) -> List[Dict]:
"""
return self._app.control.ping()

def get_queue_lengths_sync(self) -> Dict[str, int]:
"""
Get the lengths of all queues by querying broker directly.

Returns:
Dict[str, int]: A dictionary mapping queue names to their lengths.
Returns empty dict if monitoring is unavailable.
"""
if self._queue_monitor is None:
logger.warning(
"Flower queue monitor not initialized. Cannot retrieve queue lengths."
)
return {}

# Collect all queue names to check
queue_names = [self._config.queue_name]
if self._config.failed_task_queue_name:
queue_names.append(self._config.failed_task_queue_name)
if self._config.unprocessable_task_queue_name:
queue_names.append(self._config.unprocessable_task_queue_name)

try:
queue_lengths = self._queue_monitor.get_queue_lengths(queue_names)
return queue_lengths
except Exception as e:
logger.error(f"Flower failed to retrieve queue lengths: {e}")
return {}

def _handle_task_failure(
self,
sender: Any = None,
Expand Down
72 changes: 72 additions & 0 deletions python/ray/serve/tests/test_task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,78 @@ def get_message_received(self):
lambda: "test_data_3" in handle.get_message_received.remote().result()
)

def test_flower_monitors_queue_with_redis(
self, external_redis, serve_instance # noqa: F811
):
"""Test that queue monitor can monitor queue lengths with Redis broker."""

redis_address = os.environ.get("RAY_REDIS_ADDRESS")
processor_config = TaskProcessorConfig(
queue_name="flower_test_queue",
adapter_config=CeleryAdapterConfig(
broker_url=f"redis://{redis_address}/0",
backend_url=f"redis://{redis_address}/1",
app_custom_config={"worker_prefetch_multiplier": 1},
),
)

signal = SignalActor.remote()

@serve.deployment(max_ongoing_requests=1)
@task_consumer(task_processor_config=processor_config)
class BlockingTaskConsumer:
def __init__(self, signal_actor):
self._signal = signal_actor
self.tasks_started = 0

@task_handler(name="blocking_task")
def blocking_task(self, data):
self.tasks_started += 1
ray.get(
self._signal.wait.remote()
) # Block indefinitely waiting for signal
return f"processed: {data}"

def get_tasks_started(self):
return self.tasks_started

def get_queue_lengths(self):
return self._adapter.get_queue_lengths_sync()

handle = serve.run(BlockingTaskConsumer.bind(signal))

# Push 10 tasks to the queue
num_tasks = 10
for i in range(num_tasks):
send_request_to_queue.remote(
processor_config, f"task_{i}", task_name="blocking_task"
)

# Wait for one task to start (will be blocked waiting for signal)
wait_for_condition(
lambda: handle.get_tasks_started.remote().result() == 1,
timeout=30,
retry_interval_ms=1000,
)

def check_queue_lengths():
result = handle.get_queue_lengths.remote().result()
# 8 should still be queued (1 executing, possibly 1 prefetched)
return result.get("flower_test_queue", 0) >= 8

wait_for_condition(check_queue_lengths, timeout=30)

# Get final queue lengths
queue_lengths = handle.get_queue_lengths.remote().result()

assert len(queue_lengths) > 0, "Should have queue length data"
assert "flower_test_queue" in queue_lengths, "Should report main queue"
assert (
queue_lengths["flower_test_queue"] == 8
), f"Expected 8 tasks in queue, got {queue_lengths['flower_test_queue']} instead"

ray.get(signal.send.remote())


@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.")
class TestTaskConsumerWithDLQsConfiguration:
Expand Down
3 changes: 3 additions & 0 deletions python/ray/serve/tests/unit/test_task_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def get_metrics_sync(self) -> Dict[str, Any]:
def health_check_sync(self) -> List[Dict]:
pass

def get_queue_lengths_sync(self) -> Dict[str, int]:
pass


@pytest.fixture
def config():
Expand Down