|
9 | 9 | #include <string> |
10 | 10 | #include <cstring> |
11 | 11 |
|
12 | | -#define NIF_CALL_IMPLEMENTATION |
13 | | -#include "nif_call.h" |
14 | | - |
15 | 12 | using namespace mlx::core; |
16 | 13 |
|
17 | 14 | std::map<const std::string, const mlx::core::Dtype> dtypes = { |
@@ -495,98 +492,6 @@ NIF(eval) { |
495 | 492 | return nx::nif::ok(env); |
496 | 493 | } |
497 | 494 |
|
498 | | -NIF(set_compile) { |
499 | | - PARAM(0, bool, compile); |
500 | | - |
501 | | - if (compile) { |
502 | | - mlx::core::enable_compile(); |
503 | | - } else { |
504 | | - mlx::core::disable_compile(); |
505 | | - } |
506 | | - |
507 | | - return nx::nif::ok(env); |
508 | | -} |
509 | | - |
510 | | -void move_between_envs(ERL_NIF_TERM from_term, ERL_NIF_TERM *to_term, |
511 | | - ErlNifEnv *from_env, ErlNifEnv *to_env) { |
512 | | - ErlNifBinary serialized; |
513 | | - enif_term_to_binary(from_env, from_term, &serialized); |
514 | | - enif_binary_to_term(to_env, serialized.data, serialized.size, to_term, 0); |
515 | | -} |
516 | | - |
517 | | -class EMLXCompileError : public std::runtime_error { |
518 | | -public: |
519 | | - EMLXCompileError(ERL_NIF_TERM term) |
520 | | - : std::runtime_error("EMLXCompileError"), error_term(term) {} |
521 | | - |
522 | | - ERL_NIF_TERM get_error_term() const { return error_term; } |
523 | | - |
524 | | -private: |
525 | | - ERL_NIF_TERM error_term; |
526 | | -}; |
527 | | - |
528 | | -NIF(compile) { |
529 | | - LIST_PARAM(0, std::vector<mlx::core::array>, arrays); |
530 | | - ERL_NIF_TERM tag = argv[1]; |
531 | | - |
532 | | - ErlNifEnv *closure_env = enif_alloc_env(); |
533 | | - |
534 | | - auto fun = [env = closure_env, outer_env = env, tag_outer = tag]( |
535 | | - const std::vector<mlx::core::array> &compile_args) { |
536 | | - ERL_NIF_TERM tag = enif_make_copy(env, tag_outer); |
537 | | - ERL_NIF_TERM tensor_list = nx::nif::make_list(env, compile_args); |
538 | | - ERL_NIF_TERM arg_list = enif_make_list1(env, tensor_list); |
539 | | - |
540 | | - NifCallResult result = make_nif_call(env, tag, arg_list); |
541 | | - enif_clear_env(env); |
542 | | - |
543 | | - if (!result.is_ok()) { |
544 | | - ERL_NIF_TERM error_term = |
545 | | - enif_make_tuple2(env, enif_make_atom(env, "error"), result.get_err()); |
546 | | - throw EMLXCompileError(enif_make_copy(outer_env, error_term)); |
547 | | - } |
548 | | - |
549 | | - ERL_NIF_TERM output_list = result.get_value(); |
550 | | - |
551 | | - // Convert output_list back to vector of MLX arrays |
552 | | - std::vector<mlx::core::array> output_tensors; |
553 | | - if (!nx::nif::get_list(env, output_list, output_tensors)) { |
554 | | - ERL_NIF_TERM error_string = enif_make_string( |
555 | | - env, "Failed to convert callback result to tensors", ERL_NIF_LATIN1); |
556 | | - ERL_NIF_TERM error_term = |
557 | | - enif_make_tuple2(env, enif_make_atom(env, "error"), error_string); |
558 | | - throw EMLXCompileError(enif_make_copy(outer_env, error_term)); |
559 | | - } |
560 | | - |
561 | | - enif_free_env(env); |
562 | | - |
563 | | - return output_tensors; |
564 | | - }; |
565 | | - |
566 | | - emlx::function compiled_function_ptr; |
567 | | - |
568 | | - try { |
569 | | - compiled_function_ptr = mlx::core::compile(fun); |
570 | | - } catch (const EMLXCompileError &e) { |
571 | | - return e.get_error_term(); |
572 | | - } |
573 | | - |
574 | | - return nx::nif::ok(env, create_function_resource(env, compiled_function_ptr)); |
575 | | -} |
576 | | - |
577 | | -NIF(call_compiled) { |
578 | | - emlx::function *compiled_function_ptr = nullptr; |
579 | | - |
580 | | - if (!nx::nif::get(env, argv[0], compiled_function_ptr)) { |
581 | | - return nx::nif::error(env, "Unable to get compiled function pointer"); |
582 | | - } |
583 | | - LIST_PARAM(1, std::vector<mlx::core::array>, args); |
584 | | - |
585 | | - std::vector<mlx::core::array> result = (*compiled_function_ptr)(args); |
586 | | - |
587 | | - return nx::nif::ok(env, nx::nif::make_list(env, result)); |
588 | | -} |
589 | | - |
590 | 495 | NIF(stack) { |
591 | 496 | LIST_PARAM(0, std::vector<mlx::core::array>, arrays); |
592 | 497 | PARAM(1, int, axis); |
@@ -794,10 +699,6 @@ static int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) { |
794 | 699 | return -1; |
795 | 700 | } |
796 | 701 |
|
797 | | - if (nif_call_onload(env) != 0) { |
798 | | - return -1; |
799 | | - } |
800 | | - |
801 | 702 | return 0; |
802 | 703 | } |
803 | 704 |
|
@@ -1070,7 +971,6 @@ NIF(as_strided) { |
1070 | 971 | } |
1071 | 972 |
|
1072 | 973 | static ErlNifFunc nif_funcs[] = { |
1073 | | - NIF_CALL_NIF_FUNC(nif_call_evaluated), |
1074 | 974 | {"strides", 1, strides}, |
1075 | 975 | {"as_strided", 5, as_strided}, |
1076 | 976 | {"scalar_type", 1, scalar_type}, |
@@ -1187,11 +1087,8 @@ static ErlNifFunc nif_funcs[] = { |
1187 | 1087 | {"max", 4, max}, |
1188 | 1088 | {"min", 4, min}, |
1189 | 1089 | {"clip", 4, clip}, |
1190 | | - {"tri_inv", 3, tri_inv}, |
1191 | | - {"set_compile", 1, set_compile}, |
1192 | | - {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, |
1193 | | - {"call_compiled_cpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND}, |
1194 | | - {"call_compiled_gpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_IO_BOUND}}; |
| 1090 | + {"tri_inv", 3, tri_inv} |
| 1091 | +}; |
1195 | 1092 |
|
1196 | 1093 | // Update the NIF initialization |
1197 | 1094 | ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL) |
0 commit comments