Skip to content

Commit 939fe9a

Browse files
authored
feat: sort/argsort/argmin/argmax with nans (#86)
* feat: argsort nans * feat: sort nans * feat: support nans in argmin argmax * chore: revert stray files * --version
1 parent 035f6d2 commit 939fe9a

File tree

2 files changed

+95
-30
lines changed

2 files changed

+95
-30
lines changed

lib/emlx/backend.ex

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,32 @@ defmodule EMLX.Backend do
590590

591591
t_mx = from_nx(tensor)
592592

593+
# Check for NaNs in the original tensor (before any reversal)
594+
is_nan_mx = EMLX.is_nan(t_mx)
595+
596+
nan_index_mx =
597+
if axis do
598+
EMLX.argmax(is_nan_mx, axis, keep_axis)
599+
else
600+
EMLX.argmax(is_nan_mx, keep_axis)
601+
end
602+
603+
# Check if any NaN exists along the axis
604+
has_nan_mx =
605+
cond do
606+
axis ->
607+
EMLX.any(is_nan_mx, [axis], keep_axis)
608+
609+
tuple_size(tensor.shape) == 0 ->
610+
# For scalar input, is_nan_mx is already a scalar boolean
611+
is_nan_mx
612+
613+
true ->
614+
# For full reduction over non-scalar tensors
615+
EMLX.any(is_nan_mx, Nx.axes(tensor), keep_axis)
616+
end
617+
618+
# Apply reversal for tie_break after NaN check
593619
t_mx =
594620
if opts[:tie_break] == :high do
595621
reverse_mlx(t_mx, tensor.shape, [axis] || Nx.axes(tensor))
@@ -623,6 +649,9 @@ defmodule EMLX.Backend do
623649
result
624650
end
625651

652+
# Use NaN index if any NaN exists, otherwise use regular result
653+
result = EMLX.where(has_nan_mx, nan_index_mx, result)
654+
626655
result
627656
|> EMLX.astype(to_mlx_type(out.type))
628657
|> to_nx(out)
@@ -1196,36 +1225,78 @@ defmodule EMLX.Backend do
11961225
axis = opts[:axis]
11971226
asc? = opts[:direction] == :asc
11981227

1199-
t = tensor |> from_nx() |> EMLX.sort(axis)
1228+
t_mx = from_nx(tensor)
12001229

1201-
if asc? do
1202-
to_nx(t, out)
1203-
else
1204-
t
1205-
|> to_nx(out)
1206-
|> Nx.reverse(axes: [axis])
1207-
end
1230+
# Get the sorting indices
1231+
sort_mx =
1232+
if asc? do
1233+
EMLX.argsort(t_mx, axis)
1234+
else
1235+
t_mx
1236+
|> EMLX.negate()
1237+
|> EMLX.argsort(axis)
1238+
end
1239+
1240+
# Gather values at sorted positions to identify NaNs
1241+
sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis)
1242+
is_nan_mx = EMLX.is_nan(sorted_values_mx)
1243+
1244+
# Partition indices to place NaNs correctly (NaNs are treated as highest):
1245+
# - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
1246+
# - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
1247+
partition_indices_mx =
1248+
if asc? do
1249+
EMLX.argsort(is_nan_mx, axis)
1250+
else
1251+
is_nan_mx
1252+
|> EMLX.logical_not()
1253+
|> EMLX.argsort(axis)
1254+
end
1255+
1256+
# Reorder the sorted values to move NaNs to the correct position
1257+
sorted_values_mx
1258+
|> EMLX.take_along_axis(partition_indices_mx, axis)
1259+
|> EMLX.astype(to_mlx_type(out.type))
1260+
|> to_nx(out)
12081261
end
12091262

12101263
@impl true
12111264
def argsort(out, tensor, opts) do
12121265
axis = opts[:axis]
12131266
asc? = opts[:direction] == :asc
12141267

1215-
if asc? do
1216-
tensor
1217-
|> from_nx()
1218-
|> EMLX.argsort(axis)
1219-
|> EMLX.astype(to_mlx_type(out.type))
1220-
|> to_nx(out)
1221-
else
1222-
tensor
1223-
|> from_nx()
1224-
|> EMLX.negate()
1225-
|> EMLX.argsort(axis)
1226-
|> EMLX.astype(to_mlx_type(out.type))
1227-
|> to_nx(out)
1228-
end
1268+
t_mx = from_nx(tensor)
1269+
# Get the initial sorting indices
1270+
sort_mx =
1271+
if asc? do
1272+
EMLX.argsort(t_mx, axis)
1273+
else
1274+
t_mx
1275+
|> EMLX.negate()
1276+
|> EMLX.argsort(axis)
1277+
end
1278+
1279+
# Gather values at sorted positions to identify NaNs
1280+
sorted_values_mx = EMLX.take_along_axis(t_mx, sort_mx, axis)
1281+
is_nan_mx = EMLX.is_nan(sorted_values_mx)
1282+
1283+
# Partition indices to place NaNs correctly (NaNs are treated as highest):
1284+
# - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
1285+
# - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
1286+
partition_indices_mx =
1287+
if asc? do
1288+
EMLX.argsort(is_nan_mx, axis)
1289+
else
1290+
is_nan_mx
1291+
|> EMLX.logical_not()
1292+
|> EMLX.argsort(axis)
1293+
end
1294+
1295+
# Reorder the sorting indices to move NaN indices to the correct position
1296+
sort_mx
1297+
|> EMLX.take_along_axis(partition_indices_mx, axis)
1298+
|> EMLX.astype(to_mlx_type(out.type))
1299+
|> to_nx(out)
12291300
end
12301301

12311302
defp maybe_upcast(%T{type: t} = left, %T{type: t} = right),

test/emlx/nx_doctest_test.exs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,14 @@ defmodule EMLX.Nx.DoctestTest do
4646
]
4747

4848
@to_be_fixed [
49-
:moduledoc,
50-
# MLX sorts NaNs lowest, Nx sorts them highest
51-
argsort: 2
49+
:moduledoc
5250
]
5351

5452
@not_supported [
5553
reduce: 4,
5654
window_reduce: 5,
5755
population_count: 1,
58-
count_leading_zeros: 1,
59-
sort: 2,
60-
# We do not support the same ordering for NaNs as Nx
61-
argmin: 2,
62-
argmax: 2
56+
count_leading_zeros: 1
6357
]
6458

6559
doctest Nx,

0 commit comments

Comments
 (0)