Skip to content

Commit 035f6d2

Browse files
Chapamanpolvalente
andauthored
refactor(test): promote window_scatter_min/window_scatter_max tests to nx_test.exs (#85)
* window_scatter_min not working * added window_scatter_min with vectorized input * Updated window_scatter_min and max to feature in nx_test.exs; removed comments regarding the tie breaker in emlx.ex * fix: simplify pad invocations --------- Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent d468c75 commit 035f6d2

File tree

7 files changed

+219
-64
lines changed

7 files changed

+219
-64
lines changed

lib/emlx.ex

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ defmodule EMLX do
177177
deftensor tensordot(tensorA, tensorB, axesA, axesB)
178178
deftensor einsum(tensorA, tensorB, spec_string)
179179
deftensor transpose(tensor, axes)
180-
deftensor pad(tensor, axes, low_pad_size, high_pad_size, pad_value)
180+
deftensor pad(tensor, axes, low_pad_size, high_pad_size, tensor_pad_value)
181181
deftensor sort(tensor, axis)
182182
deftensor argsort(tensor, axis)
183183
deftensor tri_inv(tensor, upper)
@@ -449,9 +449,6 @@ defmodule EMLX do
449449
@impl Nx.Defn.Compiler
450450
defdelegate __partitions_options__(opts), to: Nx.Defn.Evaluator
451451

452-
@impl Nx.Defn.Compiler
453-
defdelegate __stream__(key, input, acc, vars, fun, args, opts), to: Nx.Defn.Evaluator
454-
455452
@impl Nx.Defn.Compiler
456453
def __to_backend__(opts) do
457454
device = Keyword.get(opts, :device, :gpu)

lib/emlx/backend.ex

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ defmodule EMLX.Backend do
284284
{[i | axes], [max(low, 0) | lows], [max(high, 0) | highs]}
285285
end)
286286

287-
{_device, pad_value_mx} = from_nx(pad_value)
287+
pad_value_mx = from_nx(pad_value)
288288

289289
interior_padding = Enum.map(input_config, fn {_low, _high, interior} -> interior end)
290290

@@ -1381,7 +1381,7 @@ defmodule EMLX.Backend do
13811381
window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape))
13821382
interior_padding_config = Enum.map(window_dilations, &(&1 - 1))
13831383

1384-
{_device, zero_mx} = EMLX.scalar_tensor(0, :bool, device)
1384+
zero_mx = EMLX.scalar_tensor(0, :bool, device)
13851385

13861386
window =
13871387
1
@@ -1391,7 +1391,7 @@ defmodule EMLX.Backend do
13911391

13921392
window_shape = EMLX.shape(window)
13931393

1394-
{device, pad_mx} =
1394+
pad_value_mx =
13951395
case op do
13961396
:sum ->
13971397
EMLX.scalar_tensor(0, to_mlx_type(out.type), device)
@@ -1406,11 +1406,11 @@ defmodule EMLX.Backend do
14061406
Nx.Constants.max(tensor.type, backend: {EMLX.Backend, device: device}) |> from_nx()
14071407
end
14081408

1409-
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_mx)
1409+
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_value_mx)
14101410

14111411
padded_mx
14121412
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1413-
|> then(&EMLX.where(window, &1, {device, pad_mx}))
1413+
|> then(&EMLX.where(window, &1, pad_value_mx))
14141414
|> then(&apply(EMLX, op, [&1, axes, false]))
14151415
|> to_nx(out)
14161416
end
@@ -1468,14 +1468,13 @@ defmodule EMLX.Backend do
14681468
end
14691469

14701470
defp window_scatter_function(function, out, tensor, source, init_value, window_dims_tuple, opts) do
1471-
# TODO: support window dilations
14721471
unfold_flat = fn tensor ->
14731472
{device, _} = t_mx = from_nx(tensor)
1474-
{_, pad_mx} = EMLX.scalar_tensor(0, EMLX.scalar_type(t_mx), device)
1473+
pad_value_mx = EMLX.scalar_tensor(0, EMLX.scalar_type(t_mx), device)
14751474

