Skip to content

Commit 5c87a24

Browse files
Merge pull request #828 from ChrisRackauckas-Claude/fix-sparse-cuda-defaults-827
Fix default algorithm for sparse CUDA matrices to LUFactorization
2 parents bbfa99d + f6c26d8 commit 5c87a24

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66
needs_concrete_A,
77
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8-
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
8+
CudaOffloadFactorization, CudaOffloadLUFactorization,
9+
CudaOffloadQRFactorization,
910
CUDAOffload32MixedLUFactorization,
10-
SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
11+
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12+
LinearVerbosity
1113
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
1214
using SciMLBase: AbstractSciMLOperator
1315

@@ -23,11 +25,16 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2325
if LinearSolve.cudss_loaded(A)
2426
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
2527
else
26-
if !LinearSolve.ALREADY_WARNED_CUDSS[]
27-
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
28-
LinearSolve.ALREADY_WARNED_CUDSS[] = true
29-
end
30-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
28+
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
29+
end
30+
end
31+
32+
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b,
33+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
34+
if LinearSolve.cudss_loaded(A)
35+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
36+
else
37+
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
3138
end
3239
end
3340

@@ -38,6 +45,13 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
3845
nothing
3946
end
4047

48+
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC)
49+
if !LinearSolve.cudss_loaded(A)
50+
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
51+
end
52+
nothing
53+
end
54+
4155
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
4256
kwargs...)
4357
if cache.isfresh
@@ -52,14 +66,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
5266
SciMLBase.build_linear_solution(alg, y, nothing, cache)
5367
end
5468

55-
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
69+
function LinearSolve.init_cacheval(
70+
alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
5671
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
5772
assumptions::OperatorAssumptions)
5873
# Check if CUDA is functional before creating CUDA arrays
5974
if !CUDA.functional()
6075
return nothing
6176
end
62-
77+
6378
T = eltype(A)
6479
noUnitT = typeof(zero(T))
6580
luT = LinearAlgebra.lutype(noUnitT)
@@ -87,7 +102,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
87102
if !CUDA.functional()
88103
return nothing
89104
end
90-
105+
91106
qr(CUDA.CuArray(A))
92107
end
93108

@@ -104,35 +119,42 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
104119
SciMLBase.build_linear_solution(alg, y, nothing, cache)
105120
end
106121

107-
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
122+
function LinearSolve.init_cacheval(
123+
alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
108124
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
109125
assumptions::OperatorAssumptions)
110126
qr(CUDA.CuArray(A))
111127
end
112128

113129
function LinearSolve.init_cacheval(
114130
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
115-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
131+
Pl, Pr, maxiters::Int, abstol, reltol,
132+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
116133
nothing
117134
end
118135

119136
function LinearSolve.init_cacheval(
120137
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
121-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
138+
Pl, Pr, maxiters::Int, abstol, reltol,
139+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
122140
nothing
123141
end
124142

125143
function LinearSolve.init_cacheval(
126144
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
127-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
145+
Pl, Pr, maxiters::Int, abstol, reltol,
146+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
128147
nothing
129148
end
130149

131150
# Mixed precision CUDA LU implementation
132-
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
151+
function SciMLBase.solve!(
152+
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
133153
kwargs...)
134154
if cache.isfresh
135-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
155+
fact, A_gpu_f32,
156+
b_gpu_f32,
157+
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
136158
# Compute 32-bit type on demand and convert
137159
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
138160
A_f32 = T32.(cache.A)
@@ -141,12 +163,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe
141163
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
142164
cache.isfresh = false
143165
end
144-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
145-
166+
fact, A_gpu_f32,
167+
b_gpu_f32,
168+
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
169+
146170
# Compute types on demand for conversions
147171
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
148172
Torig = eltype(cache.u)
149-
173+
150174
# Convert b to Float32, solve, then convert back to original precision
151175
b_f32 = T32.(cache.b)
152176
copyto!(b_gpu_f32, b_f32)

ext/LinearSolveCUDSSExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ using LinearSolve: LinearSolve, cudss_loaded
44
using CUDSS
55

66
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true
7+
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true
78

89
end

0 commit comments

Comments
 (0)