Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
expand_dims,
isclose,
isin,
kron,
nan_to_num,
one_hot,
pad,
Expand All @@ -20,13 +21,11 @@
apply_where,
broadcast_shapes,
default_dtype,
kron,
nunique,
)
from ._lib._lazy import lazy_apply

__version__ = "0.9.1.dev0"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
Expand Down
96 changes: 96 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"create_diagonal",
"expand_dims",
"isclose",
"kron",
"nan_to_num",
"one_hot",
"pad",
Expand Down Expand Up @@ -416,6 +417,101 @@ def isclose(
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def kron(
a: Array | complex,
b: Array | complex,
/,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Kronecker product of two arrays.

Computes the Kronecker product, a composite array made of blocks of the
second array scaled by the first.

Equivalent to ``numpy.kron`` for NumPy arrays.

Parameters
----------
a, b : Array | int | float | complex
Input arrays or scalars. At least one must be an array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
array
The Kronecker product of `a` and `b`.

Notes
-----
The function assumes that the number of dimensions of `a` and `b`
are the same, if necessary prepending the smallest with ones.
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
The elements are products of elements from `a` and `b`, organized
explicitly by::

kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]

where::

kt = it * st + jt, t = 0,...,N

In the common 2-D case (N=1), the block structure can be visualized::

[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
Array([ 5, 6, 7, 50, 60, 70, 500,
600, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
Array([ 5, 50, 500, 6, 60, 600, 7,
70, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
Array([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]], dtype=array_api_strict.float64)

>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
>>> c = xpx.kron(a, b, xp=xp)
>>> c.shape
(2, 10, 6, 20)
>>> I = (1, 3, 0, 2)
>>> J = (0, 2, 1)
>>> J1 = (0,) + J # extend to ndim=4
>>> S1 = (1,) + b.shape
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
>>> c[K] == a[I]*b[J]
Array(True, dtype=array_api_strict.bool)
"""
if xp is None:
xp = array_namespace(a, b)

a, b = asarrays(a, b, xp=xp)

if (
is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_numpy_namespace(xp)
or is_torch_namespace(xp)
):
return xp.kron(a, b)

return _funcs.kron(a, b, xp=xp)


def nan_to_num(
x: Array | float | complex,
/,
Expand Down
84 changes: 5 additions & 79 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,87 +407,13 @@ def isclose(


def kron(
a: Array | complex,
b: Array | complex,
a: Array,
b: Array,
/,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Kronecker product of two arrays.

Computes the Kronecker product, a composite array made of blocks of the
second array scaled by the first.

Equivalent to ``numpy.kron`` for NumPy arrays.

Parameters
----------
a, b : Array | int | float | complex
Input arrays or scalars. At least one must be an array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
array
The Kronecker product of `a` and `b`.

Notes
-----
The function assumes that the number of dimensions of `a` and `b`
are the same, if necessary prepending the smallest with ones.
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
The elements are products of elements from `a` and `b`, organized
explicitly by::

kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]

where::

kt = it * st + jt, t = 0,...,N

In the common 2-D case (N=1), the block structure can be visualized::

[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
Array([ 5, 6, 7, 50, 60, 70, 500,
600, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
Array([ 5, 50, 500, 6, 60, 600, 7,
70, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
Array([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]], dtype=array_api_strict.float64)

>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
>>> c = xpx.kron(a, b, xp=xp)
>>> c.shape
(2, 10, 6, 20)
>>> I = (1, 3, 0, 2)
>>> J = (0, 2, 1)
>>> J1 = (0,) + J # extend to ndim=4
>>> S1 = (1,) + b.shape
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
>>> c[K] == a[I]*b[J]
Array(True, dtype=array_api_strict.bool)
"""
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""

singletons = (1,) * (b.ndim - a.ndim)
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))
Expand Down
Loading