Skip to content

Commit 7f1aa62

Browse files
authored
dlrm int8 support (#59)
* add rwlock for quantized weight cache * enable emb * enable int8 interaction * modify code format
1 parent 07b9ae7 commit 7f1aa62

File tree

15 files changed

+1462
-59
lines changed

15 files changed

+1462
-59
lines changed

tests/cpu/test_emb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def _test_emb(self, mode):
3838
self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values())
3939
self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values(), 0.01)
4040

41-
def test_emb_fast_path(self):
41+
def test_emb_fallback_path(self):
4242
self._test_emb(mode='mean')
4343

44-
def test_emb_fallback_path(self):
44+
def test_emb_fast_path(self):
4545
self._test_emb(mode='sum')
4646

4747
if __name__ == '__main__':

tests/cpu/test_jit_ipex_quantization.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,33 @@ def forward(self, x):
7979
.check("aten::flatten") \
8080
.run(graph)
8181

82+
@llga_test_env
83+
def test_embeddingbag_int8(self):
84+
m = nn.EmbeddingBag(10, 3, mode='sum', sparse=True)
85+
input = torch.LongTensor([1,2,4,5,4,3,2,9])
86+
offsets = torch.LongTensor([0,1,2,3,4,5,6,7])
87+
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
88+
graph = self.checkQuantizeTrace(m, [input, offsets], config_name="emb", qscheme=qscheme)
89+
self.assertGraphContainsExactly(graph, 'ipex::qembedding_bag', 1)
90+
91+
@llga_test_env
92+
def test_interaction_int8(self):
93+
class M(nn.Module):
94+
def __init__(self):
95+
super(M, self).__init__()
96+
self.f = ipex.interaction
8297

98+
def forward(self, *x):
99+
x = self.f(*x)
100+
return x
101+
102+
m = M()
103+
inputs = []
104+
for i in range(0, 27):
105+
inputs.append(torch.randn([128, 128]))
106+
for qscheme in [torch.per_tensor_symmetric]:
107+
graph = self.checkQuantizeTrace(m, inputs, config_name="interaction", qscheme=qscheme)
108+
self.assertGraphContainsExactly(graph, 'ipex::qinteraction', 1)
83109

84110
if __name__ == '__main__':
85111
run_tests()

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ class AtenIpexJITDev {
127127
const at::Tensor &right, at::Tensor out_opt,
128128
const c10::Scalar &div_input);
129129

130+
static at::Tensor dil_layernorm(const at::Tensor &input,
131+
at::IntArrayRef normalized_shape,
132+
const c10::optional<at::Tensor> &weight_opt,
133+
const c10::optional<at::Tensor> &bias_opt,
134+
float eps, bool cudnn_enable);
135+
130136
// n-dims tensor op
131137
static at::Tensor dil_convolution_nd_weight_base(
132138
const at::Tensor &input, const at::Tensor &weight,
@@ -187,11 +193,19 @@ class AtenIpexJITDev {
187193
at::IntArrayRef kernel_size, int64_t groups, int64_t output_channel,
188194
bool weight_channels_last, bool weight_prepacked, at::Tensor &accumu,
189195
at::Scalar alpha);
190-
static at::Tensor dil_layernorm(const at::Tensor &input,
191-
at::IntArrayRef normalized_shape,
192-
const c10::optional<at::Tensor> &weight_opt,
193-
const c10::optional<at::Tensor> &bias_opt,
194-
float eps, bool cudnn_enable);
196+
197+
// int8 op
198+
static at::Tensor dil_qembeddingbag(const at::Tensor weight,
199+
const at::Tensor indices,
200+
const at::Tensor offsets, bool sparse,
201+
bool include_last_offset, double w_scale,
202+
int64_t w_zp, at::ScalarType w_dtype,
203+
double o_scale, int64_t o_zp,
204+
at::ScalarType o_dtype);
205+
206+
static at::Tensor dil_qinteraction(const std::vector<at::Tensor> input,
207+
double o_scale, int64_t o_zp,
208+
at::ScalarType o_dtype);
195209
};
196210

197211
} // namespace cpu

