Skip to content

Commit 6ab8207

Browse files
Chapamanpolvalente
andauthored
feat: add interior padding (#81)
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 3f1e190 commit 6ab8207

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

lib/emlx/backend.ex

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,25 +276,72 @@ defmodule EMLX.Backend do
276276
input_config
277277
|> Enum.with_index()
278278
|> Enum.reduce({[], [], []}, fn
279-
{{low, high, 0}, i}, {axes, lows, highs} ->
279+
{{low, high, _}, i}, {axes, lows, highs} ->
280280
{[i | axes], [max(low, 0) | lows], [max(high, 0) | highs]}
281-
282-
_, _ ->
283-
raise "Interior padding not supported in EMLX yet"
284281
end)
285282

286-
pad_value =
287-
pad_value
288-
|> from_nx()
289-
|> elem(1)
283+
{_device, pad_value_mx} = from_nx(pad_value)
284+
285+
interior_padding = Enum.map(input_config, fn {_low, _high, interior} -> interior end)
290286

291287
tensor
292288
|> from_nx()
289+
|> interior_padding_mlx(pad_value_mx, interior_padding)
293290
|> slice_negative_padding(input_config)
294-
|> EMLX.pad(axes, low_pad_size, high_pad_size, pad_value)
291+
|> EMLX.pad(axes, low_pad_size, high_pad_size, pad_value_mx)
295292
|> to_nx(out)
296293
end
297294

295+
defp interior_padding_mlx(tensor, value, padding_config) do
296+
new_shape = Tuple.insert_at(EMLX.shape(tensor), tuple_size(EMLX.shape(tensor)), 1)
297+
tensor = EMLX.reshape(tensor, new_shape)
298+
299+
{final_tensor, _} =
300+
Enum.reduce(padding_config, {tensor, 0}, fn interior, {acc, axis_index} ->
301+
new_tensor = apply_interior_padding(acc, value, axis_index, interior, EMLX.shape(acc))
302+
{new_tensor, axis_index + 1}
303+
end)
304+
305+
final_tensor
306+
|> EMLX.squeeze([-1])
307+
end
308+
309+
defp apply_interior_padding(tensor, _value, _axis_index, 0, _shape) do
310+
tensor
311+
end
312+
313+
defp apply_interior_padding(tensor, value, axis_index, interior_padding, shape) do
314+
rank = tuple_size(shape)
315+
next_axis = axis_index + 1
316+
axis_size = elem(shape, axis_index)
317+
next_axis_size = elem(shape, next_axis)
318+
319+
lows = [0]
320+
highs = [next_axis_size * interior_padding]
321+
axes = [next_axis]
322+
323+
padded_tensor = EMLX.pad(tensor, axes, lows, highs, value)
324+
325+
new_axis_size = axis_size + axis_size * interior_padding
326+
327+
new_shape =
328+
shape
329+
|> put_elem(axis_index, new_axis_size)
330+
|> put_elem(axis_index + 1, next_axis_size)
331+
332+
lengths =
333+
new_shape
334+
|> put_elem(axis_index, new_axis_size - interior_padding)
335+
|> Tuple.to_list()
336+
337+
starts = List.duplicate(0, rank)
338+
strides = List.duplicate(1, rank)
339+
340+
padded_tensor
341+
|> EMLX.reshape(new_shape)
342+
|> mlx_slice(new_shape, starts, lengths, strides)
343+
end
344+
298345
defp slice_negative_padding(t_mx, input_config) do
299346
if Enum.any?(input_config, fn {pre, post, _} -> pre < 0 or post < 0 end) do
300347
shape = EMLX.shape(t_mx)

test/emlx/nx_doctest_test.exs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ defmodule EMLX.Nx.DoctestTest do
5454
window_product: 3,
5555
window_mean: 3,
5656
# missing support for inner padding
57-
pad: 3,
5857
# MLX sorts NaNs lowest, Nx sorts them highest
5958
argmin: 2,
6059
argmax: 2,

0 commit comments

Comments
 (0)