@@ -70,29 +70,40 @@ struct ArrayOfSimilarArrays{
7070 data:: P
7171
7272 function ArrayOfSimilarArrays {T,M,N} (flat_data:: AbstractArray{U,L} ) where {T,M,N,L,U}
73- size_inner, size_outer = split_tuple (size (flat_data), Val {M} ())
7473 require_ndims (flat_data, _add_vals (Val {M} (), Val {N} ()))
7574 conv_parent = _convert_elype (T, flat_data)
7675 P = typeof (conv_parent)
7776 new {T,M,N,L,P} (conv_parent)
7877 end
78+ end
7979
80- function ArrayOfSimilarArrays {T,M} (flat_data:: AbstractArray{U,L} ) where {T,M,L,U}
81- size_inner, size_outer = split_tuple (size (flat_data), Val {M} ())
82- N = length (size_outer)
83- conv_parent = _convert_elype (T, flat_data)
84- P = typeof (conv_parent)
85- new {T,M,N,L,P} (conv_parent)
86- end
80+ function ArrayOfSimilarArrays {T,M} (flat_data:: AbstractArray{U,L} ) where {T,M,L,U}
81+ _, size_outer = split_tuple (size (flat_data), Val {M} ())
82+ N = length (size_outer)
83+ ArrayOfSimilarArrays {T,M,N} (flat_data)
8784end
8885
8986export ArrayOfSimilarArrays
9087
88+ function _aosa_ctor_fromflat_pullback (ΔΩ)
89+ NoTangent (), flatview (convert (ArrayOfSimilarArrays, unthunk (ΔΩ)))
90+ end
91+
92+ function ChainRulesCore. rrule (:: Type{ArrayOfSimilarArrays{T,M,N}} , flat_data:: AbstractArray{U,L} ) where {T,M,N,L,U}
93+ return ArrayOfSimilarArrays {T,M,N} (flat_data), _aosa_ctor_fromflat_pullback
94+ end
95+
9196function ArrayOfSimilarArrays {T,M,N} (A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U}
9297 B = ArrayOfSimilarArrays {T,M,N} (Array {T} (undef, innersize (A)... , size (A)... ))
9398 copyto! (B, A)
9499end
95100
101+ _aosa_ctor_fromnested_pullback (ΔΩ) = NoTangent (), ΔΩ
102+
103+ function ChainRulesCore. rrule (:: Type{ArrayOfSimilarArrays{T,M,N}} , A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U}
104+ return ArrayOfSimilarArrays {T,M,N} (A), _aosa_ctor_fromnested_pullback
105+ end
106+
96107ArrayOfSimilarArrays {T} (A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U} =
97108 ArrayOfSimilarArrays {T,M,N} (A)
98109
0 commit comments