torch_ipex/csrc/cpu/embeddingbag.cpp

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
#include <torch/csrc/autograd/function.h>
1+
#include "ExtendOPs.h"
2+
#include "cpu/bf16/vec/bf16_vec_kernel.h"
3+
#include "cpu/int8/vec/int8_vec_kernel.h"
4+
#include "torch_ipex/csrc/autocast_mode.h"
5+
#include "torch_ipex/csrc/autocast_verbose.h"
6+
#include "torch_ipex/csrc/cpu/CustomOPs.h"
7+
#include "torch_ipex/csrc/quantization/AutoCast.hpp"
8+
#include "torch_ipex/csrc/rw_lock.h"
29
#include <ATen/Parallel.h>
310
#include <ATen/Tensor.h>
11+
#include <ATen/quantized/Quantizer.h>
12+
#include <algorithm>
13+
#include <c10/util/Exception.h>
414
#include <c10/util/Optional.h>
515
#include <torch/csrc/autograd/custom_function.h>
616
#include <torch/csrc/autograd/function.h>
717
#include <torch/csrc/autograd/variable.h>
818
#include <torch/script.h>
9-
#include <c10/util/Exception.h>
10-
#include <algorithm>
11-
#include "cpu/bf16/vec/bf16_vec_kernel.h"
12-
#include "ExtendOPs.h"
13-
#include "torch_ipex/csrc/autocast_mode.h"
14-
#include "torch_ipex/csrc/autocast_verbose.h"
1519

