@@ -94,6 +94,7 @@ def __init__(
9494 replay_buffer_class : Optional [type [ReplayBuffer ]] = None ,
9595 replay_buffer_kwargs : Optional [dict [str , Any ]] = None ,
9696 optimize_memory_usage : bool = False ,
97+ policy_delay : int = 1 ,
9798 ent_coef : Union [str , float ] = "auto" ,
9899 target_update_interval : int = 1 ,
99100 target_entropy : Union [str , float ] = "auto" ,
@@ -145,6 +146,7 @@ def __init__(
145146 self .target_update_interval = target_update_interval
146147 self .ent_coef_optimizer : Optional [th .optim .Adam ] = None
147148 self .top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
149+ self .policy_delay = policy_delay
148150
149151 if _init_setup_model :
150152 self ._setup_model ()
@@ -190,7 +192,7 @@ def _create_aliases(self) -> None:
190192 self .critic = self .policy .critic
191193 self .critic_target = self .policy .critic_target
192194
193- def train (self , gradient_steps : int , batch_size : int = 64 ) -> None :
195+ def train (self , gradient_steps : int , batch_size : int = 64 , train_freq : int = 1 ) -> None :
194196 # Switch to train mode (this affects batch norm / dropout)
195197 self .policy .set_training_mode (True )
196198 # Update optimizers learning rate
@@ -205,6 +207,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
205207 actor_losses , critic_losses = [], []
206208
207209 for gradient_step in range (gradient_steps ):
210+ self ._n_updates += 1
211+ update_actor = self ._n_updates % self .policy_delay == 0
208212 # Sample replay buffer
209213 replay_data = self .replay_buffer .sample (batch_size , env = self ._vec_normalize_env ) # type: ignore[union-attr]
210214
@@ -222,8 +226,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
222226 # so we don't change it with other losses
223227 # see https://github.com/rail-berkeley/softlearning/issues/60
224228 ent_coef = th .exp (self .log_ent_coef .detach ())
225- ent_coef_loss = - (self .log_ent_coef * (log_prob + self .target_entropy ).detach ()).mean ()
226- ent_coef_losses .append (ent_coef_loss .item ())
229+ if update_actor :
230+ ent_coef_loss = - (self .log_ent_coef * (log_prob + self .target_entropy ).detach ()).mean ()
231+ ent_coef_losses .append (ent_coef_loss .item ())
227232 else :
228233 ent_coef = self .ent_coef_tensor
229234
@@ -265,24 +270,23 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
265270 critic_loss .backward ()
266271 self .critic .optimizer .step ()
267272
268- # Compute actor loss
269- qf_pi = self .critic (replay_data .observations , actions_pi ).mean (dim = 2 ).mean (dim = 1 , keepdim = True )
270- actor_loss = (ent_coef * log_prob - qf_pi ).mean ()
271- actor_losses .append (actor_loss .item ())
273+ if update_actor :
274+ qf_pi = self .critic (replay_data .observations , actions_pi ).mean (dim = 2 ).mean (dim = 1 , keepdim = True )
275+ actor_loss = (ent_coef * log_prob - qf_pi ).mean ()
276+ actor_losses .append (actor_loss .item ())
277+
278+ # Optimize the actor
279+ self .actor .optimizer .zero_grad ()
280+ actor_loss .backward ()
281+ self .actor .optimizer .step ()
272282
273- # Optimize the actor
274- self .actor .optimizer .zero_grad ()
275- actor_loss .backward ()
276- self .actor .optimizer .step ()
277283
278284 # Update target networks
279285 if gradient_step % self .target_update_interval == 0 :
280286 polyak_update (self .critic .parameters (), self .critic_target .parameters (), self .tau )
281287 # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
282288 polyak_update (self .batch_norm_stats , self .batch_norm_stats_target , 1.0 )
283289
284- self ._n_updates += gradient_steps
285-
286290 self .logger .record ("train/n_updates" , self ._n_updates , exclude = "tensorboard" )
287291 self .logger .record ("train/ent_coef" , np .mean (ent_coefs ))
288292 self .logger .record ("train/actor_loss" , np .mean (actor_losses ))
0 commit comments