@@ -1290,42 +1290,46 @@ defmodule EMLX.Backend do
12901290 for op <- [ :sum , :product , :max , :min ] do
12911291 @ impl true
12921292 def unquote ( :"window_#{ op } " ) ( out , tensor , window_shape , opts ) do
1293- # TODO: window dilations can be implemented after we support internal padding
1294- # in Nx.pad (we should have pad_internal as a shared defp)
1295- tensor_rank = tuple_size ( tensor . shape )
1296-
1297- axes =
1298- 0 .. ( tuple_size ( window_shape ) - 1 )
1299- |> Enum . to_list ( )
1300- |> Enum . map ( fn axis ->
1301- tensor_rank + axis
1302- end )
1293+ window_op ( unquote ( op ) , out , tensor , window_shape , opts )
1294+ end
1295+ end
13031296
1304- { low_pad , high_pad } = Enum . unzip ( opts [ :padding ] )
1305- { device , _ } = t_mx = from_nx ( tensor )
1297+ defp window_op ( op , out , tensor , window_shape , opts ) do
1298+ # TODO: window dilations can be implemented after we support internal padding
1299+ # in Nx.pad (we should have pad_internal as a shared defp)
1300+ tensor_rank = tuple_size ( tensor . shape )
13061301
1307- { _device , pad_mx } =
1308- case unquote ( op ) do
1309- :sum ->
1310- EMLX . scalar_tensor ( 0 , to_mlx_type ( out . type ) , device )
1302+ axes =
1303+ 0 .. ( tuple_size ( window_shape ) - 1 )
1304+ |> Enum . to_list ( )
1305+ |> Enum . map ( fn axis ->
1306+ tensor_rank + axis
1307+ end )
13111308
1312- :product ->
1313- EMLX . scalar_tensor ( 1 , to_mlx_type ( out . type ) , device )
1309+ { low_pad , high_pad } = Enum . unzip ( opts [ :padding ] )
1310+ { device , _ } = t_mx = from_nx ( tensor )
13141311
1315- :max ->
1316- Nx.Constants . min ( tensor . type , backend: { EMLX.Backend , device: device } ) |> from_nx ( )
1312+ { _device , pad_mx } =
1313+ case op do
1314+ :sum ->
1315+ EMLX . scalar_tensor ( 0 , to_mlx_type ( out . type ) , device )
13171316
1318- :min ->
1319- Nx.Constants . max ( tensor . type , backend: { EMLX.Backend , device: device } ) |> from_nx ( )
1320- end
1317+ :product ->
1318+ EMLX . scalar_tensor ( 1 , to_mlx_type ( out . type ) , device )
13211319
1322- padded_mx = EMLX . pad ( t_mx , Nx . axes ( tensor ) , low_pad , high_pad , pad_mx )
1320+ :max ->
1321+ Nx.Constants . min ( tensor . type , backend: { EMLX.Backend , device: device } ) |> from_nx ( )
13231322
1324- padded_mx
1325- |> sliding_window_view ( EMLX . shape ( padded_mx ) , window_shape , opts [ :strides ] )
1326- |> EMLX . unquote ( op ) ( axes , false )
1327- |> to_nx ( out )
1328- end
1323+ :min ->
1324+ Nx.Constants . max ( tensor . type , backend: { EMLX.Backend , device: device } ) |> from_nx ( )
1325+ end
1326+
1327+ padded_mx = EMLX . pad ( t_mx , Nx . axes ( tensor ) , low_pad , high_pad , pad_mx )
1328+
1329+ padded_mx
1330+ |> sliding_window_view ( EMLX . shape ( padded_mx ) , window_shape , opts [ :strides ] )
1331+ |> then ( & apply ( EMLX , op , [ & 1 , axes , false ] ) )
1332+ |> to_nx ( out )
13291333 end
13301334
13311335 defp sliding_window_view ( t , tensor_shape , window_shape , opt_strides ) do
0 commit comments