Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This release changes the license from `BSD-2-Clause` to `BSD-3-Clause`.
* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619)
* Added `dpnp.exceptions` submodule to aggregate the generic exceptions used by dpnp [#2616](https://github.com/IntelPython/dpnp/pull/2616)
* Added implementation of `dpnp.scipy.special.erfcx` [#2596](https://github.com/IntelPython/dpnp/pull/2596)
* Added implementation of `dpnp.scipy.special.erfinv` and `dpnp.scipy.special.erfcinv` [#2624](https://github.com/IntelPython/dpnp/pull/2624)

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ static void populate(py::module_ m,
MACRO_DEFINE_IMPL(erf, Erf);
MACRO_DEFINE_IMPL(erfc, Erfc);
MACRO_DEFINE_IMPL(erfcx, Erfcx);
MACRO_DEFINE_IMPL(erfinv, Erfinv);
MACRO_DEFINE_IMPL(erfcinv, Erfcinv);
} // namespace impl

void init_erf_funcs(py::module_ m)
Expand All @@ -236,5 +238,13 @@ void init_erf_funcs(py::module_ m)
impl::populate<impl::ErfcxContigFactory, impl::ErfcxStridedFactory>(
m, "_erfcx", "", impl::erfcx_contig_dispatch_vector,
impl::erfcx_strided_dispatch_vector);

impl::populate<impl::ErfinvContigFactory, impl::ErfinvStridedFactory>(
m, "_erfinv", "", impl::erfinv_contig_dispatch_vector,
impl::erfinv_strided_dispatch_vector);

impl::populate<impl::ErfcinvContigFactory, impl::ErfcinvStridedFactory>(
m, "_erfcinv", "", impl::erfcinv_contig_dispatch_vector,
impl::erfcinv_strided_dispatch_vector);
}
} // namespace dpnp::extensions::ufunc
14 changes: 14 additions & 0 deletions dpnp/backend/extensions/vm/erf_funcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
MACRO_DEFINE_IMPL(erf, Erf);
MACRO_DEFINE_IMPL(erfc, Erfc);
MACRO_DEFINE_IMPL(erfcx, Erfcx);
MACRO_DEFINE_IMPL(erfinv, Erfinv);
MACRO_DEFINE_IMPL(erfcinv, Erfcinv);

template <template <typename fnT, typename T> typename factoryT>
static void populate(py::module_ m,
Expand Down Expand Up @@ -194,5 +196,17 @@ void init_erf_funcs(py::module_ m)
"Call `erfcx` function from OneMKL VM library to compute the scaled "
"complementary error function value of vector elements",
impl::erfcx_contig_dispatch_vector);

impl::populate<impl::ErfinvContigFactory>(
m, "_erfinv",
"Call `erfinv` function from OneMKL VM library to compute the inverse "
"of the error function value of vector elements",
impl::erfinv_contig_dispatch_vector);

impl::populate<impl::ErfcinvContigFactory>(
m, "_erfcinv",
"Call `erfcinv` function from OneMKL VM library to compute the inverse "
"of the complementary error function value of vector elements",
impl::erfcinv_contig_dispatch_vector);
}
} // namespace dpnp::extensions::vm
15 changes: 9 additions & 6 deletions dpnp/backend/kernels/elementwise_functions/erf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <sycl/ext/intel/math.hpp>
#else
#include "erfcx.hpp"
#include "erfinv.hpp"
#endif

namespace dpnp::kernels::erfs
Expand Down Expand Up @@ -85,13 +86,15 @@ struct BaseFunctor
template <typename ArgT, typename ResT> \
using __f_name__##Functor = BaseFunctor<__f_name__##Op, ArgT, ResT>;

MACRO_DEFINE_FUNCTOR(sycl::erf, Erf);
MACRO_DEFINE_FUNCTOR(sycl::erfc, Erfc);
MACRO_DEFINE_FUNCTOR(
#if defined(__SYCL_EXT_INTEL_MATH_SUPPORT)
sycl::ext::intel::math::erfcx,
using namespace sycl::ext::intel::math;
#else
impl::erfcx,
using namespace impl;
#endif
Erfcx);

