Skip to content

Commit e42d9ff

Browse files
committed
chore: bump to v0.2.0
1 parent 495d57e commit e42d9ff

File tree

7 files changed

+5
-397
lines changed

7 files changed

+5
-397
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ MAKE_JOBS ?= $(MAKE_DEFAULT_JOBS)
4040

4141
# Source files
4242
SOURCES = c_src/emlx_nif.cpp
43-
HEADERS = c_src/nif_call.h c_src/nx_nif_utils.hpp
43+
HEADERS = c_src/nx_nif_utils.hpp
4444
OBJECTS = $(patsubst c_src/%.cpp,$(BUILD_DIR)/%.o,$(SOURCES))
4545

4646
# Main targets

c_src/emlx_nif.cpp

Lines changed: 2 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
#include <string>
1010
#include <cstring>
1111

12-
#define NIF_CALL_IMPLEMENTATION
13-
#include "nif_call.h"
14-
1512
using namespace mlx::core;
1613

1714
std::map<const std::string, const mlx::core::Dtype> dtypes = {
@@ -495,98 +492,6 @@ NIF(eval) {
495492
return nx::nif::ok(env);
496493
}
497494

498-
NIF(set_compile) {
499-
PARAM(0, bool, compile);
500-
501-
if (compile) {
502-
mlx::core::enable_compile();
503-
} else {
504-
mlx::core::disable_compile();
505-
}
506-
507-
return nx::nif::ok(env);
508-
}
509-
510-
void move_between_envs(ERL_NIF_TERM from_term, ERL_NIF_TERM *to_term,
511-
ErlNifEnv *from_env, ErlNifEnv *to_env) {
512-
ErlNifBinary serialized;
513-
enif_term_to_binary(from_env, from_term, &serialized);
514-
enif_binary_to_term(to_env, serialized.data, serialized.size, to_term, 0);
515-
}
516-
517-
class EMLXCompileError : public std::runtime_error {
518-
public:
519-
EMLXCompileError(ERL_NIF_TERM term)
520-
: std::runtime_error("EMLXCompileError"), error_term(term) {}
521-
522-
ERL_NIF_TERM get_error_term() const { return error_term; }
523-
524-
private:
525-
ERL_NIF_TERM error_term;
526-
};
527-
528-
NIF(compile) {
529-
LIST_PARAM(0, std::vector<mlx::core::array>, arrays);
530-
ERL_NIF_TERM tag = argv[1];
531-
532-
ErlNifEnv *closure_env = enif_alloc_env();
533-
534-
auto fun = [env = closure_env, outer_env = env, tag_outer = tag](
535-
const std::vector<mlx::core::array> &compile_args) {
536-
ERL_NIF_TERM tag = enif_make_copy(env, tag_outer);
537-
ERL_NIF_TERM tensor_list = nx::nif::make_list(env, compile_args);
538-
ERL_NIF_TERM arg_list = enif_make_list1(env, tensor_list);
539-
540-
NifCallResult result = make_nif_call(env, tag, arg_list);
541-
enif_clear_env(env);
542-
543-
if (!result.is_ok()) {
544-
ERL_NIF_TERM error_term =
545-
enif_make_tuple2(env, enif_make_atom(env, "error"), result.get_err());
546-
throw EMLXCompileError(enif_make_copy(outer_env, error_term));
547-
}
548-
549-
ERL_NIF_TERM output_list = result.get_value();
550-
551-
// Convert output_list back to vector of MLX arrays
552-
std::vector<mlx::core::array> output_tensors;
553-
if (!nx::nif::get_list(env, output_list, output_tensors)) {
554-
ERL_NIF_TERM error_string = enif_make_string(
555-
env, "Failed to convert callback result to tensors", ERL_NIF_LATIN1);
556-
ERL_NIF_TERM error_term =
557-
enif_make_tuple2(env, enif_make_atom(env, "error"), error_string);
558-
throw EMLXCompileError(enif_make_copy(outer_env, error_term));
559-
}
560-
561-
enif_free_env(env);
562-
563-
return output_tensors;
564-
};
565-
566-
emlx::function compiled_function_ptr;
567-
568-
try {
569-
compiled_function_ptr = mlx::core::compile(fun);
570-
} catch (const EMLXCompileError &e) {
571-
return e.get_error_term();
572-
}
573-
574-
return nx::nif::ok(env, create_function_resource(env, compiled_function_ptr));
575-
}
576-
577-
NIF(call_compiled) {
578-
emlx::function *compiled_function_ptr = nullptr;
579-
580-
if (!nx::nif::get(env, argv[0], compiled_function_ptr)) {
581-
return nx::nif::error(env, "Unable to get compiled function pointer");
582-
}
583-
LIST_PARAM(1, std::vector<mlx::core::array>, args);
584-
585-
std::vector<mlx::core::array> result = (*compiled_function_ptr)(args);
586-
587-
return nx::nif::ok(env, nx::nif::make_list(env, result));
588-
}
589-
590495
NIF(stack) {
591496
LIST_PARAM(0, std::vector<mlx::core::array>, arrays);
592497
PARAM(1, int, axis);
@@ -794,10 +699,6 @@ static int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) {
794699
return -1;
795700
}
796701

797-
if (nif_call_onload(env) != 0) {
798-
return -1;
799-
}
800-
801702
return 0;
802703
}
803704

@@ -1070,7 +971,6 @@ NIF(as_strided) {
1070971
}
1071972

1072973
static ErlNifFunc nif_funcs[] = {
1073-
NIF_CALL_NIF_FUNC(nif_call_evaluated),
1074974
{"strides", 1, strides},
1075975
{"as_strided", 5, as_strided},
1076976
{"scalar_type", 1, scalar_type},
@@ -1187,11 +1087,8 @@ static ErlNifFunc nif_funcs[] = {
11871087
{"max", 4, max},
11881088
{"min", 4, min},
11891089
{"clip", 4, clip},
1190-
{"tri_inv", 3, tri_inv},
1191-
{"set_compile", 1, set_compile},
1192-
{"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
1193-
{"call_compiled_cpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_CPU_BOUND},
1194-
{"call_compiled_gpu", 2, call_compiled, ERL_NIF_DIRTY_JOB_IO_BOUND}};
1090+
{"tri_inv", 3, tri_inv}
1091+
};
11951092

11961093
// Update the NIF initialization
11971094
ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)

0 commit comments

Comments
 (0)