Skip to content

Commit 1219fa6

Browse files
authored
correct pseudo code in examples.md (#4172)
* correct pseudo code in examples.md * update to a training example
1 parent 2cbf5c6 commit 1219fa6

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

docs/tutorials/examples.md

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,30 @@ model = Model()
3838
criterion = ...
3939
optimizer = ...
4040
model.train()
41+
# Move model and loss criterion to xpu before calling ipex.optimize()
42+
model = model.to("xpu")
43+
criterion = criterion.to("xpu")
44+
4145
# For Float32
4246
model, optimizer = ipex.optimize(model, optimizer=optimizer)
4347
# For BFloat16
4448
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
4549
...
46-
# For Float32
47-
output = model(data)
48-
...
49-
# For BFloat16
50-
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
51-
output = model(input)
50+
dataloader = ...
51+
for (input, target) in dataloader:
52+
input = input.to("xpu")
53+
target = target.to("xpu")
54+
optimizer.zero_grad()
55+
# For Float32
56+
output = model(input)
57+
58+
# For BFloat16
59+
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
60+
output = model(input)
61+
62+
loss = criterion(output, target)
63+
loss.backward()
64+
optimizer.step()
5265
...
5366
```
5467

0 commit comments

Comments
 (0)