Skip to content

Commit d468c75

Browse files
authored
fix: dirty schedule certain nifs and eval on main thread (#84)
* fix: dirty schedule certain nifs and eval on main thread * chore: format
1 parent 343dfa3 commit d468c75

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

c_src/emlx_nif.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,6 @@ NIF(to_blob) {
275275
int limit = 0;
276276
bool has_received_limit = (argc == 2);
277277

278-
// Evaluate to ensure data is available
279-
t->eval();
280-
281278
if (has_received_limit) {
282279
PARAM(1, int, param_limit);
283280
limit = param_limit;
@@ -1051,7 +1048,7 @@ static ErlNifFunc nif_funcs[] = {
10511048
{"slice", 5, slice},
10521049
{"slice_update", 5, slice_update},
10531050
{"squeeze", 3, squeeze},
1054-
{"item", 1, item},
1051+
{"item", 1, item, ERL_NIF_DIRTY_JOB_CPU_BOUND},
10551052
{"all", 4, all},
10561053
{"any", 4, any},
10571054
{"sum", 4, sum},
@@ -1067,8 +1064,8 @@ static ErlNifFunc nif_funcs[] = {
10671064
{"shape", 1, shape},
10681065
{"reshape", 3, reshape},
10691066
{"astype", 3, astype},
1070-
{"to_blob", 1, to_blob},
1071-
{"to_blob", 2, to_blob},
1067+
{"to_blob", 1, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND},
1068+
{"to_blob", 2, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND},
10721069
{"from_blob", 4, from_blob},
10731070
{"scalar_tensor", 3, scalar_tensor},
10741071
{"ones", 3, ones},
@@ -1154,7 +1151,8 @@ static ErlNifFunc nif_funcs[] = {
11541151
{"tri_inv", 3, tri_inv},
11551152
{"set_compile", 1, set_compile},
11561153
{"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
1157-
{"call_compiled", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND}};
1154+
{"call_compiled_cpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND},
1155+
{"call_compiled_gpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_IO_BOUND}};
11581156

11591157
// Update the NIF initialization
11601158
ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)

lib/emlx.ex

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,21 @@ defmodule EMLX do
255255
deftensor clip(tensor, tensor_min, tensor_max)
256256

257257
## Dirty non-tensor return values
258-
defvalue to_blob(tensor)
259-
defvalue to_blob(tensor, limit)
260258
defvalue scalar_type(tensor)
261259
defvalue shape(tensor)
262260

261+
def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
262+
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
263+
eval(tensor)
264+
EMLX.NIF.to_blob(ref) |> unwrap!()
265+
end
266+
267+
def to_blob({device, ref} = tensor, limit) when is_tensor(device, ref) do
268+
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
269+
eval(tensor)
270+
EMLX.NIF.to_blob(ref, limit) |> unwrap!()
271+
end
272+
263273
defp unwrap!(:ok), do: :ok
264274
defp unwrap!({:ok, result}), do: result
265275
defp unwrap!({:error, error}), do: raise(EMLX.NIFError, List.to_string(error))
@@ -305,7 +315,6 @@ defmodule EMLX do
305315
defp merge_device(_, _), do: :cpu
306316

307317
defvalue deallocate(tensor_ref)
308-
309318
defvalue eval(tensor)
310319

311320
deftensor slice(tensor, starts, stops, strides)
@@ -409,9 +418,14 @@ defmodule EMLX do
409418
cached_fun
410419
end
411420

421+
nif_result =
422+
case device do
423+
:cpu -> EMLX.NIF.call_compiled_cpu(compiled_fun, nif_args)
424+
:gpu -> EMLX.NIF.call_compiled_gpu(compiled_fun, nif_args)
425+
end
426+
412427
results =
413-
compiled_fun
414-
|> EMLX.NIF.call_compiled(nif_args)
428+
nif_result
415429
|> unwrap!()
416430
|> Enum.map(fn ref ->
417431
EMLX.Backend.to_nx({device, ref})

lib/emlx/nif.ex

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,20 @@ defmodule EMLX.NIF do
2727
:erlang.nif_error(:nif_not_loaded)
2828
end
2929

30-
def call_compiled(_compiled_fun, _args) do
30+
# Device-specific NIFs for dirty scheduler optimization
31+
def call_compiled_cpu(_compiled_fun, _args) do
32+
:erlang.nif_error(:nif_not_loaded)
33+
end
34+
35+
def call_compiled_gpu(_compiled_fun, _args) do
36+
:erlang.nif_error(:nif_not_loaded)
37+
end
38+
39+
def to_blob(_tensor) do
40+
:erlang.nif_error(:nif_not_loaded)
41+
end
42+
43+
def to_blob(_tensor, _limit) do
3144
:erlang.nif_error(:nif_not_loaded)
3245
end
3346
end

0 commit comments

Comments
 (0)