Skip to content

Commit 7f15ace

Browse files
committed
Opt out from generig dual number handling for SparspakFactorization
Sparspak handles sparse linear system solution for generic number types including dual numbers
1 parent 0c0f1e4 commit 7f15ace

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza
200200
return __init(prob, alg, args...; kwargs...)
201201
end
202202

203+
# Opt out for SparspakFactorization
204+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SparspakFactorization, args...; kwargs...)
205+
return __init(prob, alg, args...; kwargs...)
206+
end
207+
203208
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
204209
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
205210
return __init(prob, alg, args...; kwargs...)

test/forwarddiff_overloads.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ForwardDiff
33
using Test
44
using SparseArrays
55
using ComponentArrays
6+
using Sparspak
67

78
function h(p)
89
(A = [p[1] p[2]+1 p[2]^3;
@@ -188,7 +189,6 @@ backslash_x_p = A \ b
188189

189190
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
190191

191-
192192
# Test that GenericLU doesn't create a DualLinearCache
193193
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
194194

@@ -197,6 +197,12 @@ prob = LinearProblem(A, b)
197197

198198
@test init(prob) isa LinearSolve.LinearCache
199199

200+
# Test that SparspakFactorization doesn't create a DualLinearCache
201+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
202+
203+
prob = LinearProblem(sparse(A), b)
204+
@test init(prob, SparspakFactorization()) isa LinearSolve.LinearCache
205+
200206
# Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110)
201207
# This tests that ArrayInterface.restructure preserves ComponentArray structure
202208

@@ -234,4 +240,4 @@ grad = ForwardDiff.gradient(component_linsolve, p_test)
234240
@test grad isa Vector
235241
@test length(grad) == 2
236242
@test !any(isnan, grad)
237-
@test !any(isinf, grad)
243+
@test !any(isinf, grad)

0 commit comments

Comments
 (0)