14761475
{low_pad, high_pad} = Enum.unzip(opts[:padding])
14771476

1478-
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_mx)
1477+
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_value_mx)
14791478

14801479
unfolded_mx =
14811480
sliding_window_view(

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ defmodule EMLX.MixProject do
5858
defp deps do
5959
[
6060
{:elixir_make, "~> 0.6"},
61-
{:nx, "~> 0.9.2"},
61+
{:nx, "~> 0.10"},
6262
{:nif_call, "~> 0.1.3"}
6363
]
6464
end

mix.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
%{
2-
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
2+
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
33
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
44
"nif_call": {:hex, :nif_call, "0.1.3", "bb4af0d28d1a2f10602d50246155b95b4ef6a389025c12d830dc9924ef06d324", [:mix], [], "hexpm", "48ba2e66c7d5aab4ca1a9656d3fa7326d317ba63bac64d24b22a05c8c8a9aff0"},
5-
"nx": {:hex, :nx, "0.9.2", "17563029c01bf749aad3c31234326d7665abd0acc33ee2acbe531a4759f29a8a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "914d74741617d8103de8ab1f8c880353e555263e1c397b8a1109f79a3716557f"},
5+
"nx": {:hex, :nx, "0.10.0", "128e4a094cb790f663e20e1334b127c1f2a4df54edfb8b13c22757ec33133b4f", [:mix], [{:complex, "~> 0.6", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3db8892c124aeee091df0e6fbf8e5bf1b81f502eb0d4f5ba63e6378ebcae7da4"},
66
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
77
}

test/emlx/nx_doctest_test.exs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ defmodule EMLX.Nx.DoctestTest do
4848
@to_be_fixed [
4949
:moduledoc,
5050
# MLX sorts NaNs lowest, Nx sorts them highest
51-
argsort: 2,
52-
# Missing support for window dilations
53-
window_scatter_max: 5,
54-
window_scatter_min: 5
51+
argsort: 2
5552
]
5653

5754
@not_supported [

test/emlx/nx_test.exs

Lines changed: 206 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,8 +1055,6 @@ defmodule EMLX.NxTest do
10551055
)
10561056
end
10571057

1058-
# does not support window_dilations yet
1059-
@tag :skip
10601058
test "works with non-default options" do
10611059
t = Nx.tensor([[[4, 2, 1, 3], [4, 2, 1, 7]], [[1, 2, 5, 7], [1, 8, 9, 2]]])
10621060
opts = [strides: [2, 1, 1], padding: :valid, window_dilations: [1, 2, 2]]
@@ -1117,8 +1115,6 @@ defmodule EMLX.NxTest do
11171115
)
11181116
end
11191117

1120-
# window_dilations are not supported yet
1121-
@tag :skip
11221118
test "works with non-default options" do
11231119
t = Nx.tensor([[[4, 2, 1, 3], [4, 2, 1, 7]], [[1, 2, 5, 7], [1, 8, 9, 2]]])
11241120
opts = [strides: [2, 1, 1], padding: :valid, window_dilations: [1, 2, 2]]
@@ -1244,8 +1240,6 @@ defmodule EMLX.NxTest do
12441240
)
12451241
end
12461242

1247-
# window dilations are not supported yet
1248-
@tag :skip
12491243
test "supports window dilations" do
12501244
result = Nx.window_sum(Nx.iota({4, 4}), {2, 2}, window_dilations: [2, 1])
12511245

@@ -1298,8 +1292,6 @@ defmodule EMLX.NxTest do
12981292
)
12991293
end
13001294

1301-
# window dilations are not supported yet
1302-
@tag :skip
13031295
test "supports window dilations" do
13041296
result = Nx.window_product(Nx.iota({4, 4}), {2, 2}, window_dilations: [2, 1])
13051297

@@ -1576,4 +1568,210 @@ defmodule EMLX.NxTest do
15761568
)
15771569
end
15781570
end
1571+
1572+
describe "window_scatter_max" do
1573+
test "window_scatter_max with strides [2, 3]" do
1574+
t =
1575+
Nx.tensor([
1576+
[7, 2, 5, 3, 10, 2],
1577+
[3, 8, 9, 3, 4, 2],
1578+
[1, 5, 7, 5, 6, 1],
1579+
[0, 6, 2, 7, 2, 8]
1580+
])
1581+
1582+
opts = [strides: [2, 3], padding: :valid]
1583+
result = Nx.window_scatter_max(t, Nx.tensor([[2, 6], [3, 1]]), 0, {2, 3}, opts)
1584+
1585+
assert_all_close(
1586+
result,
1587+
Nx.tensor([
1588+
[0, 0, 0, 0, 6, 0],
1589+
[0, 0, 2, 0, 0, 0],
1590+
[0, 0, 3, 0, 0, 0],
1591+
[0, 0, 0, 0, 0, 1]
1592+
])
1593+
)
1594+
end
1595+
1596+
test "window_scatter_max with strides [2, 2]" do
1597+
t =
1598+
Nx.tensor([
1599+
[7, 2, 5, 3, 8],
1600+
[3, 8, 9, 3, 4],
1601+
[1, 5, 7, 5, 6],
1602+
[0, 6, 2, 10, 2]
1603+
])
1604+
1605+
opts = [strides: [2, 2], padding: :valid]
1606+
result = Nx.window_scatter_max(t, Nx.tensor([[2, 6], [3, 1]]), 0, {2, 3}, opts)
1607+
1608+
assert_all_close(
1609+
result,
1610+
Nx.tensor([
1611+
[0, 0, 0, 0, 0],
1612+
[0, 0, 8, 0, 0],
1613+
[0, 0, 3, 0, 0],
1614+
[0, 0, 0, 1, 0]
1615+
])
1616+
)
1617+
end
1618+
1619+
test "window_scatter_max with vectorized input" do
1620+
t =
1621+
Nx.tensor([
1622+
[
1623+
[7, 2, 5, 3],
1624+
[3, 8, 9, 3]
1625+
],
1626+
[
1627+
[1, 5, 7, 5],
1628+
[0, 6, 2, 8]
1629+
]
1630+
])
1631+
|> Nx.vectorize(:x)
1632+
1633+
opts = [strides: [1, 2], padding: :valid]
1634+
1635+
source =
1636+
Nx.tensor([
1637+
[[2, 6]],
1638+
[[3, 1]]
1639+
])
1640+
|> Nx.vectorize(:y)
1641+
1642+
result = Nx.window_scatter_max(t, source, 0, {2, 2}, opts)
1643+
1644+
Nx.Testing.assert_equal(
1645+
result,
1646+
Nx.tensor([
1647+
[
1648+
[
1649+
[0, 0, 0, 0],
1650+
[0, 2, 6, 0]
1651+
],
1652+
[
1653+
[0, 0, 0, 0],
1654+
[0, 3, 1, 0]
1655+
]
1656+
],
1657+
[
1658+
[
1659+
[0, 0, 0, 0],
1660+
[0, 2, 0, 6]
1661+
],
1662+
[
1663+
[0, 0, 0, 0],
1664+
[0, 3, 0, 1]
1665+
]
1666+
]
1667+
])
1668+
|> Nx.vectorize([:x, :y])
1669+
)
1670+
end
1671+
end
1672+
1673+
describe "window_scatter_min" do
1674+
test "window_scatter_min with strides [2, 3]" do
1675+
t =
1676+
Nx.tensor([
1677+
[7, 2, 5, 3, 10, 2],
1678+
[3, 8, 9, 3, 4, 2],
1679+
[1, 5, 7, 5, 6, 1],
1680+
[0, 6, 2, 7, 2, 8]
1681+
])
1682+
1683+
opts = [strides: [2, 3], padding: :valid]
1684+
result = Nx.window_scatter_min(t, Nx.tensor([[2, 6], [3, 1]]), 0, {2, 3}, opts)
1685+
1686+
assert_all_close(
1687+
result,
1688+
Nx.tensor([
1689+
[0, 2, 0, 0, 0, 0],
1690+
[0, 0, 0, 0, 0, 6],
1691+
[0, 0, 0, 0, 0, 1],
1692+
[3, 0, 0, 0, 0, 0]
1693+
])
1694+
)
1695+
end
1696+
1697+
test "window_scatter_min with strides [2, 2]" do
1698+
t =
1699+
Nx.tensor([
1700+
[7, 2, 5, 3, 8],
1701+
[3, 8, 9, 3, 4],
1702+
[1, 5, 7, 5, 6],
1703+
[0, 6, 2, 10, 2]
1704+
])
1705+
1706+
opts = [strides: [2, 2], padding: :valid]
1707+
result = Nx.window_scatter_min(t, Nx.tensor([[2, 6], [3, 1]]), 0, {2, 3}, opts)
1708+
1709+
assert_all_close(
1710+
result,
1711+
Nx.tensor([
1712+
[0, 2, 0, 0, 0],
1713+
[0, 0, 0, 6, 0],
1714+
[0, 0, 0, 0, 0],
1715+
[3, 0, 0, 0, 1]
1716+
])
1717+
)
1718+
end
1719+
1720+
test "window_scatter_min with vectorized input" do
1721+
t =
1722+
Nx.tensor([
1723+
[
1724+
[7, 2, 5, 3],
1725+
[3, 8, 9, 3]
1726+
],
1727+
[
1728+
[1, 5, 7, 5],
1729+
[0, 6, 2, 8]
1730+
]
1731+
])
1732+
|> Nx.vectorize(:x)
1733+
1734+
opts = [strides: [1, 2], padding: :valid]
1735+
1736+
source =
1737+
Nx.tensor([
1738+
[
1739+
[2, 6]
1740+
],
1741+
[
1742+
[3, 1]
1743+
]
1744+
])
1745+
|> Nx.vectorize(:y)
1746+
1747+
result = Nx.window_scatter_min(t, source, 0, {2, 2}, opts)
1748+
1749+
assert_all_close(
1750+
result,
1751+
Nx.tensor([
1752+
[
1753+
[
1754+
[0, 2, 0, 0],
1755+
[0, 0, 0, 6]
1756+
],
1757+
[
1758+
[0, 3, 0, 0],
1759+
[0, 0, 0, 1]
1760+
]
1761+
],
1762+
[
1763+
[
1764+
[0, 0, 0, 0],
1765+
[2, 0, 6, 0]
1766+
],
1767+
[
1768+
[0, 0, 0, 0],
1769+
[3, 0, 1, 0]
1770+
]
1771+
]
1772+
])
1773+
|> Nx.vectorize([:x, :y])
1774+
)
1775+
end
1776+
end
15791777
end

test/support/emlx_case.ex

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,7 @@ defmodule EMLX.Case do
88
using do
99
quote do
1010
import EMLX.Case
11-
end
12-
end
13-
14-
def assert_all_close(left, right, opts \\ []) do
15-
atol = opts[:atol] || 1.0e-4
16-
rtol = opts[:rtol] || 1.0e-4
17-
18-
equals =
19-
left
20-
|> Nx.all_close(right, atol: atol, rtol: rtol)
21-
|> Nx.backend_transfer(Nx.BinaryBackend)
22-
23-
if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do
24-
flunk("""
25-
Tensor assertion failed.
26-
left: #{inspect(left)}
27-
right: #{inspect(right)}
28-
""")
29-
end
30-
end
31-
32-
def assert_equal(left, right) do
33-
both_nan = Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))
34-
35-
equals =
36-
left
37-
|> Nx.equal(right)
38-
|> Nx.logical_or(both_nan)
39-
|> Nx.all()
40-
|> Nx.to_number()
41-
42-
if equals != 1 do
43-
flunk("""
44-
Tensor assertion failed.
45-
left: #{inspect(left)}
46-
right: #{inspect(right)}
47-
""")
11+
import Nx.Testing
4812
end
4913
end
5014
end

0 commit comments

Comments
 (0)