Skip to content
2 changes: 1 addition & 1 deletion benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, capacity: int):
rank = args.rank
storage_type = args.storage

torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
16 changes: 11 additions & 5 deletions docs/source/reference/collectors_weightsync.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Weight update schemes can be used outside of collectors for custom synchronizati
The new simplified API provides four core methods for weight synchronization:

- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side
- ``get_sender()`` - Get the configured sender instance
- ``get_receiver()`` - Get the configured receiver instance

Expand Down Expand Up @@ -85,16 +85,16 @@ Here's a basic example:
# or sender.send_async(weights); sender.wait_async() # Asynchronous send

# On the worker process side:
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
# scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy)
# receiver = scheme.get_receiver()
# # Non-blocking check for new weights
# if receiver.receive(timeout=0.001):
# # Weights were received and applied

# Example 2: Shared memory weight synchronization
# ------------------------------------------------
# Create shared memory scheme with auto-registration
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
# Create shared memory scheme
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict")

# Initialize with pipes for lazy registration
parent_pipe2, child_pipe2 = mp.Pipe()
Expand Down Expand Up @@ -159,7 +159,7 @@ across multiple inference workers:
# Example 2: Multiple collectors with shared memory
# --------------------------------------------------
# Shared memory is more efficient for frequent updates
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict")

collector = MultiSyncDataCollector(
create_env_fn=[
Expand Down Expand Up @@ -198,6 +198,9 @@ Weight Senders
:template: rl_template.rst

WeightSender
MPWeightSender
RPCWeightSender
DistributedWeightSender
RayModuleTransformSender

Weight Receivers
Expand All @@ -208,6 +211,9 @@ Weight Receivers
:template: rl_template.rst

WeightReceiver
MPWeightReceiver
RPCWeightReceiver
DistributedWeightReceiver
RayModuleTransformReceiver

Transports
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/multi_weight_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms.module import ModuleTransform
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
from torchrl.weight_update import MultiProcessWeightSyncScheme


def make_module():
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/weight_sync_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory():
env.close()

# Shared memory is more efficient for frequent updates
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
scheme = SharedMemWeightSyncScheme(strategy="tensordict")

print("Creating multi-collector with shared memory...")
collector = MultiSyncDataCollector(
Expand Down
4 changes: 2 additions & 2 deletions examples/collectors/weight_sync_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def example_shared_memory_sync():
# Create a simple policy
policy = nn.Linear(4, 2)

# Create shared memory scheme with auto-registration
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
# Create shared memory scheme
scheme = SharedMemWeightSyncScheme(strategy="tensordict")
sender = scheme.create_sender()

# Create pipe for lazy registration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self, capacity: int):
if __name__ == "__main__":
args = parser.parse_args()
rank = args.rank
torchrl_logger.info(f"Rank: {rank}")
torchrl_logger.debug(f"RANK: {rank}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
16 changes: 9 additions & 7 deletions test/services/test_python_executor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init):
result = x + y
print(f"Result: {result}")
"""
result = ray.get(executor.execute.remote(code), timeout=2)
result = ray.get(executor.execute.remote(code), timeout=10)

assert result["success"] is True
assert "Result: 30" in result["stdout"]
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init):

# Execute code with an error
code = "raise ValueError('Test error')"
result = ray.get(executor.execute.remote(code), timeout=2)
result = ray.get(executor.execute.remote(code), timeout=10)

assert result["success"] is False
assert "ValueError: Test error" in result["stderr"]
Expand All @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init):
"python_executor",
PythonExecutorService,
pool_size=4,
timeout=5.0,
timeout=10.0,
num_cpus=4,
max_concurrency=4,
)
Expand All @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init):
code = f"print('Execution {i}')"
futures.append(executor.execute.remote(code))

# Wait for all to complete
results = ray.get(futures, timeout=5)
# Wait for all to complete with longer timeout
results = ray.get(futures, timeout=30)

# All should succeed
assert len(results) == 8
for i, result in enumerate(results):
assert result["success"] is True
assert f"Execution {i}" in result["stdout"]
assert result["success"] is True, f"Execution {i} failed: {result}"
assert (
f"Execution {i}" in result["stdout"]
), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}"

finally:
services.reset()
Expand Down
Loading