Skip to content

Commit 4e927d1

Browse files
authored
fix: MLX scalar item bit width bug (#87)
* fix: item read bug * fix: EMLX.item bit width * Update c_src/emlx_nif.cpp
1 parent 939fe9a commit 4e927d1

File tree

2 files changed

+178
-4
lines changed

2 files changed

+178
-4
lines changed

c_src/emlx_nif.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -911,14 +911,53 @@ NIF(isclose) {
911911
NIF(item) {
912912
TENSOR_PARAM(0, t);
913913
mlx::core::eval(*t);
914-
auto dtype_kind = mlx::core::kindof(t->dtype());
915914

916-
if (dtype_kind == mlx::core::Dtype::Kind::u ||
917-
dtype_kind == mlx::core::Dtype::Kind::i ||
918-
dtype_kind == mlx::core::Dtype::Kind::b) {
915+
// Fix for MLX scalar layout bug: Use the correct type when calling item<T>()
916+
// to avoid reading wrong number of bytes from potentially invalid memory
917+
// layouts.
918+
auto dtype = t->dtype();
919+
920+
// Handle integer and boolean types with proper dtype matching
921+
if (dtype == mlx::core::bool_) {
922+
bool value = t->item<bool>();
923+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
924+
} else if (dtype == mlx::core::uint8) {
925+
uint8_t value = t->item<uint8_t>();
926+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
927+
} else if (dtype == mlx::core::uint16) {
928+
uint16_t value = t->item<uint16_t>();
929+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
930+
} else if (dtype == mlx::core::uint32) {
931+
uint32_t value = t->item<uint32_t>();
932+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
933+
} else if (dtype == mlx::core::uint64) {
934+
uint64_t value = t->item<uint64_t>();
935+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
936+
} else if (dtype == mlx::core::int8) {
937+
int8_t value = t->item<int8_t>();
938+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
939+
} else if (dtype == mlx::core::int16) {
940+
int16_t value = t->item<int16_t>();
941+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
942+
} else if (dtype == mlx::core::int32) {
943+
int32_t value = t->item<int32_t>();
944+
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
945+
} else if (dtype == mlx::core::int64) {
919946
int64_t value = t->item<int64_t>();
920947
return nx::nif::ok(env, nx::nif::make(env, value));
948+
} else if (dtype == mlx::core::float16 || dtype == mlx::core::bfloat16) {
949+
// MLX handles float16/bfloat16 conversion internally
950+
float value = t->item<float>();
951+
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
952+
} else if (dtype == mlx::core::float32) {
953+
float value = t->item<float>();
954+
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
955+
} else if (dtype == mlx::core::complex64) {
956+
// Complex types need special handling - not supported via item()
957+
return nx::nif::error(env,
958+
"Complex scalar extraction not supported via item()");
921959
} else {
960+
// Fallback for any other types
922961
double value = t->item<double>();
923962
return nx::nif::ok(env, nx::nif::make(env, value));
924963
}

test/emlx_test.exs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,139 @@ defmodule EMLXTest do
3535
assert_equal(right, Nx.tensor(-1))
3636
end)
3737
end
38+
39+
describe "scalar item extraction (MLX layout bug fix)" do
40+
# Tests for the fix in emlx_nif.cpp:item()
41+
# The bug: MLX creates scalars with invalid memory layout after slice→squeeze
42+
# The fix: Call item<T>() with the correct dtype instead of always using int64/double
43+
44+
# Helper to call EMLX.item() directly (bypasses any Elixir workarounds)
45+
defp item_direct(tensor) do
46+
{_device, ref} = EMLX.Backend.from_nx(tensor)
47+
EMLX.item({:cpu, ref})
48+
end
49+
50+
test "extracts int32 scalar from slice→squeeze" do
51+
array = Nx.iota({1000}, type: :s32)
52+
53+
# Test various indices that previously failed
54+
for idx <- [0, 1, 100, 500, 900, 951, 998, 999] do
55+
sliced = Nx.slice_along_axis(array, idx, 1, axis: 0)
56+
scalar = Nx.squeeze(sliced, axes: [0])
57+
value = item_direct(scalar)
58+
59+
assert value == idx, "Expected #{idx}, got #{value} for int32 scalar"
60+
end
61+
end
62+
63+
test "extracts int8 scalar correctly" do
64+
array = Nx.iota({128}, type: :s8)
65+
66+
for idx <- [0, 1, 50, 100, 127] do
67+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
68+
assert item_direct(scalar) == idx
69+
end
70+
end
71+
72+
test "extracts int16 scalar correctly" do
73+
array = Nx.iota({1000}, type: :s16)
74+
75+
for idx <- [0, 1, 500, 999] do
76+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
77+
assert item_direct(scalar) == idx
78+
end
79+
end
80+
81+
test "extracts int64 scalar correctly" do
82+
array = Nx.iota({100}, type: :s64)
83+
84+
for idx <- [0, 1, 50, 99] do
85+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
86+
assert item_direct(scalar) == idx
87+
end
88+
end
89+
90+
test "extracts uint8 scalar correctly" do
91+
array = Nx.iota({200}, type: :u8)
92+
93+
for idx <- [0, 1, 100, 199] do
94+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
95+
assert item_direct(scalar) == idx
96+
end
97+
end
98+
99+
test "extracts uint16 scalar correctly" do
100+
array = Nx.iota({1000}, type: :u16)
101+
102+
for idx <- [0, 1, 500, 999] do
103+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
104+
assert item_direct(scalar) == idx
105+
end
106+
end
107+
108+
test "extracts uint32 scalar correctly" do
109+
array = Nx.iota({1000}, type: :u32)
110+
111+
for idx <- [0, 1, 500, 951, 999] do
112+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
113+
assert item_direct(scalar) == idx
114+
end
115+
end
116+
117+
test "extracts uint64 scalar correctly" do
118+
array = Nx.iota({100}, type: :u64)
119+
120+
for idx <- [0, 1, 50, 99] do
121+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
122+
assert item_direct(scalar) == idx
123+
end
124+
end
125+
126+
test "extracts float32 scalar correctly" do
127+
array = Nx.iota({100}, type: :f32)
128+
129+
for idx <- [0, 1, 50, 99] do
130+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
131+
assert_in_delta item_direct(scalar), idx * 1.0, 1.0e-6
132+
end
133+
end
134+
135+
test "extracts boolean scalar correctly" do
136+
# Create array [0, 1, 0, 1, ...] as uint8
137+
array = Nx.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], type: :u8)
138+
139+
for idx <- [0, 1, 2, 3] do
140+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
141+
expected = rem(idx, 2)
142+
assert item_direct(scalar) == expected
143+
end
144+
end
145+
146+
test "direct scalar creation works (baseline)" do
147+
# Ensure direct scalar creation still works
148+
scalar = Nx.tensor(951, type: :s32)
149+
assert item_direct(scalar) == 951
150+
end
151+
152+
test "negative values work correctly" do
153+
array = Nx.tensor([-100, -50, 0, 50, 100], type: :s32)
154+
155+
for {expected, idx} <- Enum.with_index([-100, -50, 0, 50, 100]) do
156+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
157+
assert item_direct(scalar) == expected
158+
end
159+
end
160+
161+
test "edge values for int32" do
162+
# Test boundary values
163+
max_val = 2_147_483_647
164+
min_val = -2_147_483_648
165+
array = Nx.tensor([min_val, -1, 0, 1, max_val], type: :s32)
166+
167+
for {expected, idx} <- Enum.with_index([min_val, -1, 0, 1, max_val]) do
168+
scalar = array |> Nx.slice_along_axis(idx, 1, axis: 0) |> Nx.squeeze(axes: [0])
169+
assert item_direct(scalar) == expected
170+
end
171+
end
172+
end
38173
end

0 commit comments

Comments
 (0)