Skip to content

Commit 2447ce5

Browse files
committed
Simplify branching and fix type instability in setproperty!
1 parent b8aec5e commit 2447ce5

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -308,31 +308,37 @@ function SciMLBase.solve!(
308308
)
309309
end
310310

311-
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
312-
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray)
311+
# Generalized function to set A, b or u ("x") for DualLinearCache
312+
function setx(dc::DualLinearCache, x::Symbol, dual_x::Symbol, partials_x::Symbol, val, invalidate::Bool)
313+
# Put the Dual-stripped versions in the LinearCache
314+
prop = nodual_value!(getproperty(dc.linear_cache, x), val) # Update in-place
315+
setproperty!(dc.linear_cache, x, prop) # Does additional invalidation logic etc.
316+
317+
# Update partials
318+
setfield!(dc, dual_x, val)
319+
partial_vals!(getfield(dc, partials_x), val) # Update in-place
320+
321+
# Invalidate cache (if setting A or b)
322+
invalidate && setfield!(dc, :rhs_cache_valid, false)
323+
end
324+
setA!(dc::DualLinearCache, A) = setx(dc, :A, :dual_A, :partials_A, A, true)
325+
setb!(dc::DualLinearCache, b) = setx(dc, :b, :dual_b, :partials_b, b, true)
326+
setu!(dc::DualLinearCache, u) = setx(dc, :u, :dual_u, :partials_u, u, false)
327+
328+
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
313329
# If the property is A or b, also update it in the LinearCache
314-
if sym === :A || sym === :b || sym === :u
315-
prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place
316-
setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc.
330+
if sym === :A
331+
setA!(dc, val)
332+
elseif sym === :b
333+
setb!(dc, val)
334+
elseif sym === :u
335+
setu!(dc, val)
317336
elseif hasfield(DualLinearCache, sym)
318337
setfield!(dc, sym, val)
319338
elseif hasfield(LinearSolve.LinearCache, sym)
320339
setproperty!(dc.linear_cache, sym, val)
321340
end
322-
323-
# Update the partials and invalidate cache if setting A or b
324-
if sym === :A
325-
setfield!(dc, :dual_A, val)
326-
partial_vals!(getfield(dc, :partials_A), val) # Update in-place
327-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
328-
elseif sym === :b
329-
setfield!(dc, :dual_b, val)
330-
partial_vals!(getfield(dc, :partials_b), val) # Update in-place
331-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
332-
elseif sym === :u
333-
setfield!(dc, :dual_u, val)
334-
partial_vals!(getfield(dc, :partials_u), val) # Update in-place
335-
end
341+
nothing
336342
end
337343

338344
# "Forwards" getproperty to LinearCache if necessary

0 commit comments

Comments
 (0)