Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3137,7 +3137,7 @@ end
(),
"unbatched_" * string(f),
false;
args_in_result=:none,
args_in_result=:result,
do_transpose=false,
argprefix,
)
Expand Down
97 changes: 83 additions & 14 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,15 @@ end
function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
)
res = unwrapped_broadcast(f, A)
res, updated_dims, re = unwrapped_broadcast(f, A, dims)
# This means we are unable to use the optimized dispatches. For now we will
# unroll the mapreduce.
if typeof(res) == typeof(A)
@assert dims == Colon() "dims not supported for mapreduce currently."
@assert dims isa Colon "dims not supported for mapreduce currently."
return foldl(op, res; init)
end
return overloaded_mapreduce(identity, op, res; dims=:, init)

return re(overloaded_mapreduce(identity, op, res; dims=updated_dims, init))
end

function overloaded_mapreduce(
Expand All @@ -361,6 +362,7 @@ function overloaded_mapreduce(
dims isa Int && (dims = Int64[dims])
dims isa Colon && (dims = collect(Int64, 1:N))
dims isa Vector{Int64} || (dims = collect(Int64, dims))
dims = sort(dims)

op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
reduce_init = __default_init(op_in_T, op)
Expand Down Expand Up @@ -733,7 +735,7 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
dims = dims isa Colon ? nothing : dims
res = []
prev_dims = nothing
for x in unwrapped_broadcast(identity, xs)
for x in first(unwrapped_broadcast(identity, xs, Colon()))
cur_dims = ndims(x)
if prev_dims === nothing
prev_dims = cur_dims
Expand Down Expand Up @@ -1358,24 +1360,91 @@ end

(fn::BroadcastIterator)(args...) = fn.f((args...,))

function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
abstract type AbstractUnwrappedBroadcastRestoreFunction end

struct __Identity <: AbstractUnwrappedBroadcastRestoreFunction end
(::__Identity)(x) = x

struct __EachSlice{D} <: AbstractUnwrappedBroadcastRestoreFunction
dims::D
drop::Bool
end
(s::__EachSlice{D})(x) where {D} = eachslice(x; dims=s.dims, drop=s.drop)

struct __DropDims{D} <: AbstractUnwrappedBroadcastRestoreFunction
dims::D
end
(s::__DropDims{D})(x) where {D} = dropdims(x; dims=s.dims)

function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F}
min_length = Base.inferencebarrier(minimum)(length, x.is)
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x)
return broadcast(BroadcastIterator(f), itrs...)
result = if any(Base.Fix2(isa, AnyTracedRArray), itrs)
broadcast(BroadcastIterator(f), itrs...)
else
unrolled_map(f, x)
end
return result, original_dims, __Identity()
end

function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
x.itr isa AnyTracedRArray || return unrolled_map(f, x)
return broadcast(
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
)
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F}
result = if x.itr isa AnyTracedRArray
broadcast(
BroadcastIterator(f),
Reactant.promote_to(TracedRArray, 1:length(x.itr)),
x.itr,
)
else
unrolled_map(f, x)
end
return result, original_dims, __Identity()
end

unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs)
function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
px = parent(x)
if ndims(x) != ndims(px) # drop=true
ordering, mapslices_dims = (), ()
for (i, s) in enumerate(x.slicemap)
s isa Colon && continue
ordering = (ordering..., s)
mapslices_dims = (mapslices_dims..., i)
end
mapslices_dims = Tuple(mapslices_dims[order] for order in ordering)

updated_dims = ()
if original_dims isa Colon
updated_dims = mapslices_dims
re = __DropDims(mapslices_dims)
else
for d in original_dims
idx = findfirst(isequal(d), x.slicemap)
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
updated_dims = (updated_dims..., idx)
end
re = __EachSlice(mapslices_dims, true)
end

return mapslices(f, px; dims=mapslices_dims), updated_dims, re
else
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
if original_dims isa Colon
updated_dims = mapslices_dims
re = __DropDims(mapslices_dims)
else
updated_dims = Tuple(d for d in original_dims if d in mapslices_dims)
re = __EachSlice(mapslices_dims, false)
end
return mapslices(f, px; dims=mapslices_dims), updated_dims, re
end
end

function unwrapped_broadcast(f::F, xs, original_dims) where {F}
mapped_xs = unrolled_map(f, xs)
applicable(size, xs) && (mapped_xs = reshape(mapped_xs, size(xs)))
return mapped_xs, original_dims, __Identity()
end

# TODO: once traced_call supports internal mutations, we can use traced_call here
# TODO: we should overload this for Slices and use mapslices instead
function unrolled_map(f::F, itr) where {F}
y = Reactant.call_with_reactant(iterate, itr)
y === nothing && return []
Expand Down
40 changes: 40 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1725,3 +1725,43 @@ end
fn = @compile sum(x_ra1)
@test_throws Reactant.Compiler.MisMatchedThunkTypeError fn(x_ra2)
end

@testset "Slices" begin
@testset "drop=true" begin
x = eachslice(
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5); dims=(3, 1)
)
x_ra = Reactant.to_rarray(x)

@test @jit(sum(x_ra)) ≈ sum(x)

@testset for dims in (1, 2, (1, 2), (2, 1))
res_ra = @jit sum(x_ra; dims)
res = sum(x; dims)
@test size(res_ra) == size(res)
for (gt, comp) in zip(res_ra, res)
@test gt ≈ comp
end
end
end

@testset "drop=false" begin
x = eachslice(
Reactant.TestUtils.construct_test_array(Float32, 2, 3, 4, 5);
dims=(3, 1),
drop=false,
)
x_ra = Reactant.to_rarray(x)

@test @jit(sum(x_ra)) ≈ sum(x)

@testset for dims in (1, 2, 3, 4, (1, 2), (1, 2, 4), (3, 4, 1), (2, 1))
res_ra = @jit sum(x_ra; dims)
res = sum(x; dims)
@test size(res_ra) == size(res)
for (gt, comp) in zip(res_ra, res)
@test gt ≈ comp
end
end
end
end
Loading