We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8630947 commit 2d09181Copy full SHA for 2d09181
train.py
@@ -512,7 +512,7 @@ def main():
512
**args.model_kwargs,
513
)
514
if args.head_init_scale is not None:
515
- with torch.inference_mode():
+ with torch.no_grad():
516
model.get_classifier().weight.mul_(args.head_init_scale)
517
model.get_classifier().bias.mul_(args.head_init_scale)
518
if args.head_init_bias is not None:
0 commit comments