File tree Expand file tree Collapse file tree 1 file changed +19
-6
lines changed Expand file tree Collapse file tree 1 file changed +19
-6
lines changed Original file line number Diff line number Diff line change @@ -38,17 +38,30 @@ model = Model()
3838criterion = ...
3939optimizer = ...
4040model.train()
41+ # Move model and loss criterion to xpu before calling ipex.optimize()
42+ model = model.to("xpu")
43+ criterion = criterion.to("xpu")
44+
4145# For Float32
4246model, optimizer = ipex.optimize(model, optimizer=optimizer)
4347# For BFloat16
4448model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
4549...
46- # For Float32
47- output = model(data)
48- ...
49- # For BFloat16
50- with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
51- output = model(input)
50+ dataloader = ...
51+ for (input, target) in dataloader:
52+ input = input.to("xpu")
53+ target = target.to("xpu")
54+ optimizer.zero_grad()
55+ # For Float32
56+ output = model(input)
57+
58+ # For BFloat16
59+ with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
60+ output = model(input)
61+
62+ loss = criterion(output, target)
63+ loss.backward()
64+ optimizer.step()
5265...
5366```
5467
You can’t perform that action at this time.
0 commit comments