@@ -276,25 +276,72 @@ defmodule EMLX.Backend do
276276 input_config
277277 |> Enum . with_index ( )
278278 |> Enum . reduce ( { [ ] , [ ] , [ ] } , fn
279- { { low , high , 0 } , i } , { axes , lows , highs } ->
279+ { { low , high , _ } , i } , { axes , lows , highs } ->
280280 { [ i | axes ] , [ max ( low , 0 ) | lows ] , [ max ( high , 0 ) | highs ] }
281-
282- _ , _ ->
283- raise "Interior padding not supported in EMLX yet"
284281 end )
285282
286- pad_value =
287- pad_value
288- |> from_nx ( )
289- |> elem ( 1 )
283+ { _device , pad_value_mx } = from_nx ( pad_value )
284+
285+ interior_padding = Enum . map ( input_config , fn { _low , _high , interior } -> interior end )
290286
291287 tensor
292288 |> from_nx ( )
289+ |> interior_padding_mlx ( pad_value_mx , interior_padding )
293290 |> slice_negative_padding ( input_config )
294- |> EMLX . pad ( axes , low_pad_size , high_pad_size , pad_value )
291+ |> EMLX . pad ( axes , low_pad_size , high_pad_size , pad_value_mx )
295292 |> to_nx ( out )
296293 end
297294
295+ defp interior_padding_mlx ( tensor , value , padding_config ) do
296+ new_shape = Tuple . insert_at ( EMLX . shape ( tensor ) , tuple_size ( EMLX . shape ( tensor ) ) , 1 )
297+ tensor = EMLX . reshape ( tensor , new_shape )
298+
299+ { final_tensor , _ } =
300+ Enum . reduce ( padding_config , { tensor , 0 } , fn interior , { acc , axis_index } ->
301+ new_tensor = apply_interior_padding ( acc , value , axis_index , interior , EMLX . shape ( acc ) )
302+ { new_tensor , axis_index + 1 }
303+ end )
304+
305+ final_tensor
306+ |> EMLX . squeeze ( [ - 1 ] )
307+ end
308+
309+ defp apply_interior_padding ( tensor , _value , _axis_index , 0 , _shape ) do
310+ tensor
311+ end
312+
313+ defp apply_interior_padding ( tensor , value , axis_index , interior_padding , shape ) do
314+ rank = tuple_size ( shape )
315+ next_axis = axis_index + 1
316+ axis_size = elem ( shape , axis_index )
317+ next_axis_size = elem ( shape , next_axis )
318+
319+ lows = [ 0 ]
320+ highs = [ next_axis_size * interior_padding ]
321+ axes = [ next_axis ]
322+
323+ padded_tensor = EMLX . pad ( tensor , axes , lows , highs , value )
324+
325+ new_axis_size = axis_size + axis_size * interior_padding
326+
327+ new_shape =
328+ shape
329+ |> put_elem ( axis_index , new_axis_size )
330+ |> put_elem ( axis_index + 1 , next_axis_size )
331+
332+ lengths =
333+ new_shape
334+ |> put_elem ( axis_index , new_axis_size - interior_padding )
335+ |> Tuple . to_list ( )
336+
337+ starts = List . duplicate ( 0 , rank )
338+ strides = List . duplicate ( 1 , rank )
339+
340+ padded_tensor
341+ |> EMLX . reshape ( new_shape )
342+ |> mlx_slice ( new_shape , starts , lengths , strides )
343+ end
344+
298345 defp slice_negative_padding ( t_mx , input_config ) do
299346 if Enum . any? ( input_config , fn { pre , post , _ } -> pre < 0 or post < 0 end ) do
300347 shape = EMLX . shape ( t_mx )
0 commit comments