@@ -364,18 +364,26 @@ def on_train_batch_end(
364364
365365 # For manual optimization, we save the model state that was captured in training_step
366366 # before the optimizer step. The test case saves this state in model.saved_models.
367- if hasattr (pl_module , "saved_models" ) and pl_module .saved_models and hasattr (pl_module , "layer" ):
368- latest_step = max (pl_module .saved_models .keys ())
367+ if (
368+ hasattr (pl_module , "saved_models" )
369+ and isinstance (pl_module .saved_models , dict )
370+ and pl_module .saved_models
371+ and hasattr (pl_module , "layer" )
372+ and isinstance (pl_module .layer , torch .nn .Module )
373+ ):
374+ # Get the latest saved state
375+ saved_models = pl_module .saved_models
376+ if not saved_models : # Check if dictionary is not empty
377+ return
378+
379+ latest_step = max (saved_models .keys ())
369380 # Save the checkpoint with the pre-optimization state
370381 with torch .no_grad ():
371382 # Save the current state
372- if not isinstance (pl_module .layer , torch .nn .Module ):
373- raise TypeError ("pl_module.layer must be a torch.nn.Module for state dict operations" )
374-
375383 original_state = {k : v .detach ().clone () for k , v in pl_module .layer .state_dict ().items ()}
376384 try :
377385 # Restore the pre-optimization state
378- saved_state = pl_module . saved_models [latest_step ]
386+ saved_state = saved_models [latest_step ]
379387 if not isinstance (saved_state , dict ):
380388 raise TypeError ("Saved model state must be a dictionary" )
381389
0 commit comments