Skip to content

Commit 847bcc5

Browse files
committed
feat: better lowering for Base.Slices
1 parent 6b07526 commit 847bcc5

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3137,7 +3137,7 @@ end
31373137
(),
31383138
"unbatched_" * string(f),
31393139
false;
3140-
args_in_result=:none,
3140+
args_in_result=:result,
31413141
do_transpose=false,
31423142
argprefix,
31433143
)

src/TracedRArray.jl

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,15 @@ end
338338
function overloaded_mapreduce(
339339
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
340340
)
341-
res = unwrapped_broadcast(f, A)
341+
res, updated_dims, re = unwrapped_broadcast(f, A, dims)
342342
# This means we are unable to use the optimized dispatches. For now we will
343343
# unroll the mapreduce.
344344
if typeof(res) == typeof(A)
345-
@assert dims == Colon() "dims not supported for mapreduce currently."
345+
@assert dims isa Colon "dims not supported for mapreduce currently."
346346
return foldl(op, res; init)
347347
end
348-
return overloaded_mapreduce(identity, op, res; dims=:, init)
348+
349+
return re(overloaded_mapreduce(identity, op, res; dims=updated_dims, init))
349350
end
350351

351352
function overloaded_mapreduce(
@@ -361,6 +362,7 @@ function overloaded_mapreduce(
361362
dims isa Int && (dims = Int64[dims])
362363
dims isa Colon && (dims = collect(Int64, 1:N))
363364
dims isa Vector{Int64} || (dims = collect(Int64, dims))
365+
dims = sort(dims)
364366

365367
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
366368
reduce_init = __default_init(op_in_T, op)
@@ -733,7 +735,7 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
733735
dims = dims isa Colon ? nothing : dims
734736
res = []
735737
prev_dims = nothing
736-
for x in unwrapped_broadcast(identity, xs)
738+
for x in first(unwrapped_broadcast(identity, xs, Colon()))
737739
cur_dims = ndims(x)
738740
if prev_dims === nothing
739741
prev_dims = cur_dims
@@ -1358,24 +1360,66 @@ end
13581360

13591361
(fn::BroadcastIterator)(args...) = fn.f((args...,))
13601362

1361-
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
1363+
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip, original_dims) where {F}
13621364
min_length = Base.inferencebarrier(minimum)(length, x.is)
13631365
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
13641366
any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x)
1365-
return broadcast(BroadcastIterator(f), itrs...)
1367+
return broadcast(BroadcastIterator(f), itrs...), original_dims, identity
13661368
end
13671369

1368-
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
1370+
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate, original_dims) where {F}
13691371
x.itr isa AnyTracedRArray || return unrolled_map(f, x)
1370-
return broadcast(
1371-
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
1372+
return (
1373+
broadcast(
1374+
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
1375+
),
1376+
original_dims,
1377+
identity,
13721378
)
13731379
end
13741380

1375-
unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs)
1381+
function unwrapped_broadcast(f::F, x::Slices, original_dims) where {F}
1382+
px = parent(x)
1383+
if ndims(x) != ndims(px) # drop=true
1384+
ordering, mapslices_dims = (), ()
1385+
for (i, s) in enumerate(x.slicemap)
1386+
s isa Colon && continue
1387+
ordering = (ordering..., s)
1388+
mapslices_dims = (mapslices_dims..., i)
1389+
end
1390+
mapslices_dims = Tuple(mapslices_dims[order] for order in ordering)
1391+
1392+
updated_dims = ()
1393+
if original_dims isa Colon
1394+
updated_dims = mapslices_dims
1395+
else
1396+
for d in original_dims
1397+
idx = findfirst(isequal(d), x.slicemap)
1398+
@assert idx !== nothing "Expected dimension $d in $(x.slicemap)"
1399+
updated_dims = (updated_dims..., idx)
1400+
end
1401+
end
1402+
1403+
return (
1404+
mapslices(f, px; dims=mapslices_dims),
1405+
updated_dims,
1406+
x -> eachslice(x; dims=mapslices_dims, drop=true),
1407+
)
1408+
else
1409+
mapslices_dims = Tuple(filter(i -> !(x.slicemap[i] isa Colon), 1:ndims(px)))
1410+
return (
1411+
mapslices(f, px; dims=mapslices_dims),
1412+
original_dims,
1413+
x -> eachslice(x; dims=mapslices_dims, drop=false),
1414+
)
1415+
end
1416+
end
1417+
1418+
function unwrapped_broadcast(f::F, xs, original_dims) where {F}
1419+
return reshape(unrolled_map(f, xs), size(xs)), original_dims, identity
1420+
end
13761421

13771422
# TODO: once traced_call supports internal mutations, we can use traced_call here
1378-
# TODO: we should overload this for Slices and use mapslices instead
13791423
function unrolled_map(f::F, itr) where {F}
13801424
y = Reactant.call_with_reactant(iterate, itr)
13811425
y === nothing && return []

0 commit comments

Comments
 (0)