File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -1515,7 +1515,7 @@ def create_env():
15151515 ].keys ()
15161516 for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
15171517 torch .testing .assert_close (
1518- state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1518+ state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ]. cpu () ,
15191519 policy_state_dict [k ].cpu (),
15201520 )
15211521
@@ -1533,7 +1533,7 @@ def create_env():
15331533 AssertionError
15341534 ) if torch .cuda .is_available () else nullcontext ():
15351535 torch .testing .assert_close (
1536- state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1536+ state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ]. cpu () ,
15371537 policy_state_dict [k ].cpu (),
15381538 )
15391539
@@ -1546,7 +1546,7 @@ def create_env():
15461546 for worker in range (3 ):
15471547 for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
15481548 torch .testing .assert_close (
1549- state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
1549+ state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ]. cpu () ,
15501550 policy_state_dict [k ].cpu (),
15511551 )
15521552 finally :
You can’t perform that action at this time.
0 commit comments