1- #include < torch/csrc/autograd/function.h>
1+ #include " ExtendOPs.h"
2+ #include " cpu/bf16/vec/bf16_vec_kernel.h"
3+ #include " cpu/int8/vec/int8_vec_kernel.h"
4+ #include " torch_ipex/csrc/autocast_mode.h"
5+ #include " torch_ipex/csrc/autocast_verbose.h"
6+ #include " torch_ipex/csrc/cpu/CustomOPs.h"
7+ #include " torch_ipex/csrc/quantization/AutoCast.hpp"
8+ #include " torch_ipex/csrc/rw_lock.h"
29#include < ATen/Parallel.h>
310#include < ATen/Tensor.h>
11+ #include < ATen/quantized/Quantizer.h>
12+ #include < algorithm>
13+ #include < c10/util/Exception.h>
414#include < c10/util/Optional.h>
515#include < torch/csrc/autograd/custom_function.h>
616#include < torch/csrc/autograd/function.h>
717#include < torch/csrc/autograd/variable.h>
818#include < torch/script.h>
9- #include < c10/util/Exception.h>
10- #include < algorithm>
11- #include " cpu/bf16/vec/bf16_vec_kernel.h"
12- #include " ExtendOPs.h"
13- #include " torch_ipex/csrc/autocast_mode.h"
14- #include " torch_ipex/csrc/autocast_verbose.h"
1519
1620namespace torch_ipex {
1721
@@ -40,9 +44,10 @@ bool AtenIpexTypeExt::embedding_bag_fast_path_sum(const at::Tensor weight, const
4044 return true ;
4145}
4246
43- template <typename T>
44- static inline at::Tensor _embedding_bag_index_add_select_fast (const at::Tensor select_indices,
45- const at::Tensor src, const at::Tensor offsets, bool include_last_offset) {
47+ template <typename T>
48+ static inline at::Tensor _embedding_bag_index_add_select_fast (
49+ const at::Tensor indices, const at::Tensor src, const at::Tensor offsets,
50+ bool include_last_offset) {
4651 int64_t ddim = src.size (1 );
4752 auto * src_data = src.data_ptr <T>();
4853 int64_t output_size = offsets.numel () - 1 ;
@@ -66,13 +71,13 @@ static inline at::Tensor _embedding_bag_index_add_select_fast(const at::Tensor s
6671 if (left_size > 0 ) {
6772 move_ker (&offsets_include_last_data[align32_size], &offsets_data[align32_size], left_size);
6873 }
69- offsets_include_last[output_size] = select_indices .numel ();
74+ offsets_include_last[output_size] = indices .numel ();
7075 offsets_data = offsets_include_last.data ();
7176 }
7277
7378 at::Tensor output = at::empty ({output_size, src.size (1 )}, src.options ());
7479 auto * output_data = output.data_ptr <T>();
75- auto indices_accessor = select_indices .accessor <int64_t , 1 >();
80+ auto indices_accessor = indices .accessor <int64_t , 1 >();
7681 at::parallel_for (0 , output_size, 16 , [&](int64_t start, int64_t end) {
7782 for (int64_t i = start; i < end; i++) {
7883 auto * out_data_ptr = &output_data[i * ddim];
@@ -380,6 +385,112 @@ at::Tensor AtenIpexTypeExt::embedding_bag(
380385 return NewEmbeddingBagOp::_forward (weight, indices, offsets, sparse, include_last_offset);
381386}
382387
388+ namespace cpu {
389+ using weakref_type =
390+ c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
391+ using val_type = std::tuple<weakref_type, at::Tensor>;
392+ std::unordered_map<c10::TensorImpl *, val_type> cached_qweight;
393+ torch_ipex::ReadWriteMutex rwmutex;
394+
395+ at::Tensor embedding_bag_int8_impl (const at::Tensor &qweight,
396+ const at::Tensor &indices,
397+ const at::Tensor &offsets,
398+ bool include_last_offset) {
399+ int64_t ddim = qweight.size (1 );
400+ double scale = at::native::q_scale_quant (qweight);
401+ int8_t *qweight_data =
402+ reinterpret_cast <int8_t *>(qweight.data_ptr <at::qint8>());
403+ int64_t output_size = offsets.numel () - 1 ;
404+ int64_t *offsets_data = offsets.data_ptr <int64_t >();
405+ std::vector<int64_t > offsets_include_last;
406+ if (!include_last_offset) {
407+ output_size = offsets.numel ();
408+ offsets_include_last.resize (output_size + 1 );
409+ int64_t *offsets_include_last_data = offsets_include_last.data ();
410+ int64_t iter_time = (output_size >> 5 );
411+ int64_t align32_size = (iter_time << 5 );
412+ int64_t left_size = output_size - align32_size;
413+ at::parallel_for (0 , iter_time, 16 , [&](int64_t start, int64_t end) {
414+ for (int64_t i = start; i < end; i += 1 ) {
415+ auto start_offset = i << 5 ;
416+ move_ker (&offsets_include_last_data[start_offset],
417+ &offsets_data[start_offset], 32 );
418+ }
419+ });
420+ if (left_size > 0 ) {
421+ move_ker (&offsets_include_last_data[align32_size],
422+ &offsets_data[align32_size], left_size);
423+ }
424+ offsets_include_last[output_size] = indices.numel ();
425+ offsets_data = offsets_include_last.data ();
426+ }
427+ // init output tensor
428+ at::QuantizerPtr output_quantizer =
429+ at::make_per_tensor_affine_quantizer (scale, /* zp=*/ 0 , at::kQInt8 );
430+ at::Tensor output = at::new_qtensor (/* sizes=*/ {output_size, qweight.size (1 )},
431+ qweight.options (), output_quantizer);
432+ int8_t *output_data =
433+ reinterpret_cast <int8_t *>(output.data_ptr <at::qint8>());
434+ auto indices_accessor = indices.accessor <int64_t , 1 >();
435+ at::parallel_for (0 , output_size, 16 , [&](int64_t start, int64_t end) {
436+ for (int64_t i = start; i < end; i++) {
437+ int8_t *out_data_ptr = &output_data[i * ddim];
438+ auto inputs_start = offsets_data[i];
439+ auto inputs_end = offsets_data[i + 1 ];
440+ if (inputs_start >= inputs_end) {
441+ zero_ker (out_data_ptr, ddim);
442+ } else {
443+ int8_t *select_data_ptr =
444+ &qweight_data[indices_accessor[inputs_start] * ddim];
445+ move_ker (out_data_ptr, select_data_ptr, ddim);
446+ }
447+ for (int64_t s = (inputs_start + 1 ); s < inputs_end; s++) {
448+ int8_t *select_data_ptr = &qweight_data[indices_accessor[s] * ddim];
449+ add_ker (out_data_ptr, select_data_ptr, ddim);
450+ }
451+ }
452+ });
453+
454+ return output;
455+ }
456+
457+ at::Tensor AtenIpexJITDev::dil_qembeddingbag (
458+ const at::Tensor weight, const at::Tensor indices, const at::Tensor offsets,
459+ bool sparse, bool include_last_offset, double w_scale, int64_t w_zp,
460+ at::ScalarType w_dtype, double o_scale, int64_t o_zp,
461+ at::ScalarType o_dtype) {
462+ at::Tensor qweight;
463+ {
464+ torch_ipex::UniqueReadLock<torch_ipex::ReadWriteMutex> lock (rwmutex);
465+ auto it = cached_qweight.find (weight.unsafeGetTensorImpl ());
466+ if (it != cached_qweight.end ()) {
467+ // cache hit
468+ qweight = std::get<1 >(it->second );
469+ }
470+ }
471+
472+ if (!qweight.defined ()) {
473+ // cache miss
474+ torch_ipex::UniqueWriteLock<torch_ipex::ReadWriteMutex> lock (rwmutex);
475+ auto it = cached_qweight.find (weight.unsafeGetTensorImpl ());
476+ if (it == cached_qweight.end ()) {
477+ // check again if qweight is still not cached
478+ qweight = at::quantize_per_tensor (weight, w_scale, 0 , at::kQInt8 );
479+ cached_qweight.emplace (
480+ weight.unsafeGetTensorImpl (),
481+ val_type{weakref_type (weight.getIntrusivePtr ()), qweight});
482+ } else {
483+ // qweight is cached
484+ qweight = std::get<1 >(it->second );
485+ }
486+ }
487+
488+ return embedding_bag_int8_impl (qweight, indices, offsets,
489+ include_last_offset);
490+ }
491+
492+ } // namespace cpu
493+
383494} // namespace torch_ipex
384495
385496namespace {
@@ -405,6 +516,10 @@ at::Tensor embedding_bag(
405516 verbose::OpNameGuard op_name (" embedding_bag" );
406517#endif
407518 auto target_type = get_autocast_dtype ();
519+ if (at::ScalarType::Char == target_type) {
520+ return int8::embedding_bag (weight, indices, offsets, sparse,
521+ include_last_offset);
522+ }
408523 // only have bf16 support now, keep fp32 for other target_type
409524 bool cast_to_bfloat16 = !at::GradMode::is_enabled () && at::kBFloat16 == target_type;
410525 auto casted_weight = cast_to_bfloat16 ? cpu_cached_cast (at::kBFloat16 , weight) : weight;
0 commit comments