@@ -2,7 +2,7 @@ module LinearSolveForwardDiffExt
22
33using LinearSolve
44using LinearSolve: SciMLLinearSolveAlgorithm, __init, LinearVerbosity, DefaultLinearSolver,
5- DefaultAlgorithmChoice, defaultalg
5+ DefaultAlgorithmChoice, defaultalg, reinit!
66using LinearAlgebra
77using ForwardDiff
88using ForwardDiff: Dual, Partials
@@ -342,6 +342,38 @@ function setu!(dc::DualLinearCache, u)
342342 partial_vals! (getfield (dc, :partials_u ), u) # Update in-place
343343end
344344
345+ function SciMLBase. reinit! (cache:: DualLinearCache ;
346+ A = nothing ,
347+ b = nothing ,
348+ u = nothing ,
349+ p = nothing ,
350+ reuse_precs = false )
351+ if ! isnothing (A)
352+ setA! (cache, A)
353+ end
354+
355+ if ! isnothing (b)
356+ setb! (cache, b)
357+ end
358+
359+ if ! isnothing (u)
360+ setu! (cache, u)
361+ end
362+
363+ if ! isnothing (p)
364+ cache. linear_cache. p= p
365+ end
366+
367+ isfresh = ! isnothing (A)
368+ precsisfresh = ! reuse_precs && (isfresh || ! isnothing (p))
369+ isfresh |= cache. isfresh
370+ precsisfresh |= cache. linear_cache. precsisfresh
371+ cache. linear_cache. isfresh = true
372+ cache. linear_cache. precsisfresh = precsisfresh
373+
374+ nothing
375+ end
376+
345377function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
346378 # If the property is A or b, also update it in the LinearCache
347379 if sym === :A
@@ -390,7 +422,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
390422nodual_value (x) = x
391423nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
392424nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
393- nodual_value (x:: AbstractArray{<:Dual} ) = nodual_value! (similar (x, typeof (nodual_value (first (x)))), x)
425+ function nodual_value (x:: AbstractArray{<:Dual} )
426+ nodual_value! (similar (x, typeof (nodual_value (first (x)))), x)
427+ end
394428nodual_value! (out, x) = map! (nodual_value, out, x) # Update in-place
395429
396430function update_partials_list! (partial_matrix:: AbstractVector{T} , list_cache) where {T}
0 commit comments