@@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55 DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66 needs_concrete_A,
77 error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8- CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
8+ CudaOffloadFactorization, CudaOffloadLUFactorization,
9+ CudaOffloadQRFactorization,
910 CUDAOffload32MixedLUFactorization,
10- SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
11+ SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12+ LinearVerbosity
1113using LinearSolve. LinearAlgebra, LinearSolve. SciMLBase, LinearSolve. ArrayInterface
1214using SciMLBase: AbstractSciMLOperator
1315
@@ -23,11 +25,16 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2325 if LinearSolve. cudss_loaded (A)
2426 LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. LUFactorization)
2527 else
26- if ! LinearSolve. ALREADY_WARNED_CUDSS[]
27- @warn (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov" )
28- LinearSolve. ALREADY_WARNED_CUDSS[] = true
29- end
30- LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. KrylovJL_GMRES)
28+ error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library." )
29+ end
30+ end
31+
32+ function LinearSolve. defaultalg (A:: CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti} , b,
33+ assump:: OperatorAssumptions{Bool} ) where {Tv, Ti}
34+ if LinearSolve. cudss_loaded (A)
35+ LinearSolve. DefaultLinearSolver (LinearSolve. DefaultAlgorithmChoice. LUFactorization)
36+ else
37+ error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library." )
3138 end
3239end
3340
@@ -38,6 +45,13 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
3845 nothing
3946end
4047
48+ function LinearSolve. error_no_cudss_lu (A:: CUDA.CUSPARSE.CuSparseMatrixCSC )
49+ if ! LinearSolve. cudss_loaded (A)
50+ error (" CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library." )
51+ end
52+ nothing
53+ end
54+
4155function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CudaOffloadLUFactorization ;
4256 kwargs... )
4357 if cache. isfresh
@@ -52,14 +66,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
5266 SciMLBase. build_linear_solution (alg, y, nothing , cache)
5367end
5468
55- function LinearSolve. init_cacheval (alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
69+ function LinearSolve. init_cacheval (
70+ alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
5671 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
5772 assumptions:: OperatorAssumptions )
5873 # Check if CUDA is functional before creating CUDA arrays
5974 if ! CUDA. functional ()
6075 return nothing
6176 end
62-
77+
6378 T = eltype (A)
6479 noUnitT = typeof (zero (T))
6580 luT = LinearAlgebra. lutype (noUnitT)
@@ -87,7 +102,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
87102 if ! CUDA. functional ()
88103 return nothing
89104 end
90-
105+
91106 qr (CUDA. CuArray (A))
92107end
93108
@@ -104,35 +119,42 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
104119 SciMLBase. build_linear_solution (alg, y, nothing , cache)
105120end
106121
107- function LinearSolve. init_cacheval (alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
122+ function LinearSolve. init_cacheval (
123+ alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
108124 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
109125 assumptions:: OperatorAssumptions )
110126 qr (CUDA. CuArray (A))
111127end
112128
113129function LinearSolve. init_cacheval (
114130 :: SparspakFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
115- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
131+ Pl, Pr, maxiters:: Int , abstol, reltol,
132+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
116133 nothing
117134end
118135
119136function LinearSolve. init_cacheval (
120137 :: KLUFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
121- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
138+ Pl, Pr, maxiters:: Int , abstol, reltol,
139+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
122140 nothing
123141end
124142
125143function LinearSolve. init_cacheval (
126144 :: UMFPACKFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
127- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
145+ Pl, Pr, maxiters:: Int , abstol, reltol,
146+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
128147 nothing
129148end
130149
131150# Mixed precision CUDA LU implementation
132- function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
151+ function SciMLBase. solve! (
152+ cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
133153 kwargs... )
134154 if cache. isfresh
135- fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
155+ fact, A_gpu_f32,
156+ b_gpu_f32,
157+ u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
136158 # Compute 32-bit type on demand and convert
137159 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
138160 A_f32 = T32 .(cache. A)
@@ -141,12 +163,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe
141163 cache. cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
142164 cache. isfresh = false
143165 end
144- fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
145-
166+ fact, A_gpu_f32,
167+ b_gpu_f32,
168+ u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
169+
146170 # Compute types on demand for conversions
147171 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
148172 Torig = eltype (cache. u)
149-
173+
150174 # Convert b to Float32, solve, then convert back to original precision
151175 b_f32 = T32 .(cache. b)
152176 copyto! (b_gpu_f32, b_f32)
0 commit comments