Skip to content

Commit 92a4b1b

Browse files
committed
fix state dict device
1 parent 484eb9c commit 92a4b1b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/test_collector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,7 @@ def create_env():
15151515
].keys()
15161516
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
15171517
torch.testing.assert_close(
1518-
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1518+
state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(),
15191519
policy_state_dict[k].cpu(),
15201520
)
15211521

@@ -1533,7 +1533,7 @@ def create_env():
15331533
AssertionError
15341534
) if torch.cuda.is_available() else nullcontext():
15351535
torch.testing.assert_close(
1536-
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1536+
state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(),
15371537
policy_state_dict[k].cpu(),
15381538
)
15391539

@@ -1546,7 +1546,7 @@ def create_env():
15461546
for worker in range(3):
15471547
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
15481548
torch.testing.assert_close(
1549-
state_dict[f"worker{worker}"]["policy_state_dict"][k],
1549+
state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(),
15501550
policy_state_dict[k].cpu(),
15511551
)
15521552
finally:

0 commit comments

Comments
 (0)