@@ -16,53 +16,27 @@ ArrayInterfaceCore.indices_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) wher
1616ArrayInterfaceCore. instances_do_not_alias (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = ArrayInterfaceCore. instances_do_not_alias (A)
1717
1818# Cats
19- # TODO : Make this a little less copy-pastey
20- function Base. hcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
21- ax_x, ax_y = second_axis .((x,y))
22- if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[1 ] != getaxes (y)[1 ]
23- return hcat (getdata (x), getdata (y))
19+ function Base. cat (inputs:: ComponentArray... ; dims:: Int )
20+ combined_data = cat (getdata .(inputs)... ; dims= dims)
21+ axes_to_merge = [(getaxes (i)... , FlatAxis ())[dims] for i in inputs]
22+ rest_axes = [getaxes (i)[1 : end .!= dims] for i in inputs]
23+ no_duplicate_keys = (length (inputs) == 1 || isempty (intersect (keys .(axes_to_merge)... )))
24+ if no_duplicate_keys && length (Set (rest_axes)) == 1
25+ offsets = cumsum (size .(inputs, 1 ) .- size (first (inputs), 1 ))
26+ merged_axis = Axis (merge (indexmap .(reindex .(axes_to_merge, offsets))... ))
27+ result_axes = (first (rest_axes)[1 : (dims - 1 )]. .. , merged_axis, first (rest_axes)[dims: end ]. .. )
28+ return ComponentArray (combined_data, result_axes... )
2429 else
25- data_x, data_y = getdata .((x, y))
26- ax_y = reindex (ax_y, size (x,2 ))
27- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
28- axs = getaxes (x)
29- return ComponentArray (hcat (data_x, data_y), axs[1 ], Axis ((;idxmap_x... , idxmap_y... )), axs[3 : end ]. .. )
30+ return combined_data
3031 end
3132end
3233
33- second_axis (ca:: AbstractComponentVecOrMat ) = getaxes (ca)[2 ]
34- second_axis (:: ComponentVector ) = FlatAxis ()
35-
36- # Are all these methods necessary?
37- # TODO : See what we can reduce down to without getting ambiguity errors
38- Base. vcat (x:: ComponentVector , y:: AbstractVector ) = vcat (getdata (x), y)
39- Base. vcat (x:: AbstractVector , y:: ComponentVector ) = vcat (x, getdata (y))
40- function Base. vcat (x:: ComponentVector , y:: ComponentVector )
41- if reduce ((accum, key) -> accum || (key in keys (x)), keys (y); init= false )
42- return vcat (getdata (x), getdata (y))
43- else
44- data_x, data_y = getdata .((x, y))
45- ax_x, ax_y = getindex .(getaxes .((x, y)), 1 )
46- ax_y = reindex (ax_y, length (x))
47- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
48- return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )))
49- end
34+ function Base. _typed_hcat (:: Type{T} , inputs:: Base.AbstractVecOrTuple{ComponentArray} ) where {T}
35+ return Base. cat (map (i -> T .(i), inputs)... ; dims= 2 )
5036end
51- function Base. vcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
52- ax_x, ax_y = getindex .(getaxes .((x, y)), 1 )
53- if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[2 : end ] != getaxes (y)[2 : end ]
54- return vcat (getdata (x), getdata (y))
55- else
56- data_x, data_y = getdata .((x, y))
57- ax_y = reindex (ax_y, size (x,1 ))
58- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
59- return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )), getaxes (x)[2 : end ]. .. )
60- end
37+ function Base. _typed_vcat (:: Type{T} , inputs:: Base.AbstractVecOrTuple{ComponentArray} ) where {T}
38+ return Base. cat (map (i -> T .(i), inputs)... ; dims= 1 )
6139end
62- Base. vcat (x:: CV... ) where {CV<: AdjOrTransComponentArray } = ComponentArray (reduce (vcat, map (y-> getdata (y. parent)' , x)), getaxes (x[1 ]))
63- Base. vcat (x:: ComponentVector , args... ) = vcat (getdata (x), getdata .(args)... )
64- Base. vcat (x:: ComponentVector , args:: Union{Number, UniformScaling, AbstractVecOrMat} ...) = vcat (getdata (x), getdata .(args)... )
65- Base. vcat (x:: ComponentVector , args:: Vararg{AbstractVector{T}, N} ) where {T,N} = vcat (getdata (x), getdata .(args)... )
6640
6741function Base. hvcat (row_lengths:: NTuple{N,Int} , xs:: AbstractComponentVecOrMat... ) where {N}
6842 i = 1
145119Base. stride (x:: ComponentArray , k) = stride (getdata (x), k)
146120Base. stride (x:: ComponentArray , k:: Int64 ) = stride (getdata (x), k)
147121
148- ArrayInterfaceCore. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
122+ ArrayInterfaceCore. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
0 commit comments