@@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55 DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66 needs_concrete_A,
77 error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8- CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
8+ CudaOffloadFactorization, CudaOffloadLUFactorization,
9+ CudaOffloadQRFactorization,
910 CUDAOffload32MixedLUFactorization,
10- SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
11+ SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12+ LinearVerbosity
1113using LinearSolve. LinearAlgebra, LinearSolve. SciMLBase, LinearSolve. ArrayInterface
1214using SciMLBase: AbstractSciMLOperator
1315
@@ -56,14 +58,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
5658 SciMLBase. build_linear_solution (alg, y, nothing , cache)
5759end
5860
59- function LinearSolve. init_cacheval (alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
61+ function LinearSolve. init_cacheval (
62+ alg:: CudaOffloadLUFactorization , A:: AbstractArray , b, u, Pl, Pr,
6063 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
6164 assumptions:: OperatorAssumptions )
6265 # Check if CUDA is functional before creating CUDA arrays
6366 if ! CUDA. functional ()
6467 return nothing
6568 end
66-
69+
6770 T = eltype (A)
6871 noUnitT = typeof (zero (T))
6972 luT = LinearAlgebra. lutype (noUnitT)
@@ -91,7 +94,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
9194 if ! CUDA. functional ()
9295 return nothing
9396 end
94-
97+
9598 qr (CUDA. CuArray (A))
9699end
97100
@@ -108,35 +111,42 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
108111 SciMLBase. build_linear_solution (alg, y, nothing , cache)
109112end
110113
111- function LinearSolve. init_cacheval (alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
114+ function LinearSolve. init_cacheval (
115+ alg:: CudaOffloadFactorization , A:: AbstractArray , b, u, Pl, Pr,
112116 maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} ,
113117 assumptions:: OperatorAssumptions )
114118 qr (CUDA. CuArray (A))
115119end
116120
117121function LinearSolve. init_cacheval (
118122 :: SparspakFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
119- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
123+ Pl, Pr, maxiters:: Int , abstol, reltol,
124+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
120125 nothing
121126end
122127
123128function LinearSolve. init_cacheval (
124129 :: KLUFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
125- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
130+ Pl, Pr, maxiters:: Int , abstol, reltol,
131+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
126132 nothing
127133end
128134
129135function LinearSolve. init_cacheval (
130136 :: UMFPACKFactorization , A:: CUDA.CUSPARSE.CuSparseMatrixCSR , b, u,
131- Pl, Pr, maxiters:: Int , abstol, reltol, verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
137+ Pl, Pr, maxiters:: Int , abstol, reltol,
138+ verbose:: Union{LinearVerbosity, Bool} , assumptions:: OperatorAssumptions )
132139 nothing
133140end
134141
135142# Mixed precision CUDA LU implementation
136- function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
143+ function SciMLBase. solve! (
144+ cache:: LinearSolve.LinearCache , alg:: CUDAOffload32MixedLUFactorization ;
137145 kwargs... )
138146 if cache. isfresh
139- fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
147+ fact, A_gpu_f32,
148+ b_gpu_f32,
149+ u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
140150 # Compute 32-bit type on demand and convert
141151 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
142152 A_f32 = T32 .(cache. A)
@@ -145,12 +155,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe
145155 cache. cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
146156 cache. isfresh = false
147157 end
148- fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
149-
158+ fact, A_gpu_f32,
159+ b_gpu_f32,
160+ u_gpu_f32 = LinearSolve. @get_cacheval (cache, :CUDAOffload32MixedLUFactorization )
161+
150162 # Compute types on demand for conversions
151163 T32 = eltype (cache. A) <: Complex ? ComplexF32 : Float32
152164 Torig = eltype (cache. u)
153-
165+
154166 # Convert b to Float32, solve, then convert back to original precision
155167 b_f32 = T32 .(cache. b)
156168 copyto! (b_gpu_f32, b_f32)
0 commit comments