Skip to content

Commit 81b4b4c

Browse files
committed
Attempt the implementation of reinit! for dual cache
1 parent 0c0f1e4 commit 81b4b4c

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt
22

33
using LinearSolve
44
using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5-
DefaultAlgorithmChoice, defaultalg
5+
DefaultAlgorithmChoice, defaultalg, reinit!
66
using LinearAlgebra
77
using ForwardDiff
88
using ForwardDiff: Dual, Partials
@@ -342,6 +342,38 @@ function setu!(dc::DualLinearCache, u)
342342
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
343343
end
344344

345+
function SciMLBase.reinit!(cache::DualLinearCache;
346+
A = nothing,
347+
b = nothing,
348+
u = nothing,
349+
p = nothing,
350+
reuse_precs = false)
351+
if !isnothing(A)
352+
setA!(cache, A)
353+
end
354+
355+
if !isnothing(b)
356+
setb!(cache, b)
357+
end
358+
359+
if !isnothing(u)
360+
setu!(cache, u)
361+
end
362+
363+
if !isnothing(p)
364+
cache.linear_cache.p=p
365+
end
366+
367+
isfresh = !isnothing(A)
368+
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
369+
isfresh |= cache.isfresh
370+
precsisfresh |= cache.linear_cache.precsisfresh
371+
cache.linear_cache.isfresh = true
372+
cache.linear_cache.precsisfresh = precsisfresh
373+
374+
nothing
375+
end
376+
345377
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
346378
# If the property is A or b, also update it in the LinearCache
347379
if sym === :A
@@ -390,7 +422,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
390422
nodual_value(x) = x
391423
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
392424
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
393-
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
425+
function nodual_value(x::AbstractArray{<:Dual})
426+
nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
427+
end
394428
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place
395429

396430
function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ backslash_x_p = A \ b
188188

189189
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
190190

191+
A[1, 1]+=2
192+
cache = overload_x_p.cache
193+
reinit!(cache; A = sparse(A))
194+
overload_x_p = solve!(cache, UMFPACKFactorization())
195+
backslash_x_p = A \ b
196+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
191197

192198
# Test that GenericLU doesn't create a DualLinearCache
193199
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
@@ -234,4 +240,4 @@ grad = ForwardDiff.gradient(component_linsolve, p_test)
234240
@test grad isa Vector
235241
@test length(grad) == 2
236242
@test !any(isnan, grad)
237-
@test !any(isinf, grad)
243+
@test !any(isinf, grad)

0 commit comments

Comments
 (0)