File tree Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 120120parser .add_argument ('--reparam' , default = False , action = 'store_true' ,
121121 help = 'Reparameterize model' )
122122parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
123+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
124+ help = "torch.compile mode (default: None)." )
123125
124126# codegen (model compilation) options
125127scripting_group = parser .add_mutually_exclusive_group ()
@@ -224,6 +226,7 @@ def __init__(
224226 device = 'cuda' ,
225227 torchscript = False ,
226228 torchcompile = None ,
229+ torchcompile_mode = None ,
227230 aot_autograd = False ,
228231 reparam = False ,
229232 precision = 'float32' ,
@@ -278,7 +281,7 @@ def __init__(
278281 elif torchcompile :
279282 assert has_compile , 'A version of torch w/ torch.compile() is required, possibly a nightly.'
280283 torch ._dynamo .reset ()
281- self .model = torch .compile (self .model , backend = torchcompile )
284+ self .model = torch .compile (self .model , backend = torchcompile , mode = torchcompile_mode )
282285 self .compiled = True
283286 elif aot_autograd :
284287 assert has_functorch , "functorch is needed for --aot-autograd"
Original file line number Diff line number Diff line change 114114parser .add_argument ('--fuser' , default = '' , type = str ,
115115 help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
116116parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
117+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
118+ help = "torch.compile mode (default: None)." )
117119
118120scripting_group = parser .add_mutually_exclusive_group ()
119121scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
@@ -216,7 +218,7 @@ def main():
216218 elif args .torchcompile :
217219 assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
218220 torch ._dynamo .reset ()
219- model = torch .compile (model , backend = args .torchcompile )
221+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
220222 elif args .aot_autograd :
221223 assert has_functorch , "functorch is needed for --aot-autograd"
222224 model = memory_efficient_fusion (model )
Original file line number Diff line number Diff line change 161161 help = 'Head initialization scale' )
162162group .add_argument ('--head-init-bias' , default = None , type = float ,
163163 help = 'Head initialization bias value' )
164+ group .add_argument ('--torchcompile-mode' , type = str , default = None ,
165+ help = "torch.compile mode (default: None)." )
164166
165167# scripting / codegen
166168scripting_group = group .add_mutually_exclusive_group ()
@@ -627,7 +629,7 @@ def main():
627629 if args .torchcompile :
628630 # torch compile should be done after DDP
629631 assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
630- model = torch .compile (model , backend = args .torchcompile )
632+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
631633
632634 # create the train and eval datasets
633635 if args .data and not args .data_dir :
Original file line number Diff line number Diff line change 139139parser .add_argument ('--reparam' , default = False , action = 'store_true' ,
140140 help = 'Reparameterize model' )
141141parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
142-
142+ parser .add_argument ('--torchcompile-mode' , type = str , default = None ,
143+ help = "torch.compile mode (default: None)." )
143144
144145scripting_group = parser .add_mutually_exclusive_group ()
145146scripting_group .add_argument ('--torchscript' , default = False , action = 'store_true' ,
@@ -246,7 +247,7 @@ def validate(args):
246247 elif args .torchcompile :
247248 assert has_compile , 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
248249 torch ._dynamo .reset ()
249- model = torch .compile (model , backend = args .torchcompile )
250+ model = torch .compile (model , backend = args .torchcompile , mode = args . torchcompile_mode )
250251 elif args .aot_autograd :
251252 assert has_functorch , "functorch is needed for --aot-autograd"
252253 model = memory_efficient_fusion (model )
You can’t perform that action at this time.
0 commit comments