|
1 | 1 | module LinearSolveCUDAExt |
2 | 2 |
|
3 | | -using CUDA, LinearAlgebra, LinearSolve, SciMLBase |
| 3 | +using CUDA |
| 4 | +using LinearSolve |
| 5 | +using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
4 | 6 | using SciMLBase: AbstractSciMLOperator |
5 | 7 |
|
6 | 8 | function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; |
7 | 9 | kwargs...) |
8 | 10 | if cache.isfresh |
9 | | - fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u) |
10 | | - cache = LinearSolve.set_cacheval(cache, fact) |
| 11 | + fact = qr(CUDA.CuArray(cache.A)) |
| 12 | + cache.cacheval = fact |
11 | 13 | cache.isfresh = false |
12 | 14 | end |
13 | | - |
14 | | - copyto!(cache.u, cache.b) |
15 | | - y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u))) |
| 15 | + y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
| 16 | + cache.u .= y |
16 | 17 | SciMLBase.build_linear_solution(alg, y, nothing, cache) |
17 | 18 | end |
18 | 19 |
|
19 | | -function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u) |
20 | | - A isa Union{AbstractMatrix, AbstractSciMLOperator} || |
21 | | - error("LU is not defined for $(typeof(A))") |
22 | | - |
23 | | - if A isa Union{MatrixOperator, DiffEqArrayOperator} |
24 | | - A = A.A |
25 | | - end |
26 | | - |
27 | | - fact = qr(CUDA.CuArray(A)) |
28 | | - return fact |
| 20 | +function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
| 21 | + maxiters::Int, abstol, reltol, verbose::Bool, |
| 22 | + assumptions::OperatorAssumptions) |
| 23 | + qr(CUDA.CuArray(A)) |
29 | 24 | end |
30 | 25 |
|
31 | 26 | end |
0 commit comments