1620
namespace torch_ipex {
1721

@@ -40,9 +44,10 @@ bool AtenIpexTypeExt::embedding_bag_fast_path_sum(const at::Tensor weight, const
4044
return true;
4145
}
4246

43-
template<typename T>
44-
static inline at::Tensor _embedding_bag_index_add_select_fast(const at::Tensor select_indices,
45-
const at::Tensor src, const at::Tensor offsets, bool include_last_offset) {
47+
template <typename T>
48+
static inline at::Tensor _embedding_bag_index_add_select_fast(
49+
const at::Tensor indices, const at::Tensor src, const at::Tensor offsets,
50+
bool include_last_offset) {
4651
int64_t ddim = src.size(1);
4752
auto* src_data = src.data_ptr<T>();
4853
int64_t output_size = offsets.numel() - 1;
@@ -66,13 +71,13 @@ static inline at::Tensor _embedding_bag_index_add_select_fast(const at::Tensor s
6671
if (left_size > 0) {
6772
move_ker(&offsets_include_last_data[align32_size], &offsets_data[align32_size], left_size);
6873
}
69-
offsets_include_last[output_size] = select_indices.numel();
74+
offsets_include_last[output_size] = indices.numel();
7075
offsets_data = offsets_include_last.data();
7176
}
7277

7378
at::Tensor output = at::empty({output_size, src.size(1)}, src.options());
7479
auto* output_data = output.data_ptr<T>();
75-
auto indices_accessor = select_indices.accessor<int64_t, 1>();
80+
auto indices_accessor = indices.accessor<int64_t, 1>();
7681
at::parallel_for(0, output_size, 16, [&](int64_t start, int64_t end) {
7782
for (int64_t i = start; i < end; i++) {
7883
auto* out_data_ptr = &output_data[i * ddim];
@@ -380,6 +385,112 @@ at::Tensor AtenIpexTypeExt::embedding_bag(
380385
return NewEmbeddingBagOp::_forward(weight, indices, offsets, sparse, include_last_offset);
381386
}
382387

388+
namespace cpu {
389+
using weakref_type =
390+
c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
391+
using val_type = std::tuple<weakref_type, at::Tensor>;
392+
std::unordered_map<c10::TensorImpl *, val_type> cached_qweight;
393+
torch_ipex::ReadWriteMutex rwmutex;
394+
395+
at::Tensor embedding_bag_int8_impl(const at::Tensor &qweight,
396+
const at::Tensor &indices,
397+
const at::Tensor &offsets,
398+
bool include_last_offset) {
399+
int64_t ddim = qweight.size(1);
400+
double scale = at::native::q_scale_quant(qweight);
401+
int8_t *qweight_data =
402+
reinterpret_cast<int8_t *>(qweight.data_ptr<at::qint8>());
403+
int64_t output_size = offsets.numel() - 1;
404+
int64_t *offsets_data = offsets.data_ptr<int64_t>();
405+
std::vector<int64_t> offsets_include_last;
406+
if (!include_last_offset) {
407+
output_size = offsets.numel();
408+
offsets_include_last.resize(output_size + 1);
409+
int64_t *offsets_include_last_data = offsets_include_last.data();
410+
int64_t iter_time = (output_size >> 5);
411+
int64_t align32_size = (iter_time << 5);
412+
int64_t left_size = output_size - align32_size;
413+
at::parallel_for(0, iter_time, 16, [&](int64_t start, int64_t end) {
414+
for (int64_t i = start; i < end; i += 1) {
415+
auto start_offset = i << 5;
416+
move_ker(&offsets_include_last_data[start_offset],
417+
&offsets_data[start_offset], 32);
418+
}
419+
});
420+
if (left_size > 0) {
421+
move_ker(&offsets_include_last_data[align32_size],
422+
&offsets_data[align32_size], left_size);
423+
}
424+
offsets_include_last[output_size] = indices.numel();
425+
offsets_data = offsets_include_last.data();
426+
}
427+
// init output tensor
428+
at::QuantizerPtr output_quantizer =
429+
at::make_per_tensor_affine_quantizer(scale, /*zp=*/0, at::kQInt8);
430+
at::Tensor output = at::new_qtensor(/*sizes=*/{output_size, qweight.size(1)},
431+
qweight.options(), output_quantizer);
432+
int8_t *output_data =
433+
reinterpret_cast<int8_t *>(output.data_ptr<at::qint8>());
434+
auto indices_accessor = indices.accessor<int64_t, 1>();
435+
at::parallel_for(0, output_size, 16, [&](int64_t start, int64_t end) {
436+
for (int64_t i = start; i < end; i++) {
437+
int8_t *out_data_ptr = &output_data[i * ddim];
438+
auto inputs_start = offsets_data[i];
439+
auto inputs_end = offsets_data[i + 1];
440+
if (inputs_start >= inputs_end) {
441+
zero_ker(out_data_ptr, ddim);
442+
} else {
443+
int8_t *select_data_ptr =
444+
&qweight_data[indices_accessor[inputs_start] * ddim];
445+
move_ker(out_data_ptr, select_data_ptr, ddim);
446+
}
447+
for (int64_t s = (inputs_start + 1); s < inputs_end; s++) {
448+
int8_t *select_data_ptr = &qweight_data[indices_accessor[s] * ddim];
449+
add_ker(out_data_ptr, select_data_ptr, ddim);
450+
}
451+
}
452+
});
453+
454+
return output;
455+
}
456+
457+
at::Tensor AtenIpexJITDev::dil_qembeddingbag(
458+
const at::Tensor weight, const at::Tensor indices, const at::Tensor offsets,
459+
bool sparse, bool include_last_offset, double w_scale, int64_t w_zp,
460+
at::ScalarType w_dtype, double o_scale, int64_t o_zp,
461+
at::ScalarType o_dtype) {
462+
at::Tensor qweight;
463+
{
464+
torch_ipex::UniqueReadLock<torch_ipex::ReadWriteMutex> lock(rwmutex);
465+
auto it = cached_qweight.find(weight.unsafeGetTensorImpl());
466+
if (it != cached_qweight.end()) {
467+
// cache hit
468+
qweight = std::get<1>(it->second);
469+
}
470+
}
471+
472+
if (!qweight.defined()) {
473+
// cache miss
474+
torch_ipex::UniqueWriteLock<torch_ipex::ReadWriteMutex> lock(rwmutex);
475+
auto it = cached_qweight.find(weight.unsafeGetTensorImpl());
476+
if (it == cached_qweight.end()) {
477+
// check again if qweight is still not cached
478+
qweight = at::quantize_per_tensor(weight, w_scale, 0, at::kQInt8);
479+
cached_qweight.emplace(
480+
weight.unsafeGetTensorImpl(),
481+
val_type{weakref_type(weight.getIntrusivePtr()), qweight});
482+
} else {
483+
// qweight is cached
484+
qweight = std::get<1>(it->second);
485+
}
486+
}
487+
488+
return embedding_bag_int8_impl(qweight, indices, offsets,
489+
include_last_offset);
490+
}
491+
492+
} // namespace cpu
493+
383494
} // namespace torch_ipex
384495

385496
namespace {
@@ -405,6 +516,10 @@ at::Tensor embedding_bag(
405516
verbose::OpNameGuard op_name("embedding_bag");
406517
#endif
407518
auto target_type = get_autocast_dtype();
519+
if (at::ScalarType::Char == target_type) {
520+
return int8::embedding_bag(weight, indices, offsets, sparse,
521+
include_last_offset);
522+
}
408523
// only have bf16 support now, keep fp32 for other target_type
409524
bool cast_to_bfloat16 = !at::GradMode::is_enabled() && at::kBFloat16 == target_type;
410525
auto casted_weight = cast_to_bfloat16 ? cpu_cached_cast(at::kBFloat16, weight) : weight;

0 commit comments

Comments
 (0)