Skip to content

Commit f851770

Browse files
add reference counting for get_autocast_dtype (intel#238)
* add reference counting for get_autocast_dtype * fix test error
1 parent 8d0c479 commit f851770

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ void InitIpexModuleBindings(py::module m) {
5252
// ipex amp autocast
5353
m.def("get_autocast_dtype", []() {
5454
at::ScalarType current_dtype = torch_ipex::autocast::get_autocast_dtype();
55-
return py::reinterpret_steal<py::object>(
56-
(PyObject*)torch::getTHPDtype(current_dtype));
55+
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
56+
Py_INCREF(dtype);
57+
return py::reinterpret_steal<py::object>(dtype);
5758
});
5859
m.def("set_autocast_dtype", [](py::object dtype) {
5960
at::ScalarType target_dtype =

torch_ipex_py/quantization/quantization_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import functools
33
import warnings
4+
import copy
45
import numpy as np
56
import intel_extension_for_pytorch._C as core
67
from .. import conf
@@ -167,14 +168,16 @@ def decorate_autocast(*args, **kwargs):
167168
def convert(model, conf, inputs):
168169
# pre-conver model's parameters dtype if it has conv, linear
169170
# and Embedding for bfloat16 path.
171+
model_ = model
170172
if torch.is_autocast_cpu_enabled() and core.get_autocast_dtype() == torch.bfloat16:
171-
model = utils._convert_module_data_type(model, torch.bfloat16)
173+
model_ = utils._convert_module_data_type(copy.deepcopy(model), torch.bfloat16)
174+
172175
core.disable_jit_opt()
173176
core._jit_set_llga_enabled(True)
174177
torch._C._jit_set_profiling_mode(True)
175178
torch._C._jit_set_profiling_executor(True)
176179
with torch.no_grad(), _quantization_int8():
177-
trace_model = torch.jit.trace(model, inputs, check_trace=False)
180+
trace_model = torch.jit.trace(model_, inputs, check_trace=False)
178181
trace_model = torch.jit.freeze(trace_model)
179182

180183
return trace_model

0 commit comments

Comments
 (0)