@@ -2077,7 +2077,7 @@ defmodule Nx.BinaryBackend do
20772077 for << match! ( x , 0 ) <- data >> , into: << >> do
20782078 x = read! ( x , 0 )
20792079
2080- case x do
2080+ generated_case x do
20812081 % Complex { re: re } when float_output? and real_output? ->
20822082 number_to_binary ( re , output_type )
20832083
@@ -2253,14 +2253,13 @@ defmodule Nx.BinaryBackend do
22532253 end
22542254 end
22552255
2256- output_data =
2257- match_types [ out . type ] do
2258- for row <- result , % Complex { re: re , im: im } <- row , into: << >> do
2259- re = if abs ( re ) <= eps , do: 0 , else: re
2260- im = if abs ( im ) <= eps , do: 0 , else: im
2256+ % { type: { _ , output_size } } = out
22612257
2262- << write! ( Complex . new ( re , im ) , 0 ) >>
2263- end
2258+ output_data =
2259+ for row <- result , % Complex { re: re , im: im } <- row , into: << >> do
2260+ re = if abs ( re ) <= eps , do: 0 , else: re
2261+ im = if abs ( im ) <= eps , do: 0 , else: im
2262+ << write_complex ( re , im , div ( output_size , 2 ) ) :: binary >>
22642263 end
22652264
22662265 intermediate_shape = out . shape |> Tuple . delete_at ( axis ) |> Tuple . append ( n )
@@ -2391,20 +2390,6 @@ defmodule Nx.BinaryBackend do
23912390 end
23922391 end
23932392
2394- defp bin_zip_reduce ( t1 , [ ] , t2 , [ ] , type , acc , fun ) do
2395- % { type: { _ , s1 } } = t1
2396- % { type: { _ , s2 } } = t2
2397- b1 = to_binary ( t1 )
2398- b2 = to_binary ( t2 )
2399-
2400- match_types [ t1 . type , t2 . type ] do
2401- for << d1 :: size ( s1 ) - bitstring <- b1 >> , << d2 :: size ( s2 ) - bitstring <- b2 >> , into: << >> do
2402- { result , _ } = fun . ( d1 , d2 , acc )
2403- scalar_to_binary! ( result , type )
2404- end
2405- end
2406- end
2407-
24082393 defp bin_zip_reduce ( t1 , [ _ | _ ] = axes1 , t2 , [ _ | _ ] = axes2 , type , acc , fun ) do
24092394 { _ , s1 } = t1 . type
24102395 { _ , s2 } = t2 . type
0 commit comments