Skip to content
Merged
30 changes: 25 additions & 5 deletions doc/source/serve/advanced-guides/replica-ranks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
This API is experimental and may change between Ray minor versions.
:::

Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **rank (an integer from 0 to N-1)** and **a world size (the total number of replicas)**.
Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **`ReplicaRank` object** containing rank information and **a world size (the total number of replicas)**. The rank object includes a global rank (an integer from 0 to N-1), a node rank, and a local rank on the node.

## Access replica ranks

Expand All @@ -28,9 +28,29 @@ The following example shows how to access replica rank information:

The [`ReplicaContext`](../api/doc/ray.serve.context.ReplicaContext.rst) provides two key fields:

- `rank`: An integer from 0 to N-1 representing this replica's unique identifier.
- `rank`: A [`ReplicaRank`](../api/doc/ray.serve.schema.ReplicaRank.rst) object containing rank information for this replica. Access the integer rank value with `.rank`.
- `world_size`: The target number of replicas for the deployment.

The `ReplicaRank` object contains three fields:
- `rank`: The global rank (an integer from 0 to N-1) representing this replica's unique identifier across all nodes.
- `node_rank`: The rank of the node this replica runs on (an integer from 0 to M-1 where M is the number of nodes).
- `local_rank`: The rank of this replica on its node (an integer from 0 to K-1 where K is the number of replicas on this node).

:::{note}
**Accessing rank values:**

To use the rank in your code, access the `.rank` attribute to get the integer value:

```python
context = serve.get_replica_context()
my_rank = context.rank.rank # Get the integer rank value
my_node_rank = context.rank.node_rank # Get the node rank
my_local_rank = context.rank.local_rank # Get the local rank on this node
```

Most use cases only need the global `rank` value. The `node_rank` and `local_rank` are useful for advanced scenarios such as coordinating replicas on the same node.
:::

## Handle rank changes with reconfigure

When a replica's rank changes (such as during downscaling), Ray Serve can automatically call the `reconfigure` method on your deployment class to notify it of the new rank. This allows you to update replica-specific state when ranks change.
Expand All @@ -54,15 +74,15 @@ The following example shows how to implement `reconfigure` to handle rank change
Ray Serve automatically calls your `reconfigure` method in the following situations:

1. **At replica startup:** When a replica starts, if your deployment has both a `reconfigure` method and a `user_config`, Ray Serve calls `reconfigure` after running `__init__`. This lets you initialize rank-aware state without duplicating code between `__init__` and `reconfigure`.
2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank.
3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank.
2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank as a `ReplicaRank` object.
3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank as a `ReplicaRank` object.

:::{note}
**Requirements to receive rank updates:**

