Skip to content

Commit a0668de

Browse files
Apply JuliaFormatter with SciMLStyle to changed files
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent aca3737 commit a0668de

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66
needs_concrete_A,
77
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8-
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
8+
CudaOffloadFactorization, CudaOffloadLUFactorization,
9+
CudaOffloadQRFactorization,
910
CUDAOffload32MixedLUFactorization,
10-
SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
11+
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12+
LinearVerbosity
1113
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
1214
using SciMLBase: AbstractSciMLOperator
1315

@@ -56,14 +58,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
5658
SciMLBase.build_linear_solution(alg, y, nothing, cache)
5759
end
5860

59-
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
61+
function LinearSolve.init_cacheval(
62+
alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
6063
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
6164
assumptions::OperatorAssumptions)
6265
# Check if CUDA is functional before creating CUDA arrays
6366
if !CUDA.functional()
6467
return nothing
6568
end
66-
69+
6770
T = eltype(A)
6871
noUnitT = typeof(zero(T))
6972
luT = LinearAlgebra.lutype(noUnitT)
@@ -91,7 +94,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
9194
if !CUDA.functional()
9295
return nothing
9396
end
94-
97+
9598
qr(CUDA.CuArray(A))
9699
end
97100

@@ -108,35 +111,42 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
108111
SciMLBase.build_linear_solution(alg, y, nothing, cache)
109112
end
110113

111-
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
114+
function LinearSolve.init_cacheval(
115+
alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
112116
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
113117
assumptions::OperatorAssumptions)
114118
qr(CUDA.CuArray(A))
115119
end
116120

117121
function LinearSolve.init_cacheval(
118122
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
119-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
123+
Pl, Pr, maxiters::Int, abstol, reltol,
124+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
120125
nothing
121126
end
122127

123128
function LinearSolve.init_cacheval(
124129
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
125-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
130+
Pl, Pr, maxiters::Int, abstol, reltol,
131+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
126132
nothing
127133
end
128134

129135
function LinearSolve.init_cacheval(
130136
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
131-
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
137+
Pl, Pr, maxiters::Int, abstol, reltol,
138+
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
132139
nothing
133140
end
134141

135142
# Mixed precision CUDA LU implementation
136-
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
143+
function SciMLBase.solve!(
144+
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
137145
kwargs...)
138146
if cache.isfresh
139-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
147+
fact, A_gpu_f32,
148+
b_gpu_f32,
149+
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
140150
# Compute 32-bit type on demand and convert
141151
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
142152
A_f32 = T32.(cache.A)
@@ -145,12 +155,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe
145155
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
146156
cache.isfresh = false
147157
end
148-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
149-
158+
fact, A_gpu_f32,
159+
b_gpu_f32,
160+
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
161+
150162
# Compute types on demand for conversions
151163
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
152164
Torig = eltype(cache.u)
153-
165+
154166
# Convert b to Float32, solve, then convert back to original precision
155167
b_f32 = T32.(cache.b)
156168
copyto!(b_gpu_f32, b_f32)

0 commit comments

Comments
 (0)