MACRO_DEFINE_FUNCTOR(sycl::erf, Erf);
MACRO_DEFINE_FUNCTOR(sycl::erfc, Erfc);
MACRO_DEFINE_FUNCTOR(erfcx, Erfcx);
MACRO_DEFINE_FUNCTOR(erfinv, Erfinv);
MACRO_DEFINE_FUNCTOR(erfcinv, Erfcinv);
} // namespace dpnp::kernels::erfs
8 changes: 3 additions & 5 deletions dpnp/backend/kernels/elementwise_functions/erfcx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,19 +1622,17 @@ For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), with the
usual checks for overflow etcetera.
*/
template <typename Tp>
Tp erfcx(Tp x)
inline Tp erfcx(Tp x)
{
static_assert(std::is_floating_point_v<Tp>,
"erfcx requires a floating-point type");

if (x >= 0) {
if (x > 50) // continued-fraction expansion is faster
{
if (x > 50) { // continued-fraction expansion is faster
// 1/sqrt(pi)
constexpr Tp inv_sqrtpi = 0.564189583547756286948079451560772586L;

if (x > 5e7) // 1-term expansion, important to avoid overflow
{
if (x > 5e7) { // 1-term expansion, important to avoid overflow
return inv_sqrtpi / x;
}

Expand Down
219 changes: 219 additions & 0 deletions dpnp/backend/kernels/elementwise_functions/erfinv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// - Neither the name of the copyright holder nor the names of its contributors
// may be used to endorse or promote products derived from this software
// without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <limits>
#include <sycl/sycl.hpp>

namespace dpnp::kernels::erfs::impl
{
template <typename Tp>
inline Tp polevl(Tp x, const Tp *coeff, int i)
{
Tp p = *coeff++;

do {
p = p * x + *coeff++;
} while (--i);
return p;
}

template <typename Tp>
inline Tp p1evl(Tp x, const Tp *coeff, int i)
{
Tp p = x + *coeff++;

while (--i) {
p = p * x + *coeff++;
}
return p;
}

template <typename Tp>
inline Tp ndtri(Tp y0)
{
Tp y;
int code = 1;

if (y0 == 0.0) {
return -std::numeric_limits<Tp>::infinity();
}
else if (y0 == 1.0) {
return std::numeric_limits<Tp>::infinity();
}
else if (y0 < 0.0 || y0 > 1.0) {
return std::numeric_limits<Tp>::quiet_NaN();
}

// exp(-2)
constexpr Tp exp_minus2 = 0.13533528323661269189399949497248L;
if (y0 > (1.0 - exp_minus2)) {
y = 1.0 - y0;
code = 0;
}
else {
y = y0;
}

if (y > exp_minus2) {
// sqrt(2*pi)
constexpr Tp root_2_pi = 2.50662827463100050241576528481105L;

// approximation for 0 <= |y - 0.5| <= 3/8
constexpr Tp p[] = {
-5.99633501014107895267E1, 9.80010754185999661536E1,
-5.66762857469070293439E1, 1.39312609387279679503E1,
-1.23916583867381258016E0,
};
constexpr Tp q[] = {
1.95448858338141759834E0, 4.67627912898881538453E0,
8.63602421390890590575E1, -2.25462687854119370527E2,
2.00260212380060660359E2, -8.20372256168333339912E1,
1.59056225126211695515E1, -1.18331621121330003142E0,
};

y -= 0.5;
Tp y2 = y * y;
Tp x = y + y * (y2 * polevl(y2, p, 4) / p1evl(y2, q, 8));
return x * root_2_pi;
}

Tp x = sycl::sqrt(-2.0 * sycl::log(y));
Tp x0 = x - sycl::log(x) / x;
Tp z = 1.0 / x;

Tp x1;
if (x < 8.0) {
// approximation for 2 <= sqrt(-2*log(y)) < 8
constexpr Tp p[] = {
4.05544892305962419923E0, 3.15251094599893866154E1,
5.71628192246421288162E1, 4.40805073893200834700E1,
1.46849561928858024014E1, 2.18663306850790267539E0,
-1.40256079171354495875E-1, -3.50424626827848203418E-2,
-8.57456785154685413611E-4,
};

constexpr Tp q[] = {
1.57799883256466749731E1, 4.53907635128879210584E1,
4.13172038254672030440E1, 1.50425385692907503408E1,
2.50464946208309415979E0, -1.42182922854787788574E-1,
-3.80806407691578277194E-2, -9.33259480895457427372E-4,
};

x1 = z * polevl(z, p, 8) / p1evl(z, q, 8);
}
else {
// approximation for 8 <= sqrt(-2*log(y)) < 64
constexpr Tp p[] = {
3.23774891776946035970E0, 6.91522889068984211695E0,
3.93881025292474443415E0, 1.33303460815807542389E0,
2.01485389549179081538E-1, 1.23716634817820021358E-2,
3.01581553508235416007E-4, 2.65806974686737550832E-6,
6.23974539184983293730E-9,
};

constexpr Tp q[] = {
6.02427039364742014255E0, 3.67983563856160859403E0,
1.37702099489081330271E0, 2.16236993594496635890E-1,
1.34204006088543189037E-2, 3.28014464682127739104E-4,
2.89247864745380683936E-6, 6.79019408009981274425E-9,
};

x1 = z * polevl(z, p, 8) / p1evl(z, q, 8);
}

x = x0 - x1;
if (code != 0) {
x = -x;
}
return x;
}

template <typename Tp>
inline Tp erfinv(Tp y)
{
static_assert(std::is_floating_point_v<Tp>,
"erfinv requires a floating-point type");

constexpr Tp lower = -1;
constexpr Tp upper = 1;

constexpr Tp thresh = 1e-7;

// For small arguments, use the Taylor expansion.
// Otherwise, y + 1 loses precision for |y| << 1.
if ((-thresh < y) && (y < thresh)) {
// 2/sqrt(pi)
constexpr Tp inv_sqrtpi = 1.1283791670955125738961589031215452L;
return y / inv_sqrtpi;
}

if ((lower < y) && (y < upper)) {
// 1/sqrt(2)
constexpr Tp one_div_root_2 = 0.7071067811865475244008443621048490L;
return ndtri(0.5 * (y + 1)) * one_div_root_2;
}
else if (y == lower) {
return -std::numeric_limits<Tp>::infinity();
}
else if (y == upper) {
return std::numeric_limits<Tp>::infinity();
}
else if (sycl::isnan(y)) {
return y;
}
return std::numeric_limits<Tp>::quiet_NaN();
}

template <typename Tp>
inline Tp erfcinv(Tp y)
{
static_assert(std::is_floating_point_v<Tp>,
"erfcinv requires a floating-point type");

constexpr Tp lower = 0;
constexpr Tp upper = 2;

if ((lower < y) && (y < upper)) {
// 1/sqrt(2)
constexpr Tp one_div_root_2 = 0.7071067811865475244008443621048490L;
return -ndtri(0.5 * y) * one_div_root_2;
}
else if (y == lower) {
return std::numeric_limits<Tp>::infinity();
}
else if (y == upper) {
return -std::numeric_limits<Tp>::infinity();
}
else if (sycl::isnan(y)) {
return y;
}
return std::numeric_limits<Tp>::quiet_NaN();
}
} // namespace dpnp::kernels::erfs::impl
4 changes: 4 additions & 0 deletions dpnp/scipy/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@
from ._erf import (
erf,
erfc,
erfcinv,
erfcx,
erfinv,
)

__all__ = [
"erf",
"erfc",
"erfcx",
"erfinv",
"erfcinv",
]
Loading
Loading