To get rank changes through `reconfigure`, your deployment needs:
- A class-based deployment (function deployments don't support `reconfigure`)
- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: int)`
- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: ReplicaRank)`
- A `user_config` in your deployment (even if it's just an empty dict: `user_config={}`)

Without a `user_config`, Ray Serve won't call `reconfigure` for rank changes.
Expand Down
12 changes: 7 additions & 5 deletions doc/source/serve/doc_code/replica_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
@serve.deployment(num_replicas=4)
class ModelShard:
def __call__(self):
context = serve.get_replica_context()
return {
"rank": serve.get_replica_context().rank,
"world_size": serve.get_replica_context().world_size,
"rank": context.rank.rank, # Access the integer rank value
"world_size": context.world_size,
}


Expand All @@ -17,20 +18,21 @@ def __call__(self):
# __reconfigure_rank_start__
from typing import Any
from ray import serve
from ray.serve.schema import ReplicaRank


@serve.deployment(num_replicas=4, user_config={"name": "model_v1"})
class RankAwareModel:
def __init__(self):
context = serve.get_replica_context()
self.rank = context.rank
self.rank = context.rank.rank # Extract integer rank value
self.world_size = context.world_size
self.model_name = None
print(f"Replica rank: {self.rank}/{self.world_size}")

async def reconfigure(self, user_config: Any, rank: int):
async def reconfigure(self, user_config: Any, rank: ReplicaRank):
"""Called when user_config or rank changes."""
self.rank = rank
self.rank = rank.rank # Extract integer rank value from ReplicaRank object
self.world_size = serve.get_replica_context().world_size
self.model_name = user_config.get("name")
print(f"Reconfigured: rank={self.rank}, model={self.model_name}")
Expand Down
37 changes: 16 additions & 21 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(
self._docs_path: Optional[str] = None
self._route_patterns: Optional[List[str]] = None
# Rank assigned to the replica.
self._rank: Optional[int] = None
self._rank: Optional[ReplicaRank] = None
# Populated in `on_scheduled` or `recover`.
self._actor_handle: ActorHandle = None
self._placement_group: PlacementGroup = None
Expand Down Expand Up @@ -290,7 +290,7 @@ def deployment_name(self) -> str:
return self._deployment_id.name

@property
def rank(self) -> Optional[int]:
def rank(self) -> Optional[ReplicaRank]:
return self._rank

@property
Expand Down Expand Up @@ -442,7 +442,7 @@ def initialization_latency_s(self) -> Optional[float]:
return self._initialization_latency_s

def start(
self, deployment_info: DeploymentInfo, rank: int
self, deployment_info: DeploymentInfo, rank: ReplicaRank
) -> ReplicaSchedulingRequest:
"""Start the current DeploymentReplica instance.

Expand Down Expand Up @@ -609,11 +609,7 @@ def _format_user_config(self, user_config: Any):
temp = msgpack_deserialize(temp)
return temp

def reconfigure(
self,
version: DeploymentVersion,
rank: int,
) -> bool:
def reconfigure(self, version: DeploymentVersion, rank: ReplicaRank) -> bool:
"""
Update replica version. Also, updates the deployment config on the actor
behind this DeploymentReplica instance if necessary.
Expand Down Expand Up @@ -1170,7 +1166,7 @@ def initialization_latency_s(self) -> Optional[float]:
return self._actor.initialization_latency_s

def start(
self, deployment_info: DeploymentInfo, rank: int
self, deployment_info: DeploymentInfo, rank: ReplicaRank
) -> ReplicaSchedulingRequest:
"""
Start a new actor for current DeploymentReplica instance.
Expand All @@ -1184,7 +1180,7 @@ def start(
def reconfigure(
self,
version: DeploymentVersion,
rank: int,
rank: ReplicaRank,
) -> bool:
"""
Update replica version. Also, updates the deployment config on the actor
Expand All @@ -1211,7 +1207,7 @@ def recover(self) -> bool:
return True

@property
def rank(self) -> Optional[int]:
def rank(self) -> Optional[ReplicaRank]:
"""Get the rank assigned to the replica."""
return self._actor.rank

Expand Down Expand Up @@ -1695,9 +1691,11 @@ def _assign_rank_impl():
# Assign global rank
rank = self._replica_rank_manager.assign_rank(replica_id)

return ReplicaRank(rank=rank)
return ReplicaRank(rank=rank, node_rank=-1, local_rank=-1)

return self._execute_with_error_handling(_assign_rank_impl, ReplicaRank(rank=0))
return self._execute_with_error_handling(
_assign_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1)
)

def release_rank(self, replica_id: str) -> None:
"""Release rank for a replica.
Expand Down Expand Up @@ -1776,10 +1774,10 @@ def _get_replica_rank_impl():
raise RuntimeError(f"Rank for {replica_id} not assigned")

global_rank = self._replica_rank_manager.get_rank(replica_id)
return ReplicaRank(rank=global_rank)
return ReplicaRank(rank=global_rank, node_rank=-1, local_rank=-1)

return self._execute_with_error_handling(
_get_replica_rank_impl, ReplicaRank(rank=0)
_get_replica_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1)
)

def check_rank_consistency_and_reassign_minimally(
Expand Down Expand Up @@ -2547,7 +2545,7 @@ def scale_deployment_replicas(
self._target_state.version,
)
scheduling_request = new_deployment_replica.start(
self._target_state.info, rank=assigned_rank.rank
self._target_state.info, rank=assigned_rank
)

upscale.append(scheduling_request)
Expand Down Expand Up @@ -2665,10 +2663,7 @@ def _check_startup_replicas(
# data structure with RUNNING state.
# Recover rank from the replica actor during controller restart
replica_id = replica.replica_id.unique_id
recovered_rank = replica.rank
self._rank_manager.recover_rank(
replica_id, ReplicaRank(rank=recovered_rank)
)
self._rank_manager.recover_rank(replica_id, replica.rank)
# This replica should be now be added to handle's replica
# set.
self._replicas.add(ReplicaState.RUNNING, replica)
Expand Down Expand Up @@ -2951,7 +2946,7 @@ def _reconfigure_replicas_with_new_ranks(
# World size is calculated automatically from deployment config
_ = replica.reconfigure(
self._target_state.version,
rank=new_rank.rank,
rank=new_rank,
)
updated_count += 1

Expand Down
16 changes: 8 additions & 8 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
RayServeException,
)
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import EncodingType, LoggingConfig
from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank

logger = logging.getLogger(SERVE_LOGGER_NAME)

Expand All @@ -129,7 +129,7 @@
Optional[str],
int,
int,
int, # rank
ReplicaRank, # rank
Optional[List[str]], # route_patterns
]

Expand Down Expand Up @@ -510,7 +510,7 @@ def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
rank: ReplicaRank,
):
self._version = version
self._replica_id = replica_id
Expand Down Expand Up @@ -610,7 +610,7 @@ def get_dynamically_created_handles(self) -> Set[DeploymentID]:
return self._dynamically_created_handles

def _set_internal_replica_context(
self, *, servable_object: Callable = None, rank: int = None
self, *, servable_object: Callable = None, rank: ReplicaRank = None
):
# Calculate world_size from deployment config instead of storing it
world_size = self._deployment_config.num_replicas
Expand Down Expand Up @@ -961,7 +961,7 @@ async def initialize(self, deployment_config: DeploymentConfig):
async def reconfigure(
self,
deployment_config: DeploymentConfig,
rank: int,
rank: ReplicaRank,
route_prefix: Optional[str] = None,
):
try:
Expand Down Expand Up @@ -1186,7 +1186,7 @@ async def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
rank: ReplicaRank,
):
deployment_config = DeploymentConfig.from_proto_bytes(
deployment_config_proto_bytes
Expand Down Expand Up @@ -1305,7 +1305,7 @@ async def record_routing_stats(self) -> Dict[str, Any]:
return await self._replica_impl.record_routing_stats()

async def reconfigure(
self, deployment_config, rank: int, route_prefix: Optional[str] = None
self, deployment_config, rank: ReplicaRank, route_prefix: Optional[str] = None
) -> ReplicaMetadata:
await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
return self._replica_impl.get_metadata()
Expand Down Expand Up @@ -1802,7 +1802,7 @@ async def _call_user_autoscaling_stats(self) -> Dict[str, Union[int, float]]:
return result

@_run_user_code
async def call_reconfigure(self, user_config: Optional[Any], rank: int):
async def call_reconfigure(self, user_config: Optional[Any], rank: ReplicaRank):
self._raise_if_not_initialized("call_reconfigure")

# NOTE(edoakes): there is the possibility of a race condition in user code if
Expand Down
5 changes: 3 additions & 2 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ray.serve._private.replica_result import ReplicaResult
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
from ray.serve.schema import ReplicaRank
from ray.util.annotations import DeveloperAPI

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand All @@ -48,7 +49,7 @@ class ReplicaContext:
replica_id: ReplicaID
servable_object: Callable
_deployment_config: DeploymentConfig
rank: int
rank: ReplicaRank
world_size: int
_handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None

Expand Down Expand Up @@ -113,7 +114,7 @@ def _set_internal_replica_context(
replica_id: ReplicaID,
servable_object: Callable,
_deployment_config: DeploymentConfig,
rank: int,
rank: ReplicaRank,
world_size: int,
handle_registration_callback: Optional[Callable[[str, str], None]] = None,
):
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,3 +1600,9 @@ class ReplicaRank(BaseModel):
rank: int = Field(
description="Global rank of the replica across all nodes scoped to the deployment."
)

node_rank: int = Field(description="Rank of the node in the deployment.")

local_rank: int = Field(
description="Rank of the replica on the node scoped to the deployment."
)
Loading