@@ -590,6 +590,32 @@ defmodule EMLX.Backend do
590590
591591 t_mx = from_nx ( tensor )
592592
593+ # Check for NaNs in the original tensor (before any reversal)
594+ is_nan_mx = EMLX . is_nan ( t_mx )
595+
596+ nan_index_mx =
597+ if axis do
598+ EMLX . argmax ( is_nan_mx , axis , keep_axis )
599+ else
600+ EMLX . argmax ( is_nan_mx , keep_axis )
601+ end
602+
603+ # Check if any NaN exists along the axis
604+ has_nan_mx =
605+ cond do
606+ axis ->
607+ EMLX . any ( is_nan_mx , [ axis ] , keep_axis )
608+
609+ tuple_size ( tensor . shape ) == 0 ->
610+ # For scalar input, is_nan_mx is already a scalar boolean
611+ is_nan_mx
612+
613+ true ->
614+ # For full reduction over non-scalar tensors
615+ EMLX . any ( is_nan_mx , Nx . axes ( tensor ) , keep_axis )
616+ end
617+
618+ # Apply reversal for tie_break after NaN check
593619 t_mx =
594620 if opts [ :tie_break ] == :high do
595621 reverse_mlx ( t_mx , tensor . shape , [ axis ] || Nx . axes ( tensor ) )
@@ -623,6 +649,9 @@ defmodule EMLX.Backend do
623649 result
624650 end
625651
652+ # Use NaN index if any NaN exists, otherwise use regular result
653+ result = EMLX . where ( has_nan_mx , nan_index_mx , result )
654+
626655 result
627656 |> EMLX . astype ( to_mlx_type ( out . type ) )
628657 |> to_nx ( out )
@@ -1196,36 +1225,78 @@ defmodule EMLX.Backend do
11961225 axis = opts [ :axis ]
11971226 asc? = opts [ :direction ] == :asc
11981227
1199- t = tensor |> from_nx ( ) |> EMLX . sort ( axis )
1228+ t_mx = from_nx ( tensor )
12001229
1201- if asc? do
1202- to_nx ( t , out )
1203- else
1204- t
1205- |> to_nx ( out )
1206- |> Nx . reverse ( axes: [ axis ] )
1207- end
1230+ # Get the sorting indices
1231+ sort_mx =
1232+ if asc? do
1233+ EMLX . argsort ( t_mx , axis )
1234+ else
1235+ t_mx
1236+ |> EMLX . negate ( )
1237+ |> EMLX . argsort ( axis )
1238+ end
1239+
1240+ # Gather values at sorted positions to identify NaNs
1241+ sorted_values_mx = EMLX . take_along_axis ( t_mx , sort_mx , axis )
1242+ is_nan_mx = EMLX . is_nan ( sorted_values_mx )
1243+
1244+ # Partition indices to place NaNs correctly (NaNs are treated as highest):
1245+ # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
1246+ # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
1247+ partition_indices_mx =
1248+ if asc? do
1249+ EMLX . argsort ( is_nan_mx , axis )
1250+ else
1251+ is_nan_mx
1252+ |> EMLX . logical_not ( )
1253+ |> EMLX . argsort ( axis )
1254+ end
1255+
1256+ # Reorder the sorted values to move NaNs to the correct position
1257+ sorted_values_mx
1258+ |> EMLX . take_along_axis ( partition_indices_mx , axis )
1259+ |> EMLX . astype ( to_mlx_type ( out . type ) )
1260+ |> to_nx ( out )
12081261 end
12091262
12101263 @ impl true
12111264 def argsort ( out , tensor , opts ) do
12121265 axis = opts [ :axis ]
12131266 asc? = opts [ :direction ] == :asc
12141267
1215- if asc? do
1216- tensor
1217- |> from_nx ( )
1218- |> EMLX . argsort ( axis )
1219- |> EMLX . astype ( to_mlx_type ( out . type ) )
1220- |> to_nx ( out )
1221- else
1222- tensor
1223- |> from_nx ( )
1224- |> EMLX . negate ( )
1225- |> EMLX . argsort ( axis )
1226- |> EMLX . astype ( to_mlx_type ( out . type ) )
1227- |> to_nx ( out )
1228- end
1268+ t_mx = from_nx ( tensor )
1269+ # Get the initial sorting indices
1270+ sort_mx =
1271+ if asc? do
1272+ EMLX . argsort ( t_mx , axis )
1273+ else
1274+ t_mx
1275+ |> EMLX . negate ( )
1276+ |> EMLX . argsort ( axis )
1277+ end
1278+
1279+ # Gather values at sorted positions to identify NaNs
1280+ sorted_values_mx = EMLX . take_along_axis ( t_mx , sort_mx , axis )
1281+ is_nan_mx = EMLX . is_nan ( sorted_values_mx )
1282+
1283+ # Partition indices to place NaNs correctly (NaNs are treated as highest):
1284+ # - For ascending: NaNs (highest) go to end: sort by is_nan (0 < 1)
1285+ # - For descending: NaNs (highest) go to beginning: sort by !is_nan (1 < 0)
1286+ partition_indices_mx =
1287+ if asc? do
1288+ EMLX . argsort ( is_nan_mx , axis )
1289+ else
1290+ is_nan_mx
1291+ |> EMLX . logical_not ( )
1292+ |> EMLX . argsort ( axis )
1293+ end
1294+
1295+ # Reorder the sorting indices to move NaN indices to the correct position
1296+ sort_mx
1297+ |> EMLX . take_along_axis ( partition_indices_mx , axis )
1298+ |> EMLX . astype ( to_mlx_type ( out . type ) )
1299+ |> to_nx ( out )
12291300 end
12301301
12311302 defp maybe_upcast ( % T { type: t } = left , % T { type: t } = right ) ,
0 commit comments