Skip to content

Commit 343dfa3

Browse files
authored
feat: support tie breaks (#83)
* feat: support tie breaks * chore: format
1 parent 2b60811 commit 343dfa3

File tree

5 files changed

+277
-81
lines changed

5 files changed

+277
-81
lines changed

README.md

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,6 @@ Defaulting to Nx.Defn.Evaluator is the safest option for now.
4949
Nx.Defn.default_options(compiler: EMLX)
5050
```
5151

52-
### Configuration
53-
54-
EMLX supports several configuration options that can be set in your application's config:
55-
56-
#### `:warn_unsupported_option`
57-
58-
Controls whether warnings are logged when unsupported options are used with certain operations.
59-
60-
- **Type**: `boolean`
61-
- **Default**: `true`
62-
- **Description**: When enabled, EMLX will log warnings for operations that receive options not supported by the MLX backend. For example, `Nx.argmax/2` and `Nx.argmin/2` with `tie_break: :high` will log a warning since MLX doesn't support this tie-breaking behavior.
63-
64-
```elixir
65-
# In config/config.exs
66-
config :emlx, :warn_unsupported_option, false
67-
```
68-
6952
### MLX binaries
7053

7154
EMLX relies on the [MLX](https://github.com/ml-explore/mlx) library to function, and currently EMLX will download precompiled builds from [mlx-build](https://github.com/cocoa-xu/mlx-build).

lib/emlx/backend.ex

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,14 @@ defmodule EMLX.Backend do
246246

247247
@impl true
248248
def reverse(out, tensor, axes) do
249-
shape = Tuple.to_list(tensor.shape)
249+
tensor
250+
|> from_nx()
251+
|> reverse_mlx(tensor.shape, axes)
252+
|> to_nx(out)
253+
end
254+
255+
defp reverse_mlx(tensor_mx, shape, axes) do
256+
shape = Tuple.to_list(shape)
250257

251258
{starts_stops, strides} =
252259
shape
@@ -264,10 +271,7 @@ defmodule EMLX.Backend do
264271

265272
{starts, stops} = Enum.unzip(starts_stops)
266273

267-
tensor
268-
|> from_nx()
269-
|> EMLX.slice(starts, stops, strides)
270-
|> to_nx(out)
274+
EMLX.slice(tensor_mx, starts, stops, strides)
271275
end
272276

273277
@impl true
@@ -584,21 +588,41 @@ defmodule EMLX.Backend do
584588
axis = opts[:axis]
585589
keep_axis = opts[:keep_axis] == true
586590

587-
if Application.get_env(:emlx, :warn_unsupported_option, true) and opts[:tie_break] == :high do
588-
Logger.warning(
589-
"Nx.Backend.#{unquote(op)}/3 with tie_break: :high is not supported in EMLX"
590-
)
591-
end
592-
593591
t_mx = from_nx(tensor)
594592

595-
result =
593+
t_mx =
594+
if opts[:tie_break] == :high do
595+
reverse_mlx(t_mx, tensor.shape, [axis] || Nx.axes(tensor))
596+
else
597+
t_mx
598+
end
599+
600+
{device, _} =
601+
result =
596602
if axis do
597603
EMLX.unquote(op)(t_mx, axis, keep_axis)
598604
else
599605
EMLX.unquote(op)(t_mx, keep_axis)
600606
end
601607

608+
# in case we had tie_break: :high, we need to subtract the result from the size of the sorted
609+
# set because in reversing the axis above, we will get the complement of the result
610+
result =
611+
case {axis, opts[:tie_break]} do
612+
{nil, :high} ->
613+
size = EMLX.scalar_tensor(Nx.size(tensor) - 1, to_mlx_type(out.type), device)
614+
EMLX.subtract(size, result)
615+
616+
{_, :high} ->
617+
size =
618+
EMLX.scalar_tensor(Nx.axis_size(tensor, axis) - 1, to_mlx_type(out.type), device)
619+
620+
EMLX.subtract(size, result)
621+
622+
{_, _} ->
623+
result
624+
end
625+
602626
result
603627
|> EMLX.astype(to_mlx_type(out.type))
604628
|> to_nx(out)

test/emlx/config_test.exs

Lines changed: 0 additions & 48 deletions
This file was deleted.

test/emlx/nx_doctest_test.exs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ defmodule EMLX.Nx.DoctestTest do
4848
@to_be_fixed [
4949
:moduledoc,
5050
# MLX sorts NaNs lowest, Nx sorts them highest
51-
argmin: 2,
52-
argmax: 2,
5351
argsort: 2,
54-
# Missing support for window dilations and for tie_break: :high
52+
# Missing support for window dilations
5553
window_scatter_max: 5,
5654
window_scatter_min: 5
5755
]
@@ -61,7 +59,10 @@ defmodule EMLX.Nx.DoctestTest do
6159
window_reduce: 5,
6260
population_count: 1,
6361
count_leading_zeros: 1,
64-
sort: 2
62+
sort: 2,
63+
# We do not support the same ordering for NaNs as Nx
64+
argmin: 2,
65+
argmax: 2
6566
]
6667

6768
doctest Nx,

0 commit comments

Comments
 (0)