diff --git a/src/Ops.jl b/src/Ops.jl index ed162501ba..474ef2e32f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3137,7 +3137,7 @@ end (), "unbatched_" * string(f), false; - args_in_result=:none, + args_in_result=:result, do_transpose=false, argprefix, ) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f01c57231b..a2b3bd7062 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -338,14 +338,15 @@ end function overloaded_mapreduce( @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue() ) - res = unwrapped_broadcast(f, A) + res, updated_dims, re = unwrapped_broadcast(f, A, dims) # This means we are unable to use the optimized dispatches. For now we will # unroll the mapreduce. if typeof(res) == typeof(A) - @assert dims == Colon() "dims not supported for mapreduce currently." + @assert dims isa Colon "dims not supported for mapreduce currently." return foldl(op, res; init) end - return overloaded_mapreduce(identity, op, res; dims=:, init) + + return re(overloaded_mapreduce(identity, op, res; dims=updated_dims, init)) end function overloaded_mapreduce( @@ -361,6 +362,7 @@ function overloaded_mapreduce( dims isa Int && (dims = Int64[dims]) dims isa Colon && (dims = collect(Int64, 1:N)) dims isa Vector{Int64} || (dims = collect(Int64, dims)) + dims = sort(dims) op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T})) reduce_init = __default_init(op_in_T, op) @@ -733,7 +735,7 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs) dims = dims isa Colon ? nothing : dims res = [] prev_dims = nothing - for x in unwrapped_broadcast(identity, xs) + for x in first(unwrapped_broadcast(identity, xs, Colon())) cur_dims = ndims(x) if prev_dims === nothing prev_dims = cur_dims @@ -1358,24 +1360,91 @@ end (fn::BroadcastIterator)(args...) = fn.f((args...,)) -function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F} +abstract type AbstractUnwrappedBroadcastRestoreFunction end + +struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end +(::__Identity)(x) = x + +struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction + dims::D + drop::Bool +end +(s::__EachSlice{D})(x) where {D} = eachslice(x; dims=s.dims, drop=s.drop) + +struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction + dims::D +end +(s::__DropDims{D})(x) where {D} = dropdims(x; dims=s.dims) + +function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F} min_length = Base.inferencebarrier(minimum)(length, x.is) itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is] - any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x) - return broadcast(BroadcastIterator(f), itrs...) + result = if any(Base.Fix2(isa, AnyTracedRArray), itrs) + broadcast(BroadcastIterator(f), itrs...) + else + unrolled_map(f, x) + end + return result, original_dims, __Identity() end -function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F} - x.itr isa AnyTracedRArray || return unrolled_map(f, x) - return broadcast( - BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr - ) +function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F} + result = if x.itr isa AnyTracedRArray + broadcast( + BroadcastIterator(f), + Reactant.promote_to(TracedRArray, 1:length(x.itr)), + x.itr, + ) + else + unrolled_map(f, x) + end + return result, original_dims, __Identity() end -unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs) +function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F} + px = parent(x) + if ndims(x) != ndims(px) # drop=true + ordering, mapslices_dims = (), () + for (i, s) in enumerate(x.slicemap) + s isa Colon && continue + ordering = (ordering..., s) + mapslices_dims = (mapslices_dims..., i) + end + mapslices_dims = Tuple(mapslices_dims[order] for order in ordering) + + updated_dims = () + if original_dims isa Colon + updated_dims = mapslices_dims + re = __DropDims(mapslices_dims) + else + for d in original_dims + idx = findfirst(isequal(d), x.slicemap) + @assert idx !== nothing "Expected dimension $d in $(x.slicemap)" + updated_dims = (updated_dims..., idx) + end + re = __EachSlice(mapslices_dims, true) + end + + return mapslices(f, px; dims=mapslices_dims), updated_dims, re + else + mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px))) + if original_dims isa Colon + updated_dims = mapslices_dims + re = __DropDims(mapslices_dims) + else + updated_dims = Tuple(d for d in original_dims if d in mapslices_dims) + re = __EachSlice(mapslices_dims, false) + end + return mapslices(f, px; dims=mapslices_dims), updated_dims, re + end +end + +function unwrapped_broadcast(f::F, xs, original_dims) where {F} + mapped_xs = unrolled_map(f, xs) + applicable(size, xs) && (mapped_xs = reshape(mapped_xs, size(xs))) + return mapped_xs, original_dims, __Identity() +end # TODO: once traced_call supports internal mutations, we can use traced_call here -# TODO: we should overload this for Slices and use mapslices instead function unrolled_map(f::F, itr) where {F} y = Reactant.call_with_reactant(iterate, itr) y === nothing && return [] diff --git a/test/basic.jl b/test/basic.jl index 3e50c002ac..6c5646d585 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1725,3 +1725,43 @@ end fn = @compile sum(x_ra1) @test_throws Reactant.Compiler.MisMatchedThunkTypeError fn(x_ra2) end + +@testset "Slices" begin + @testset "drop=true" begin + x = eachslice( + Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5); dims=(3, 1) + ) + x_ra = Reactant.to_rarray(x) + + @test @jit(sum(x_ra)) ≈ sum(x) + + @testset for dims in (1, 2, (1, 2), (2, 1)) + res_ra = @jit sum(x_ra; dims) + res = sum(x; dims) + @test size(res_ra) == size(res) + for (gt, comp) in zip(res_ra, res) + @test gt ≈ comp + end + end + end + + @testset "drop=false" begin + x = eachslice( + Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5); + dims=(3, 1), + drop=false, + ) + x_ra = Reactant.to_rarray(x) + + @test @jit(sum(x_ra)) ≈ sum(x) + + @testset for dims in (1, 2, 3, 4, (1, 2), (1, 2, 4), (3, 4, 1), (2, 1)) + res_ra = @jit sum(x_ra; dims) + res = sum(x; dims) + @test size(res_ra) == size(res) + for (gt, comp) in zip(res_ra, res) + @test gt ≈ comp + end + end + end +end