@@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do
13421342 end
13431343
13441344 defp window_op ( op , out , tensor , window_shape , opts ) do
1345- # TODO: window dilations can be implemented after we support internal padding
1346- # in Nx.pad (we should have pad_internal as a shared defp)
13471345 tensor_rank = tuple_size ( tensor . shape )
13481346
13491347 axes =
@@ -1356,7 +1354,20 @@ defmodule EMLX.Backend do
13561354 { low_pad , high_pad } = Enum . unzip ( opts [ :padding ] )
13571355 { device , _ } = t_mx = from_nx ( tensor )
13581356
1359- { _device , pad_mx } =
1357+ window_dilations = opts [ :window_dilations ] || List . duplicate ( 1 , tuple_size ( window_shape ) )
1358+ interior_padding_config = Enum . map ( window_dilations , & ( & 1 - 1 ) )
1359+
1360+ { _device , zero_mx } = EMLX . scalar_tensor ( 0 , :bool , device )
1361+
1362+ window =
1363+ 1
1364+ |> EMLX . scalar_tensor ( :bool , device )
1365+ |> EMLX . broadcast_to ( window_shape )
1366+ |> interior_padding_mlx ( zero_mx , interior_padding_config )
1367+
1368+ window_shape = EMLX . shape ( window )
1369+
1370+ { device , pad_mx } =
13601371 case op do
13611372 :sum ->
13621373 EMLX . scalar_tensor ( 0 , to_mlx_type ( out . type ) , device )
@@ -1375,6 +1386,7 @@ defmodule EMLX.Backend do
13751386
13761387 padded_mx
13771388 |> sliding_window_view ( EMLX . shape ( padded_mx ) , window_shape , opts [ :strides ] )
1389+ |> then ( & EMLX . where ( window , & 1 , { device , pad_mx } ) )
13781390 |> then ( & apply ( EMLX , op , [ & 1 , axes , false ] ) )
13791391 |> to_nx ( out )
13801392 end
0 commit comments