diff --git a/doc/source/serve/advanced-guides/replica-ranks.md b/doc/source/serve/advanced-guides/replica-ranks.md index f8f7269e2baa..fa0abb8688b8 100644 --- a/doc/source/serve/advanced-guides/replica-ranks.md +++ b/doc/source/serve/advanced-guides/replica-ranks.md @@ -6,7 +6,7 @@ This API is experimental and may change between Ray minor versions. ::: -Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **rank (an integer from 0 to N-1)** and **a world size (the total number of replicas)**. +Replica ranks provide a unique identifier for **each replica within a deployment**. Each replica receives a **`ReplicaRank` object** containing rank information and **a world size (the total number of replicas)**. The rank object includes a global rank (an integer from 0 to N-1), a node rank, and a local rank on the node. ## Access replica ranks @@ -28,9 +28,29 @@ The following example shows how to access replica rank information: The [`ReplicaContext`](../api/doc/ray.serve.context.ReplicaContext.rst) provides two key fields: -- `rank`: An integer from 0 to N-1 representing this replica's unique identifier. +- `rank`: A [`ReplicaRank`](../api/doc/ray.serve.schema.ReplicaRank.rst) object containing rank information for this replica. Access the integer rank value with `.rank`. - `world_size`: The target number of replicas for the deployment. +The `ReplicaRank` object contains three fields: +- `rank`: The global rank (an integer from 0 to N-1) representing this replica's unique identifier across all nodes. +- `node_rank`: The rank of the node this replica runs on (an integer from 0 to M-1 where M is the number of nodes). +- `local_rank`: The rank of this replica on its node (an integer from 0 to K-1 where K is the number of replicas on this node). + +:::{note} +**Accessing rank values:** + +To use the rank in your code, access the `.rank` attribute to get the integer value: + +```python +context = serve.get_replica_context() +my_rank = context.rank.rank # Get the integer rank value +my_node_rank = context.rank.node_rank # Get the node rank +my_local_rank = context.rank.local_rank # Get the local rank on this node +``` + +Most use cases only need the global `rank` value. The `node_rank` and `local_rank` are useful for advanced scenarios such as coordinating replicas on the same node. +::: + ## Handle rank changes with reconfigure When a replica's rank changes (such as during downscaling), Ray Serve can automatically call the `reconfigure` method on your deployment class to notify it of the new rank. This allows you to update replica-specific state when ranks change. @@ -54,15 +74,15 @@ The following example shows how to implement `reconfigure` to handle rank change Ray Serve automatically calls your `reconfigure` method in the following situations: 1. **At replica startup:** When a replica starts, if your deployment has both a `reconfigure` method and a `user_config`, Ray Serve calls `reconfigure` after running `__init__`. This lets you initialize rank-aware state without duplicating code between `__init__` and `reconfigure`. -2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank. -3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank. +2. **When you update user_config:** When you redeploy with a new `user_config`, Ray Serve calls `reconfigure` on all running replicas. If your `reconfigure` method includes `rank` as a parameter, Ray Serve passes both the new `user_config` and the current rank as a `ReplicaRank` object. +3. **When a replica's rank changes:** During downscaling, ranks may be reassigned to maintain contiguity (0 to N-1). If your `reconfigure` method includes `rank` as a parameter and your deployment has a `user_config`, Ray Serve calls `reconfigure` with the existing `user_config` and the new rank as a `ReplicaRank` object. :::{note} **Requirements to receive rank updates:** To get rank changes through `reconfigure`, your deployment needs: - A class-based deployment (function deployments don't support `reconfigure`) -- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: int)` +- A `reconfigure` method with `rank` as a parameter: `def reconfigure(self, user_config, rank: ReplicaRank)` - A `user_config` in your deployment (even if it's just an empty dict: `user_config={}`) Without a `user_config`, Ray Serve won't call `reconfigure` for rank changes. diff --git a/doc/source/serve/doc_code/replica_rank.py b/doc/source/serve/doc_code/replica_rank.py index f9a43ae9414f..3ff444c420e8 100644 --- a/doc/source/serve/doc_code/replica_rank.py +++ b/doc/source/serve/doc_code/replica_rank.py @@ -5,9 +5,10 @@ @serve.deployment(num_replicas=4) class ModelShard: def __call__(self): + context = serve.get_replica_context() return { - "rank": serve.get_replica_context().rank, - "world_size": serve.get_replica_context().world_size, + "rank": context.rank.rank, # Access the integer rank value + "world_size": context.world_size, } @@ -17,20 +18,21 @@ def __call__(self): # __reconfigure_rank_start__ from typing import Any from ray import serve +from ray.serve.schema import ReplicaRank @serve.deployment(num_replicas=4, user_config={"name": "model_v1"}) class RankAwareModel: def __init__(self): context = serve.get_replica_context() - self.rank = context.rank + self.rank = context.rank.rank # Extract integer rank value self.world_size = context.world_size self.model_name = None print(f"Replica rank: {self.rank}/{self.world_size}") - async def reconfigure(self, user_config: Any, rank: int): + async def reconfigure(self, user_config: Any, rank: ReplicaRank): """Called when user_config or rank changes.""" - self.rank = rank + self.rank = rank.rank # Extract integer rank value from ReplicaRank object self.world_size = serve.get_replica_context().world_size self.model_name = user_config.get("name") print(f"Reconfigured: rank={self.rank}, model={self.model_name}") diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 9460147e791a..0beec58c7815 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -254,7 +254,7 @@ def __init__( self._docs_path: Optional[str] = None self._route_patterns: Optional[List[str]] = None # Rank assigned to the replica. - self._rank: Optional[int] = None + self._rank: Optional[ReplicaRank] = None # Populated in `on_scheduled` or `recover`. self._actor_handle: ActorHandle = None self._placement_group: PlacementGroup = None @@ -290,7 +290,7 @@ def deployment_name(self) -> str: return self._deployment_id.name @property - def rank(self) -> Optional[int]: + def rank(self) -> Optional[ReplicaRank]: return self._rank @property @@ -442,7 +442,7 @@ def initialization_latency_s(self) -> Optional[float]: return self._initialization_latency_s def start( - self, deployment_info: DeploymentInfo, rank: int + self, deployment_info: DeploymentInfo, rank: ReplicaRank ) -> ReplicaSchedulingRequest: """Start the current DeploymentReplica instance. @@ -609,11 +609,7 @@ def _format_user_config(self, user_config: Any): temp = msgpack_deserialize(temp) return temp - def reconfigure( - self, - version: DeploymentVersion, - rank: int, - ) -> bool: + def reconfigure(self, version: DeploymentVersion, rank: ReplicaRank) -> bool: """ Update replica version. Also, updates the deployment config on the actor behind this DeploymentReplica instance if necessary. @@ -1170,7 +1166,7 @@ def initialization_latency_s(self) -> Optional[float]: return self._actor.initialization_latency_s def start( - self, deployment_info: DeploymentInfo, rank: int + self, deployment_info: DeploymentInfo, rank: ReplicaRank ) -> ReplicaSchedulingRequest: """ Start a new actor for current DeploymentReplica instance. @@ -1184,7 +1180,7 @@ def start( def reconfigure( self, version: DeploymentVersion, - rank: int, + rank: ReplicaRank, ) -> bool: """ Update replica version. Also, updates the deployment config on the actor @@ -1211,7 +1207,7 @@ def recover(self) -> bool: return True @property - def rank(self) -> Optional[int]: + def rank(self) -> Optional[ReplicaRank]: """Get the rank assigned to the replica.""" return self._actor.rank @@ -1695,9 +1691,11 @@ def _assign_rank_impl(): # Assign global rank rank = self._replica_rank_manager.assign_rank(replica_id) - return ReplicaRank(rank=rank) + return ReplicaRank(rank=rank, node_rank=-1, local_rank=-1) - return self._execute_with_error_handling(_assign_rank_impl, ReplicaRank(rank=0)) + return self._execute_with_error_handling( + _assign_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1) + ) def release_rank(self, replica_id: str) -> None: """Release rank for a replica. @@ -1776,10 +1774,10 @@ def _get_replica_rank_impl(): raise RuntimeError(f"Rank for {replica_id} not assigned") global_rank = self._replica_rank_manager.get_rank(replica_id) - return ReplicaRank(rank=global_rank) + return ReplicaRank(rank=global_rank, node_rank=-1, local_rank=-1) return self._execute_with_error_handling( - _get_replica_rank_impl, ReplicaRank(rank=0) + _get_replica_rank_impl, ReplicaRank(rank=0, node_rank=-1, local_rank=-1) ) def check_rank_consistency_and_reassign_minimally( @@ -2547,7 +2545,7 @@ def scale_deployment_replicas( self._target_state.version, ) scheduling_request = new_deployment_replica.start( - self._target_state.info, rank=assigned_rank.rank + self._target_state.info, rank=assigned_rank ) upscale.append(scheduling_request) @@ -2665,10 +2663,7 @@ def _check_startup_replicas( # data structure with RUNNING state. # Recover rank from the replica actor during controller restart replica_id = replica.replica_id.unique_id - recovered_rank = replica.rank - self._rank_manager.recover_rank( - replica_id, ReplicaRank(rank=recovered_rank) - ) + self._rank_manager.recover_rank(replica_id, replica.rank) # This replica should be now be added to handle's replica # set. self._replicas.add(ReplicaState.RUNNING, replica) @@ -2951,7 +2946,7 @@ def _reconfigure_replicas_with_new_ranks( # World size is calculated automatically from deployment config _ = replica.reconfigure( self._target_state.version, - rank=new_rank.rank, + rank=new_rank, ) updated_count += 1 diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index c468fb55f1ff..32b5721f3b71 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -116,7 +116,7 @@ RayServeException, ) from ray.serve.handle import DeploymentHandle -from ray.serve.schema import EncodingType, LoggingConfig +from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -129,7 +129,7 @@ Optional[str], int, int, - int, # rank + ReplicaRank, # rank Optional[List[str]], # route_patterns ] @@ -510,7 +510,7 @@ def __init__( version: DeploymentVersion, ingress: bool, route_prefix: str, - rank: int, + rank: ReplicaRank, ): self._version = version self._replica_id = replica_id @@ -610,7 +610,7 @@ def get_dynamically_created_handles(self) -> Set[DeploymentID]: return self._dynamically_created_handles def _set_internal_replica_context( - self, *, servable_object: Callable = None, rank: int = None + self, *, servable_object: Callable = None, rank: ReplicaRank = None ): # Calculate world_size from deployment config instead of storing it world_size = self._deployment_config.num_replicas @@ -961,7 +961,7 @@ async def initialize(self, deployment_config: DeploymentConfig): async def reconfigure( self, deployment_config: DeploymentConfig, - rank: int, + rank: ReplicaRank, route_prefix: Optional[str] = None, ): try: @@ -1186,7 +1186,7 @@ async def __init__( version: DeploymentVersion, ingress: bool, route_prefix: str, - rank: int, + rank: ReplicaRank, ): deployment_config = DeploymentConfig.from_proto_bytes( deployment_config_proto_bytes @@ -1305,7 +1305,7 @@ async def record_routing_stats(self) -> Dict[str, Any]: return await self._replica_impl.record_routing_stats() async def reconfigure( - self, deployment_config, rank: int, route_prefix: Optional[str] = None + self, deployment_config, rank: ReplicaRank, route_prefix: Optional[str] = None ) -> ReplicaMetadata: await self._replica_impl.reconfigure(deployment_config, rank, route_prefix) return self._replica_impl.get_metadata() @@ -1802,7 +1802,7 @@ async def _call_user_autoscaling_stats(self) -> Dict[str, Union[int, float]]: return result @_run_user_code - async def call_reconfigure(self, user_config: Optional[Any], rank: int): + async def call_reconfigure(self, user_config: Optional[Any], rank: ReplicaRank): self._raise_if_not_initialized("call_reconfigure") # NOTE(edoakes): there is the possibility of a race condition in user code if diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 3986430d4d18..1c9b92ad64ac 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -23,6 +23,7 @@ from ray.serve._private.replica_result import ReplicaResult from ray.serve.exceptions import RayServeException from ray.serve.grpc_util import RayServegRPCContext +from ray.serve.schema import ReplicaRank from ray.util.annotations import DeveloperAPI logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -48,7 +49,7 @@ class ReplicaContext: replica_id: ReplicaID servable_object: Callable _deployment_config: DeploymentConfig - rank: int + rank: ReplicaRank world_size: int _handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None @@ -113,7 +114,7 @@ def _set_internal_replica_context( replica_id: ReplicaID, servable_object: Callable, _deployment_config: DeploymentConfig, - rank: int, + rank: ReplicaRank, world_size: int, handle_registration_callback: Optional[Callable[[str, str], None]] = None, ): diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index f7638815443b..6318207303cb 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -1600,3 +1600,9 @@ class ReplicaRank(BaseModel): rank: int = Field( description="Global rank of the replica across all nodes scoped to the deployment." ) + + node_rank: int = Field(description="Rank of the node in the deployment.") + + local_rank: int = Field( + description="Rank of the replica on the node scoped to the deployment." + ) diff --git a/python/ray/serve/tests/test_replica_ranks.py b/python/ray/serve/tests/test_replica_ranks.py index 74e8ec20124c..ce74f643936f 100644 --- a/python/ray/serve/tests/test_replica_ranks.py +++ b/python/ray/serve/tests/test_replica_ranks.py @@ -22,6 +22,7 @@ check_deployment_status, check_num_replicas_eq, ) +from ray.serve.schema import ReplicaRank def get_controller() -> ServeController: @@ -97,7 +98,7 @@ def __init__(self): def __call__(self): context = serve.get_replica_context() - self.replica_rank = context.rank + self.replica_rank = context.rank.rank if context.rank else None self.world_size = context.world_size return { "rank": self.replica_rank, @@ -156,7 +157,7 @@ async def __call__(self): await signal_actor.wait.remote() context = serve.get_replica_context() return { - "rank": context.rank, + "rank": context.rank.rank if context.rank else None, "world_size": context.world_size, } @@ -214,7 +215,7 @@ class PersistentRankTracker: def __call__(self): context = serve.get_replica_context() return { - "rank": context.rank, + "rank": context.rank.rank if context.rank else None, "world_size": context.world_size, } @@ -265,7 +266,7 @@ class SingleReplicaTracker: def __call__(self): context = serve.get_replica_context() return { - "rank": context.rank, + "rank": context.rank.rank if context.rank else None, "world_size": context.world_size, } @@ -296,7 +297,7 @@ def __call__(self): context = serve.get_replica_context() return { "deployment": "deployment1", - "rank": context.rank, + "rank": context.rank.rank if context.rank else None, "world_size": context.world_size, } @@ -309,7 +310,7 @@ def __call__(self): context = serve.get_replica_context() return { "deployment": "deployment2", - "rank": context.rank, + "rank": context.rank.rank if context.rank else None, "world_size": context.world_size, } @@ -410,8 +411,9 @@ async def __call__(self): await signal_actor.wait.remote() return self.my_rank - async def reconfigure(self, user_config: Any, rank: int): - self.my_rank = rank + async def reconfigure(self, user_config: Any, rank: ReplicaRank): + # rank parameter is actually a ReplicaRank object, extract the integer value + self.my_rank = rank.rank handle = serve.run(ReconfigureRankTracker.bind()) wait_for_condition( diff --git a/python/ray/serve/tests/unit/test_deployment_rank_manager.py b/python/ray/serve/tests/unit/test_deployment_rank_manager.py index e7be0e515308..3ff3b04fa4f3 100644 --- a/python/ray/serve/tests/unit/test_deployment_rank_manager.py +++ b/python/ray/serve/tests/unit/test_deployment_rank_manager.py @@ -100,7 +100,9 @@ def test_release_rank_nonexistent_replica(self): def test_recover_rank_basic(self, rank_manager): """Test basic rank recovery.""" - rank_manager.recover_rank("replica_1", ReplicaRank(rank=5)) + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=5, node_rank=0, local_rank=0) + ) assert rank_manager.has_replica_rank("replica_1") assert rank_manager.get_replica_rank("replica_1").rank == 5 @@ -108,7 +110,9 @@ def test_recover_rank_basic(self, rank_manager): def test_recover_rank_updates_next_rank(self, rank_manager): """Test that recovering a high rank updates next_rank appropriately.""" rank_manager.assign_rank("replica_1") # Gets rank 0 - rank_manager.recover_rank("replica_2", ReplicaRank(rank=10)) + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=10, node_rank=0, local_rank=0) + ) # New replica should get rank 11 (next available after 10) rank = rank_manager.assign_rank("replica_3") @@ -124,7 +128,9 @@ def test_recover_rank_removes_from_available(self, rank_manager): rank_manager.release_rank("replica_1") # Rank 0 becomes available # Recover rank 0 for a new replica - rank_manager.recover_rank("replica_3", ReplicaRank(rank=0)) + rank_manager.recover_rank( + "replica_3", ReplicaRank(rank=0, node_rank=0, local_rank=0) + ) # Verify replica_3 has rank 0 assert rank_manager.has_replica_rank("replica_3") @@ -140,7 +146,9 @@ def test_recover_rank_duplicate_fails(self): rank_manager.assign_rank("replica_1") with pytest.raises(RuntimeError, match="already assigned"): - rank_manager.recover_rank("replica_1", ReplicaRank(rank=5)) + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=5, node_rank=0, local_rank=0) + ) def test_get_replica_rank_existing(self, rank_manager): """Test getting rank for existing replica.""" @@ -218,9 +226,15 @@ def test_check_rank_consistency_non_contiguous_ranks(self, rank_manager): replica3 = MockDeploymentReplica("replica_3") # Manually assign non-contiguous ranks using recover_rank - rank_manager.recover_rank("replica_1", ReplicaRank(rank=0)) - rank_manager.recover_rank("replica_2", ReplicaRank(rank=2)) # Gap at rank 1 - rank_manager.recover_rank("replica_3", ReplicaRank(rank=3)) + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=0, node_rank=0, local_rank=0) + ) + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=2, node_rank=0, local_rank=0) + ) # Gap at rank 1 + rank_manager.recover_rank( + "replica_3", ReplicaRank(rank=3, node_rank=0, local_rank=0) + ) result = rank_manager.check_rank_consistency_and_reassign_minimally( [replica1, replica2, replica3] @@ -243,13 +257,17 @@ def test_minimal_reassignment_keeps_existing_when_possible(self, rank_manager): replica4 = MockDeploymentReplica("replica_4") # Set up ranks: 0, 2, 5, 7 (non-contiguous) using recover_rank - rank_manager.recover_rank("replica_1", ReplicaRank(rank=0)) # Should keep this - rank_manager.recover_rank("replica_2", ReplicaRank(rank=2)) # Should keep this rank_manager.recover_rank( - "replica_3", ReplicaRank(rank=5) + "replica_1", ReplicaRank(rank=0, node_rank=0, local_rank=0) + ) # Should keep this + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=2, node_rank=0, local_rank=0) + ) # Should keep this + rank_manager.recover_rank( + "replica_3", ReplicaRank(rank=5, node_rank=0, local_rank=0) ) # Should be reassigned to 1 rank_manager.recover_rank( - "replica_4", ReplicaRank(rank=7) + "replica_4", ReplicaRank(rank=7, node_rank=0, local_rank=0) ) # Should be reassigned to 3 result = rank_manager.check_rank_consistency_and_reassign_minimally( @@ -297,8 +315,12 @@ def test_check_rank_consistency_duplicate_ranks_fails(self): replica2 = MockDeploymentReplica("replica_2") # Manually create duplicate ranks using recover_rank (this should never happen in normal operation) - rank_manager.recover_rank("replica_1", ReplicaRank(rank=0)) - rank_manager.recover_rank("replica_2", ReplicaRank(rank=0)) # Duplicate! + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=0, node_rank=0, local_rank=0) + ) + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=0, node_rank=0, local_rank=0) + ) # Duplicate! with pytest.raises(RuntimeError, match="Rank system is in an invalid state"): rank_manager.check_rank_consistency_and_reassign_minimally( @@ -356,7 +378,9 @@ def test_recover_rank_error_with_fail_on_rank_error_true(self): # Should raise RuntimeError for duplicate recovery with pytest.raises(RuntimeError, match="already assigned"): - rank_manager.recover_rank("replica_1", ReplicaRank(rank=5)) + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=5, node_rank=-1, local_rank=-1) + ) def test_recover_rank_error_with_fail_on_rank_error_false(self): """Test that recover_rank returns safe default when fail_on_rank_error=False.""" @@ -364,7 +388,9 @@ def test_recover_rank_error_with_fail_on_rank_error_false(self): rank_manager.assign_rank("replica_1") # Should return None instead of raising - result = rank_manager.recover_rank("replica_1", ReplicaRank(rank=5)) + result = rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=5, node_rank=-1, local_rank=-1) + ) assert result is None def test_get_replica_rank_error_with_fail_on_rank_error_true(self): @@ -423,8 +449,12 @@ def test_check_rank_consistency_with_duplicate_ranks_error_handling(self): replica2 = MockDeploymentReplica("replica_2") # Manually create duplicate ranks - rank_manager.recover_rank("replica_1", ReplicaRank(rank=0)) - rank_manager.recover_rank("replica_2", ReplicaRank(rank=0)) + rank_manager.recover_rank( + "replica_1", ReplicaRank(rank=0, node_rank=-1, local_rank=-1) + ) + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=0, node_rank=-1, local_rank=-1) + ) # Should return empty list instead of raising result = rank_manager.check_rank_consistency_and_reassign_minimally( @@ -449,7 +479,10 @@ def test_normal_operations_work_with_fail_on_rank_error_false(self): assert not rank_manager.has_replica_rank("replica_1") # Test normal recover - rank_manager.recover_rank("replica_2", ReplicaRank(rank=5)) + rank_manager.recover_rank( + "replica_2", ReplicaRank(rank=5, node_rank=-1, local_rank=-1) + ) + assert rank_manager.get_replica_rank("replica_2").rank == 5 # Test normal consistency check diff --git a/python/ray/serve/tests/unit/test_deployment_state.py b/python/ray/serve/tests/unit/test_deployment_state.py index df145bf0de31..c6891626570e 100644 --- a/python/ray/serve/tests/unit/test_deployment_state.py +++ b/python/ray/serve/tests/unit/test_deployment_state.py @@ -54,6 +54,7 @@ get_capacity_adjusted_num_replicas, get_random_string, ) +from ray.serve.schema import ReplicaRank from ray.util.placement_group import validate_placement_group # Global variable that is fetched during controller recovery that @@ -64,7 +65,7 @@ # loop, so we can't "mark" a replica dead through a method. This global # state is cleared after each test that uses the fixtures in this file. dead_replicas_context = set() -replica_rank_context = {} +replica_rank_context: Dict[str, ReplicaRank] = {} TEST_DEPLOYMENT_ID = DeploymentID(name="test_deployment", app_name="test_app") TEST_DEPLOYMENT_ID_2 = DeploymentID(name="test_deployment_2", app_name="test_app") @@ -225,7 +226,7 @@ def set_node_id(self, node_id: str): def set_actor_id(self, actor_id: str): self._actor_id = actor_id - def start(self, deployment_info: DeploymentInfo, rank: int): + def start(self, deployment_info: DeploymentInfo, rank: ReplicaRank): self.started = True self._rank = rank replica_rank_context[self._replica_id.unique_id] = rank @@ -246,13 +247,13 @@ def _on_scheduled_stub(*args, **kwargs): ) @property - def rank(self) -> Optional[int]: + def rank(self) -> Optional[ReplicaRank]: return self._rank def reconfigure( self, version: DeploymentVersion, - rank: int = None, + rank: ReplicaRank = None, ): self.started = True updating = self.version.requires_actor_reconfigure(version) @@ -5434,10 +5435,18 @@ def test_complex_reassignment_scenario(self, mock_deployment_state_manager): # Simulate very scattered ranks in global context: 0, 3, 7, 10 global replica_rank_context replica_rank_context.clear() - replica_rank_context[replica_ids[0].unique_id] = 0 - replica_rank_context[replica_ids[1].unique_id] = 3 - replica_rank_context[replica_ids[2].unique_id] = 7 - replica_rank_context[replica_ids[3].unique_id] = 10 + replica_rank_context[replica_ids[0].unique_id] = ReplicaRank( + rank=0, node_rank=-1, local_rank=-1 + ) + replica_rank_context[replica_ids[1].unique_id] = ReplicaRank( + rank=3, node_rank=-1, local_rank=-1 + ) + replica_rank_context[replica_ids[2].unique_id] = ReplicaRank( + rank=7, node_rank=-1, local_rank=-1 + ) + replica_rank_context[replica_ids[3].unique_id] = ReplicaRank( + rank=10, node_rank=-1, local_rank=-1 + ) # Simulate controller crashed! Create a new deployment state manager # with the existing replica IDs to trigger recovery