33
44export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCOO,
55 CuSparseMatrix, AbstractCuSparseMatrix,
6+ CuSparseArrayCSR,
67 CuSparseVector,
78 CuSparseVecOrMat
89
141142
142143CuSparseMatrixCOO (A:: CuSparseMatrixCOO ) = A
143144
145+ mutable struct CuSparseArrayCSR{Tv, Ti, N} <: AbstractCuSparseArray{Tv, Ti, N}
146+ rowPtr:: CuArray{Ti}
147+ colVal:: CuArray{Ti}
148+ nzVal:: CuArray{Tv}
149+ dims:: NTuple{N,Int}
150+ nnz:: Ti
151+
152+ function CuSparseArrayCSR {Tv, Ti, N} (rowPtr:: CuArray{<:Integer, M} , colVal:: CuArray{<:Integer, M} , nzVal:: CuArray{Tv, M} , dims:: NTuple{N,<:Integer} ) where {Tv, Ti<: Integer , M, N}
153+ @assert M == N - 1 " CuSparseArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
154+ new {Tv, Ti, N} (rowPtr, colVal, nzVal, dims, length (nzVal))
155+ end
156+ end
157+
158+ CuSparseArrayCSR (A:: CuSparseArrayCSR ) = A
159+
160+ function CUDA. unsafe_free! (xs:: CuSparseArrayCSR )
161+ unsafe_free! (xs. rowPtr)
162+ unsafe_free! (xs. colVal)
163+ unsafe_free! (nonzeros (xs))
164+ return
165+ end
166+
167+ # broadcast over batch-dim if batchsize==1
168+ ptrstride (A:: CuSparseArrayCSR ) = size (A. rowPtr, 2 ) > 1 ? stride (A. rowPtr, 2 ) : 0
169+ valstride (A:: CuSparseArrayCSR ) = size (A. nzVal, 2 ) > 1 ? stride (A. nzVal, 2 ) : 0
170+
144171"""
145172Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref),
146173[`CuSparseMatrixBSR`](@ref), [`CuSparseMatrixCOO`](@ref).
@@ -154,7 +181,6 @@ const CuSparseMatrix{Tv, Ti} = Union{
154181
155182const CuSparseVecOrMat = Union{CuSparseVector,CuSparseMatrix}
156183
157-
158184# NOTE: we use Cint as default Ti on CUDA instead of Int to provide
159185# maximum compatiblity to old CUSPARSE APIs
160186function CuSparseVector {Tv} (iPtr:: CuVector{<:Integer} , nzVal:: CuVector , len:: Integer ) where {Tv}
@@ -183,6 +209,11 @@ function CuSparseMatrixCOO{Tv}(rowInd::CuVector{<:Integer}, colInd::CuVector{<:I
183209 CuSparseMatrixCOO {Tv, Cint} (rowInd,colInd,nzVal,dims,nnz)
184210end
185211
212+ function CuSparseArrayCSR {Tv} (rowPtr:: CuArray{<:Integer, M} , colVal:: CuArray{<:Integer, M} ,
213+ nzVal:: CuArray{Tv, M} , dims:: NTuple{N,<:Integer} ) where {Tv, M, N}
214+ CuSparseArrayCSR {Tv, Cint, N} (rowPtr, colVal, nzVal, dims)
215+ end
216+
186217# # convenience constructors
187218CuSparseVector (iPtr:: DenseCuArray{<:Integer} , nzVal:: DenseCuArray{T} , len:: Integer ) where {T} =
188219 CuSparseVector {T} (iPtr, nzVal, len)
@@ -201,6 +232,9 @@ CuSparseMatrixBSR(rowPtr::DenseCuArray, colVal::DenseCuArray, nzVal::DenseCuArra
201232CuSparseMatrixCOO (rowInd:: DenseCuArray , colInd:: DenseCuArray , nzVal:: DenseCuArray{T} , dims:: NTuple{2,<:Integer} , nnz:: Integer = length (nzVal)) where T =
202233 CuSparseMatrixCOO {T} (rowInd, colInd, nzVal, dims, nnz)
203234
235+ CuSparseArrayCSR (rowPtr:: DenseCuArray , colVal:: DenseCuArray , nzVal:: DenseCuArray{T} , dims:: NTuple{N,<:Integer} ) where {T,N} =
236+ CuSparseArrayCSR {T} (rowPtr, colVal, nzVal, dims)
237+
204238Base. similar (Vec:: CuSparseVector ) = CuSparseVector (copy (nonzeroinds (Vec)), similar (nonzeros (Vec)), length (Vec))
205239Base. similar (Mat:: CuSparseMatrixCSC ) = CuSparseMatrixCSC (copy (Mat. colPtr), copy (rowvals (Mat)), similar (nonzeros (Mat)), size (Mat))
206240Base. similar (Mat:: CuSparseMatrixCSR ) = CuSparseMatrixCSR (copy (Mat. rowPtr), copy (Mat. colVal), similar (nonzeros (Mat)), size (Mat))
@@ -216,6 +250,7 @@ Base.similar(Mat::CuSparseMatrixCOO, T::Type) = CuSparseMatrixCOO(copy(Mat.rowIn
216250Base. similar (Mat:: CuSparseMatrixCSC , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCSC (CuVector {Int32} (undef, M+ 1 ), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
217251Base. similar (Mat:: CuSparseMatrixCSR , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCSR (CuVector {Int32} (undef, N+ 1 ), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
218252Base. similar (Mat:: CuSparseMatrixCOO , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCOO (CuVector {Int32} (undef, nnz (Mat)), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
253+ Base. similar (Mat:: CuSparseArrayCSR ) = CuSparseArrayCSR (copy (Mat. rowPtr), copy (Mat. colVal), similar (nonzeros (Mat)), size (Mat))
219254
220255# # array interface
221256
@@ -225,6 +260,9 @@ Base.size(g::CuSparseVector) = (g.len,)
225260Base. length (g:: CuSparseMatrix ) = prod (g. dims)
226261Base. size (g:: CuSparseMatrix ) = g. dims
227262
263+ Base. length (g:: CuSparseArrayCSR ) = prod (g. dims)
264+ Base. size (g:: CuSparseArrayCSR ) = g. dims
265+
228266function Base. size (g:: CuSparseVector , d:: Integer )
229267 if d == 1
230268 return g. len
@@ -245,6 +283,15 @@ function Base.size(g::CuSparseMatrix, d::Integer)
245283 end
246284end
247285
286+ function Base. size (g:: CuSparseArrayCSR{Tv,Ti,N} , d:: Integer ) where {Tv,Ti,N}
287+ if 1 <= d <= N
288+ return g. dims[d]
289+ elseif d > 1
290+ return 1
291+ else
292+ throw (ArgumentError (" dimension must be ≥ 1, got $d " ))
293+ end
294+ end
248295
249296# # sparse array interface
250297
@@ -348,6 +395,16 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
348395 nonzeros (A)[c1+ block_idx]
349396end
350397
398+ # matrix slices
399+ function Base. getindex (A:: CuSparseArrayCSR{Tv, Ti, N} , :: Colon , :: Colon , idxs:: Integer... ) where {Tv, Ti, N}
400+ @boundscheck checkbounds (A, :, :, idxs... )
401+ CuSparseMatrixCSR (A. rowPtr[:,idxs... ], A. colVal[:,idxs... ], nonzeros (A)[:,idxs... ], size (A)[1 : 2 ])
402+ end
403+
404+ function Base. getindex (A:: CuSparseArrayCSR{Tv, Ti, N} , i0:: Integer , i1:: Integer , idxs:: Integer... ) where {Tv, Ti, N}
405+ @boundscheck checkbounds (A, i0, i1, idxs... )
406+ CuSparseMatrixCSR (A. rowPtr[:,idxs... ], A. colVal[:,idxs... ], nonzeros (A)[:,idxs... ], size (A)[1 : 2 ])[i0, i1]
407+ end
351408
352409# # interop with sparse CPU arrays
353410
@@ -502,7 +559,7 @@ Base.copy(Mat::CuSparseMatrixCSC) = copyto!(similar(Mat), Mat)
502559Base. copy (Mat:: CuSparseMatrixCSR ) = copyto! (similar (Mat), Mat)
503560Base. copy (Mat:: CuSparseMatrixBSR ) = copyto! (similar (Mat), Mat)
504561Base. copy (Mat:: CuSparseMatrixCOO ) = copyto! (similar (Mat), Mat)
505-
562+ Base . copy (Mat :: CuSparseArrayCSR ) = CuSparseArrayCSR ( copy (Mat . rowPtr), copy (Mat . colVal), copy ( nonzeros (Mat)), size (Mat))
506563
507564# input/output
508565
@@ -543,6 +600,24 @@ for (gpu, cpu) in [:CuSparseMatrixCSC => :SparseMatrixCSC,
543600 end
544601end
545602
603+ function Base. show (io:: IOContext , :: MIME"text/plain" , A:: CuSparseArrayCSR )
604+ xnnz = nnz (A)
605+ dims = join (size (A), " ×" )
606+
607+ print (io, dims... , " " , typeof (A), " with " , xnnz, " stored " , xnnz == 1 ? " entry" : " entries" )
608+
609+ if all (size (A) .> 0 )
610+ println (io, " :" )
611+ io = IOContext (io, :typeinfo => eltype (A))
612+ for (k, c) in enumerate (CartesianIndices (size (A)[3 : end ]))
613+ k > 1 && println (io, " \n " )
614+ dims = join (c. I, " , " )
615+ println (io, " [:, :, $dims ] =" )
616+ Base. print_array (io, SparseMatrixCSC (A[:,:,c. I... ]))
617+ end
618+ end
619+ end
620+
546621
547622# interop with device arrays
548623
@@ -590,3 +665,13 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO)
590665 size (x), x. nnz
591666 )
592667end
668+
669+ function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseArrayCSR )
670+ return CuSparseDeviceArrayCSR (
671+ adapt (to, x. rowPtr),
672+ adapt (to, x. colVal),
673+ adapt (to, x. nzVal),
674+ size (x), x. nnz
675+ )
676+ end
677+
0 commit comments