@@ -338,14 +338,15 @@ end
338338function overloaded_mapreduce (
339339 @nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= Base. _InitialValue ()
340340)
341- res = unwrapped_broadcast (f, A)
341+ res, updated_dims, re = unwrapped_broadcast (f, A, dims )
342342 # This means we are unable to use the optimized dispatches. For now we will
343343 # unroll the mapreduce.
344344 if typeof (res) == typeof (A)
345- @assert dims == Colon () " dims not supported for mapreduce currently."
345+ @assert dims isa Colon " dims not supported for mapreduce currently."
346346 return foldl (op, res; init)
347347 end
348- return overloaded_mapreduce (identity, op, res; dims= :, init)
348+
349+ return re (overloaded_mapreduce (identity, op, res; dims= updated_dims, init))
349350end
350351
351352function overloaded_mapreduce (
@@ -361,6 +362,7 @@ function overloaded_mapreduce(
361362 dims isa Int && (dims = Int64[dims])
362363 dims isa Colon && (dims = collect (Int64, 1 : N))
363364 dims isa Vector{Int64} || (dims = collect (Int64, dims))
365+ dims = sort (dims)
364366
365367 op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Tuple{T}))
366368 reduce_init = __default_init (op_in_T, op)
@@ -733,7 +735,7 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
733735 dims = dims isa Colon ? nothing : dims
734736 res = []
735737 prev_dims = nothing
736- for x in unwrapped_broadcast (identity, xs)
738+ for x in first ( unwrapped_broadcast (identity, xs, Colon ()) )
737739 cur_dims = ndims (x)
738740 if prev_dims === nothing
739741 prev_dims = cur_dims
@@ -1358,24 +1360,66 @@ end
13581360
13591361(fn:: BroadcastIterator )(args... ) = fn. f ((args... ,))
13601362
1361- function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip ) where {F}
1363+ function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip , original_dims ) where {F}
13621364 min_length = Base. inferencebarrier (minimum)(length, x. is)
13631365 itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
13641366 any (Base. Fix2 (isa, AnyTracedRArray), itrs) || return unrolled_map (f, x)
1365- return broadcast (BroadcastIterator (f), itrs... )
1367+ return broadcast (BroadcastIterator (f), itrs... ), original_dims, identity
13661368end
13671369
1368- function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate ) where {F}
1370+ function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate , original_dims ) where {F}
13691371 x. itr isa AnyTracedRArray || return unrolled_map (f, x)
1370- return broadcast (
1371- BroadcastIterator (f), Reactant. promote_to (TracedRArray, 1 : length (x. itr)), x. itr
1372+ return (
1373+ broadcast (
1374+ BroadcastIterator (f), Reactant. promote_to (TracedRArray, 1 : length (x. itr)), x. itr
1375+ ),
1376+ original_dims,
1377+ identity,
13721378 )
13731379end
13741380
1375- unwrapped_broadcast (f:: F , xs) where {F} = unrolled_map (f, xs)
1381+ function unwrapped_broadcast (f:: F , x:: Slices , original_dims) where {F}
1382+ px = parent (x)
1383+ if ndims (x) != ndims (px) # drop=true
1384+ ordering, mapslices_dims = (), ()
1385+ for (i, s) in enumerate (x. slicemap)
1386+ s isa Colon && continue
1387+ ordering = (ordering... , s)
1388+ mapslices_dims = (mapslices_dims... , i)
1389+ end
1390+ mapslices_dims = Tuple (mapslices_dims[order] for order in ordering)
1391+
1392+ updated_dims = ()
1393+ if original_dims isa Colon
1394+ updated_dims = mapslices_dims
1395+ else
1396+ for d in original_dims
1397+ idx = findfirst (isequal (d), x. slicemap)
1398+ @assert idx != = nothing " Expected dimension $d in $(x. slicemap) "
1399+ updated_dims = (updated_dims... , idx)
1400+ end
1401+ end
1402+
1403+ return (
1404+ mapslices (f, px; dims= mapslices_dims),
1405+ updated_dims,
1406+ x -> eachslice (x; dims= mapslices_dims, drop= true ),
1407+ )
1408+ else
1409+ mapslices_dims = Tuple (filter (i -> ! (x. slicemap[i] isa Colon), 1 : ndims (px)))
1410+ return (
1411+ mapslices (f, px; dims= mapslices_dims),
1412+ original_dims,
1413+ x -> eachslice (x; dims= mapslices_dims, drop= false ),
1414+ )
1415+ end
1416+ end
1417+
1418+ function unwrapped_broadcast (f:: F , xs, original_dims) where {F}
1419+ return reshape (unrolled_map (f, xs), size (xs)), original_dims, identity
1420+ end
13761421
13771422# TODO : once traced_call supports internal mutations, we can use traced_call here
1378- # TODO : we should overload this for Slices and use mapslices instead
13791423function unrolled_map (f:: F , itr) where {F}
13801424 y = Reactant. call_with_reactant (iterate, itr)
13811425 y === nothing && return []
0 commit comments