@@ -377,26 +377,24 @@ def auto_plan_config(
377377
378378 # Collect manual overrides (values explicitly set in config)
379379 manual_overrides = {}
380+ training_cfg = getattr (config .system , "training" , None ) if hasattr (config , "system" ) else None
380381 if hasattr (config , "data" ):
381- if hasattr ( config . data , "batch_size" ) and config . data . batch_size is not None :
382- manual_overrides ["batch_size" ] = config . data .batch_size
383- if hasattr ( config . data , "num_workers" ) and config . data . num_workers is not None :
384- manual_overrides ["num_workers" ] = config . data .num_workers
382+ if training_cfg and getattr ( training_cfg , "batch_size" , None ) is not None :
383+ manual_overrides ["batch_size" ] = training_cfg .batch_size
384+ if training_cfg and getattr ( training_cfg , "num_workers" , None ) is not None :
385+ manual_overrides ["num_workers" ] = training_cfg .num_workers
385386 if hasattr (config .data , "patch_size" ) and config .data .patch_size is not None :
386387 manual_overrides ["patch_size" ] = config .data .patch_size
387388
388- if hasattr (config , "training" ):
389- if hasattr (config .training , "precision" ) and config .training .precision is not None :
390- manual_overrides ["precision" ] = config .training .precision
391- if (
392- hasattr (config .training , "accumulate_grad_batches" )
393- and config .training .accumulate_grad_batches is not None
394- ):
395- manual_overrides ["accumulate_grad_batches" ] = config .training .accumulate_grad_batches
389+ if hasattr (config , "optimization" ):
390+ if getattr (config .optimization , "precision" , None ) is not None :
391+ manual_overrides ["precision" ] = config .optimization .precision
392+ if getattr (config .optimization , "accumulate_grad_batches" , None ) is not None :
393+ manual_overrides ["accumulate_grad_batches" ] = config .optimization .accumulate_grad_batches
396394
397- if hasattr (config , "optimizer" ):
398- if hasattr ( config . optimizer , "lr" ) and config . optimizer . lr is not None :
399- manual_overrides ["lr" ] = config . optimizer .lr
395+ opt_cfg = getattr (config . optimization , "optimizer" , None )
396+ if opt_cfg and getattr ( opt_cfg , "lr" , None ) is not None :
397+ manual_overrides ["lr" ] = opt_cfg .lr
400398
401399 # Create planner
402400 planner = AutoConfigPlanner (
@@ -408,9 +406,8 @@ def auto_plan_config(
408406
409407 # Plan
410408 use_mixed_precision = not (
411- hasattr (config , "training" )
412- and hasattr (config .training , "precision" )
413- and config .training .precision == "32"
409+ hasattr (config , "optimization" )
410+ and getattr (config .optimization , "precision" , None ) == "32"
414411 )
415412
416413 result = planner .plan (
@@ -423,20 +420,20 @@ def auto_plan_config(
423420 # Update config with planned values (if not manually overridden)
424421 OmegaConf .set_struct (config , False ) # Allow adding new fields
425422
426- if "batch_size" not in manual_overrides :
427- config . data .batch_size = result .batch_size
428- if "num_workers" not in manual_overrides :
429- config . data .num_workers = result .num_workers
423+ if "batch_size" not in manual_overrides and training_cfg is not None :
424+ training_cfg .batch_size = result .batch_size
425+ if "num_workers" not in manual_overrides and training_cfg is not None :
426+ training_cfg .num_workers = result .num_workers
430427 if "patch_size" not in manual_overrides :
431428 config .data .patch_size = result .patch_size
432429
433430 if "precision" not in manual_overrides :
434- config .training .precision = result .precision
431+ config .optimization .precision = result .precision
435432 if "accumulate_grad_batches" not in manual_overrides :
436- config .training .accumulate_grad_batches = result .accumulate_grad_batches
433+ config .optimization .accumulate_grad_batches = result .accumulate_grad_batches
437434
438- if "lr" not in manual_overrides :
439- config .optimizer .lr = result .lr
435+ if "lr" not in manual_overrides and hasattr ( config , "optimization" ) :
436+ config .optimization . optimizer .lr = result .lr
440437
441438 OmegaConf .set_struct (config , True ) # Re-enable struct mode
442439
@@ -460,7 +457,7 @@ def auto_plan_config(
460457 cfg = auto_plan_config (cfg , print_results = True )
461458
462459 print ("\n Final Config Values:" )
463- print (f" batch_size: { cfg .data .batch_size } " )
460+ print (f" batch_size: { cfg .system . training .batch_size } " )
464461 print (f" patch_size: { cfg .data .patch_size } " )
465462 print (f" precision: { cfg .optimization .precision } " )
466463 print (f" lr: { cfg .optimization .optimizer .lr } " )
0 commit comments