Skip to content

Commit f93dea0

Browse files
authored
fix lstm ut of bf16 cpu path (#148)
1 parent ccf8cb8 commit f93dea0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/cpu/test_autocast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ def _test_lstm(self, training, bf16, prec = 1e-5):
261261

262262
with ipex.amp.autocast(enabled=bf16, configure=ipex.conf.AmpConf(torch.bfloat16)):
263263
if empty_state:
264-
y, hy = model(self._cast_dtype(input, bf16))
264+
y, hy = self._cast_dtype(model, bf16)(self._cast_dtype(input, bf16))
265265
y_ipex, hy_ipex = model_ipex(input)
266266
else:
267-
y, hy = model(input, (self._cast_dtype(h, bf16), self._cast_dtype(c, bf16)))
267+
y, hy = self._cast_dtype(model, bf16)(self._cast_dtype(input, bf16), (self._cast_dtype(h, bf16), self._cast_dtype(c, bf16)))
268268
y_ipex, hy_ipex = model_ipex(input, (h, c))
269269

270270
if not training and bf16:

0 commit comments

Comments
 (0)