Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,31 +130,24 @@ void exp2_kernel(TensorIteratorBase& iter) {
}

template <typename scalar_t>
struct Logit0Functor {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
struct LogitFunctor {
scalar_t operator()(scalar_t x) const {
const T_ACC x_acc = static_cast<T_ACC>(x);
// suppress compiler optimization on data type promotion.
volatile T_ACC res = std::log(x_acc / (T_ACC(1) - x_acc));
return res;
return std::log(x / (1 - x));
}
};

template <typename scalar_t>
struct Logit1Functor {
struct LogitEpsFunctor {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
scalar_t operator()(scalar_t x) const {
const T_ACC x_acc = static_cast<T_ACC>(x);
T_ACC z = x_acc < lo_ ? lo_ : (x_acc > hi_ ? hi_ : x_acc);
// suppress compiler optimization on data type promotion.
volatile T_ACC res = std::log(z / (T_ACC(1) - z));
return res;
scalar_t x_clamped = x < low_ ? low_ : (x > high_ ? high_ : x);
return std::log(x_clamped / (1 - x_clamped));
}
Logit1Functor(const T_ACC lo, const T_ACC hi) : lo_(lo), hi_(hi) {}
LogitEpsFunctor(const T_ACC low, const T_ACC high) : low_(low), high_(high) {}

private:
T_ACC lo_;
T_ACC hi_;
scalar_t low_;
scalar_t high_;
};

void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
Expand All @@ -167,11 +160,11 @@ void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
const T_ACC eps = eps_scalar.to<T_ACC>();
if (eps < T_ACC(0)) {
gpu_kernel(iter, Logit0Functor<scalar_t>());
gpu_kernel(iter, LogitFunctor<scalar_t>());
} else {
const T_ACC lo = eps;
const T_ACC hi = T_ACC(1) - eps;
gpu_kernel(iter, Logit1Functor<scalar_t>(lo, hi));
const T_ACC low = eps;
const T_ACC high = T_ACC(1) - eps;
gpu_kernel(iter, LogitEpsFunctor<scalar_t>(low, high));
}
});
}
Expand Down
7 changes: 0 additions & 7 deletions test/xpu/extended/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@
"test_compare_cpu_nn_functional_huber_loss_xpu_bfloat16",
"test_compare_cpu_nansum_xpu_bfloat16",
"test_compare_cpu_nanmean_xpu_bfloat16",
# Align with CUDA impl by using accumulate type. But CPU doesn't use.
# When XPU uses original data type, the case passes.
"test_compare_cpu_logit_xpu_bfloat16",
# precison error
# Mismatched elements: 1 / 24 (4.2%)
# Greatest absolute difference: 0.03125 at index (0, 1, 0, 1) (up to 0.001 allowed)
# Greatest relative difference: 0.0048828125 at index (0, 1, 0, 1) (up to 0.001 allowed)
"test_compare_cpu_nn_functional_interpolate_bilinear_xpu_bfloat16",
# RuntimeError: "compute_index_ranges_weights" not implemented for 'Half'
"test_compare_cpu_nn_functional_interpolate_bilinear_xpu_float16",
Expand Down
Loading