Skip to content

Commit 2b60811

Browse files
authored
feat: add window dilations (#82)
* first iteration of window dilations implementation followin Torchx steps * implemented window dilations on EMLX.window_op()
1 parent 6ab8207 commit 2b60811

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

lib/emlx/backend.ex

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do
13421342
end
13431343

13441344
defp window_op(op, out, tensor, window_shape, opts) do
1345-
# TODO: window dilations can be implemented after we support internal padding
1346-
# in Nx.pad (we should have pad_internal as a shared defp)
13471345
tensor_rank = tuple_size(tensor.shape)
13481346

13491347
axes =
@@ -1356,7 +1354,20 @@ defmodule EMLX.Backend do
13561354
{low_pad, high_pad} = Enum.unzip(opts[:padding])
13571355
{device, _} = t_mx = from_nx(tensor)
13581356

1359-
{_device, pad_mx} =
1357+
window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape))
1358+
interior_padding_config = Enum.map(window_dilations, &(&1 - 1))
1359+
1360+
{_device, zero_mx} = EMLX.scalar_tensor(0, :bool, device)
1361+
1362+
window =
1363+
1
1364+
|> EMLX.scalar_tensor(:bool, device)
1365+
|> EMLX.broadcast_to(window_shape)
1366+
|> interior_padding_mlx(zero_mx, interior_padding_config)
1367+
1368+
window_shape = EMLX.shape(window)
1369+
1370+
{device, pad_mx} =
13601371
case op do
13611372
:sum ->
13621373
EMLX.scalar_tensor(0, to_mlx_type(out.type), device)
@@ -1375,6 +1386,7 @@ defmodule EMLX.Backend do
13751386

13761387
padded_mx
13771388
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1389+
|> then(&EMLX.where(window, &1, {device, pad_mx}))
13781390
|> then(&apply(EMLX, op, [&1, axes, false]))
13791391
|> to_nx(out)
13801392
end

test/emlx/nx_doctest_test.exs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@ defmodule EMLX.Nx.DoctestTest do
4747

4848
@to_be_fixed [
4949
:moduledoc,
50-
# window_* do not support window_dilations yet
51-
window_sum: 3,
52-
window_max: 3,
53-
window_min: 3,
54-
window_product: 3,
55-
window_mean: 3,
56-
# missing support for inner padding
5750
# MLX sorts NaNs lowest, Nx sorts them highest
5851
argmin: 2,
5952
argmax: 2,

0 commit comments

Comments
 (0)