Skip to content

Commit 32cb0fd

Browse files
authored
fix: type warnings in case unquote (#80)
1 parent a020ccf commit 32cb0fd

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

lib/emlx/backend.ex

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,42 +1290,46 @@ defmodule EMLX.Backend do
12901290
for op <- [:sum, :product, :max, :min] do
12911291
@impl true
12921292
def unquote(:"window_#{op}")(out, tensor, window_shape, opts) do
1293-
# TODO: window dilations can be implemented after we support internal padding
1294-
# in Nx.pad (we should have pad_internal as a shared defp)
1295-
tensor_rank = tuple_size(tensor.shape)
1296-
1297-
axes =
1298-
0..(tuple_size(window_shape) - 1)
1299-
|> Enum.to_list()
1300-
|> Enum.map(fn axis ->
1301-
tensor_rank + axis
1302-
end)
1293+
window_op(unquote(op), out, tensor, window_shape, opts)
1294+
end
1295+
end
13031296

1304-
{low_pad, high_pad} = Enum.unzip(opts[:padding])
1305-
{device, _} = t_mx = from_nx(tensor)
1297+
defp window_op(op, out, tensor, window_shape, opts) do
1298+
# TODO: window dilations can be implemented after we support internal padding
1299+
# in Nx.pad (we should have pad_internal as a shared defp)
1300+
tensor_rank = tuple_size(tensor.shape)
13061301

1307-
{_device, pad_mx} =
1308-
case unquote(op) do
1309-
:sum ->
1310-
EMLX.scalar_tensor(0, to_mlx_type(out.type), device)
1302+
axes =
1303+
0..(tuple_size(window_shape) - 1)
1304+
|> Enum.to_list()
1305+
|> Enum.map(fn axis ->
1306+
tensor_rank + axis
1307+
end)
13111308

1312-
:product ->
1313-
EMLX.scalar_tensor(1, to_mlx_type(out.type), device)
1309+
{low_pad, high_pad} = Enum.unzip(opts[:padding])
1310+
{device, _} = t_mx = from_nx(tensor)
13141311

1315-
:max ->
1316-
Nx.Constants.min(tensor.type, backend: {EMLX.Backend, device: device}) |> from_nx()
1312+
{_device, pad_mx} =
1313+
case op do
1314+
:sum ->
1315+
EMLX.scalar_tensor(0, to_mlx_type(out.type), device)
13171316

1318-
:min ->
1319-
Nx.Constants.max(tensor.type, backend: {EMLX.Backend, device: device}) |> from_nx()
1320-
end
1317+
:product ->
1318+
EMLX.scalar_tensor(1, to_mlx_type(out.type), device)
13211319

1322-
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_mx)
1320+
:max ->
1321+
Nx.Constants.min(tensor.type, backend: {EMLX.Backend, device: device}) |> from_nx()
13231322

1324-
padded_mx
1325-
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1326-
|> EMLX.unquote(op)(axes, false)
1327-
|> to_nx(out)
1328-
end
1323+
:min ->
1324+
Nx.Constants.max(tensor.type, backend: {EMLX.Backend, device: device}) |> from_nx()
1325+
end
1326+
1327+
padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_mx)
1328+
1329+
padded_mx
1330+
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1331+
|> then(&apply(EMLX, op, [&1, axes, false]))
1332+
|> to_nx(out)
13291333
end
13301334

13311335
defp sliding_window_view(t, tensor_shape, window_shape, opt_strides) do

0 commit comments

Comments
 (0)