Skip to content

Commit aca3737

Browse files
Fix default algorithm for sparse CUDA matrices to LUFactorization
Previously: - CuSparseMatrixCSC defaulted to GenericLU (not GPU compatible) - CuSparseMatrixCSR fell back to Krylov when CUDSS unavailable (not GPU compatible) - Symmetric sparse matrices defaulted to Cholesky (not GPU compatible) Changes: - Both CuSparseMatrixCSC and CuSparseMatrixCSR now default to LUFactorization - Removed Krylov fallback - if CUDSS is not loaded, clear error message is shown - Added cudss_loaded() support for CuSparseMatrixCSC - Added error_no_cudss_lu() for CuSparseMatrixCSC Fixes #827 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent bbfa99d commit aca3737

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@ end
2020

2121
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2222
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
23-
if LinearSolve.cudss_loaded(A)
24-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
25-
else
26-
if !LinearSolve.ALREADY_WARNED_CUDSS[]
27-
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
28-
LinearSolve.ALREADY_WARNED_CUDSS[] = true
29-
end
30-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
31-
end
23+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
24+
end
25+
26+
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b,
27+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
28+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
3229
end
3330

3431
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
@@ -38,6 +35,13 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
3835
nothing
3936
end
4037

38+
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC)
39+
if !LinearSolve.cudss_loaded(A)
40+
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
41+
end
42+
nothing
43+
end
44+
4145
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
4246
kwargs...)
4347
if cache.isfresh

ext/LinearSolveCUDSSExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ using LinearSolve: LinearSolve, cudss_loaded
44
using CUDSS
55

66
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true
7+
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true
78

89
end

0 commit comments

Comments
 (0)