@@ -35,4 +35,139 @@ defmodule EMLXTest do
3535 assert_equal ( right , Nx . tensor ( - 1 ) )
3636 end )
3737 end
38+
39+ describe "scalar item extraction (MLX layout bug fix)" do
40+ # Tests for the fix in emlx_nif.cpp:item()
41+ # The bug: MLX creates scalars with invalid memory layout after slice→squeeze
42+ # The fix: Call item<T>() with the correct dtype instead of always using int64/double
43+
44+ # Helper to call EMLX.item() directly (bypasses any Elixir workarounds)
45+ defp item_direct ( tensor ) do
46+ { _device , ref } = EMLX.Backend . from_nx ( tensor )
47+ EMLX . item ( { :cpu , ref } )
48+ end
49+
50+ test "extracts int32 scalar from slice→squeeze" do
51+ array = Nx . iota ( { 1000 } , type: :s32 )
52+
53+ # Test various indices that previously failed
54+ for idx <- [ 0 , 1 , 100 , 500 , 900 , 951 , 998 , 999 ] do
55+ sliced = Nx . slice_along_axis ( array , idx , 1 , axis: 0 )
56+ scalar = Nx . squeeze ( sliced , axes: [ 0 ] )
57+ value = item_direct ( scalar )
58+
59+ assert value == idx , "Expected #{ idx } , got #{ value } for int32 scalar"
60+ end
61+ end
62+
63+ test "extracts int8 scalar correctly" do
64+ array = Nx . iota ( { 128 } , type: :s8 )
65+
66+ for idx <- [ 0 , 1 , 50 , 100 , 127 ] do
67+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
68+ assert item_direct ( scalar ) == idx
69+ end
70+ end
71+
72+ test "extracts int16 scalar correctly" do
73+ array = Nx . iota ( { 1000 } , type: :s16 )
74+
75+ for idx <- [ 0 , 1 , 500 , 999 ] do
76+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
77+ assert item_direct ( scalar ) == idx
78+ end
79+ end
80+
81+ test "extracts int64 scalar correctly" do
82+ array = Nx . iota ( { 100 } , type: :s64 )
83+
84+ for idx <- [ 0 , 1 , 50 , 99 ] do
85+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
86+ assert item_direct ( scalar ) == idx
87+ end
88+ end
89+
90+ test "extracts uint8 scalar correctly" do
91+ array = Nx . iota ( { 200 } , type: :u8 )
92+
93+ for idx <- [ 0 , 1 , 100 , 199 ] do
94+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
95+ assert item_direct ( scalar ) == idx
96+ end
97+ end
98+
99+ test "extracts uint16 scalar correctly" do
100+ array = Nx . iota ( { 1000 } , type: :u16 )
101+
102+ for idx <- [ 0 , 1 , 500 , 999 ] do
103+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
104+ assert item_direct ( scalar ) == idx
105+ end
106+ end
107+
108+ test "extracts uint32 scalar correctly" do
109+ array = Nx . iota ( { 1000 } , type: :u32 )
110+
111+ for idx <- [ 0 , 1 , 500 , 951 , 999 ] do
112+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
113+ assert item_direct ( scalar ) == idx
114+ end
115+ end
116+
117+ test "extracts uint64 scalar correctly" do
118+ array = Nx . iota ( { 100 } , type: :u64 )
119+
120+ for idx <- [ 0 , 1 , 50 , 99 ] do
121+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
122+ assert item_direct ( scalar ) == idx
123+ end
124+ end
125+
126+ test "extracts float32 scalar correctly" do
127+ array = Nx . iota ( { 100 } , type: :f32 )
128+
129+ for idx <- [ 0 , 1 , 50 , 99 ] do
130+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
131+ assert_in_delta item_direct ( scalar ) , idx * 1.0 , 1.0e-6
132+ end
133+ end
134+
135+ test "extracts boolean scalar correctly" do
136+ # Create array [0, 1, 0, 1, ...] as uint8
137+ array = Nx . tensor ( [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 ] , type: :u8 )
138+
139+ for idx <- [ 0 , 1 , 2 , 3 ] do
140+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
141+ expected = rem ( idx , 2 )
142+ assert item_direct ( scalar ) == expected
143+ end
144+ end
145+
146+ test "direct scalar creation works (baseline)" do
147+ # Ensure direct scalar creation still works
148+ scalar = Nx . tensor ( 951 , type: :s32 )
149+ assert item_direct ( scalar ) == 951
150+ end
151+
152+ test "negative values work correctly" do
153+ array = Nx . tensor ( [ - 100 , - 50 , 0 , 50 , 100 ] , type: :s32 )
154+
155+ for { expected , idx } <- Enum . with_index ( [ - 100 , - 50 , 0 , 50 , 100 ] ) do
156+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
157+ assert item_direct ( scalar ) == expected
158+ end
159+ end
160+
161+ test "edge values for int32" do
162+ # Test boundary values
163+ max_val = 2_147_483_647
164+ min_val = - 2_147_483_648
165+ array = Nx . tensor ( [ min_val , - 1 , 0 , 1 , max_val ] , type: :s32 )
166+
167+ for { expected , idx } <- Enum . with_index ( [ min_val , - 1 , 0 , 1 , max_val ] ) do
168+ scalar = array |> Nx . slice_along_axis ( idx , 1 , axis: 0 ) |> Nx . squeeze ( axes: [ 0 ] )
169+ assert item_direct ( scalar ) == expected
170+ end
171+ end
172+ end
38173end
0 commit comments