Skip to content

Commit b2ec6ff

Browse files
committed
Simplify branching and fix type instability in setproperty!
1 parent b8aec5e commit b2ec6ff

File tree

1 file changed

+42
-19
lines changed

1 file changed

+42
-19
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -308,31 +308,54 @@ function SciMLBase.solve!(
308308
)
309309
end
310310

311-
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
312-
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray)
311+
function setA!(dc::DualLinearCache, A)
312+
# Put the Dual-stripped versions in the LinearCache
313+
prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place
314+
setproperty!(dc.linear_cache, :A, prop) # Does additional invalidation logic etc.
315+
316+
# Update partials
317+
setfield!(dc, :dual_A, A)
318+
partial_vals!(getfield(dc, :partials_A), A) # Update in-place
319+
320+
# Invalidate cache (if setting A or b)
321+
setfield!(dc, :rhs_cache_valid, false)
322+
end
323+
function setb!(dc::DualLinearCache, b)
324+
# Put the Dual-stripped versions in the LinearCache
325+
prop = nodual_value!(getproperty(dc.linear_cache, :b), b) # Update in-place
326+
setproperty!(dc.linear_cache, :b, prop) # Does additional invalidation logic etc.
327+
328+
# Update partials
329+
setfield!(dc, :dual_b, b)
330+
partial_vals!(getfield(dc, :partials_b), b) # Update in-place
331+
332+
# Invalidate cache (if setting A or b)
333+
setfield!(dc, :rhs_cache_valid, false)
334+
end
335+
function setu!(dc::DualLinearCache, u)
336+
# Put the Dual-stripped versions in the LinearCache
337+
prop = nodual_value!(getproperty(dc.linear_cache, :u), u) # Update in-place
338+
setproperty!(dc.linear_cache, :u, prop) # Does additional invalidation logic etc.
339+
340+
# Update partials
341+
setfield!(dc, :dual_u, u)
342+
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
343+
end
344+
345+
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
313346
# If the property is A or b, also update it in the LinearCache
314-
if sym === :A || sym === :b || sym === :u
315-
prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place
316-
setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc.
347+
if sym === :A
348+
setA!(dc, val)
349+
elseif sym === :b
350+
setb!(dc, val)
351+
elseif sym === :u
352+
setu!(dc, val)
317353
elseif hasfield(DualLinearCache, sym)
318354
setfield!(dc, sym, val)
319355
elseif hasfield(LinearSolve.LinearCache, sym)
320356
setproperty!(dc.linear_cache, sym, val)
321357
end
322-
323-
# Update the partials and invalidate cache if setting A or b
324-
if sym === :A
325-
setfield!(dc, :dual_A, val)
326-
partial_vals!(getfield(dc, :partials_A), val) # Update in-place
327-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
328-
elseif sym === :b
329-
setfield!(dc, :dual_b, val)
330-
partial_vals!(getfield(dc, :partials_b), val) # Update in-place
331-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
332-
elseif sym === :u
333-
setfield!(dc, :dual_u, val)
334-
partial_vals!(getfield(dc, :partials_u), val) # Update in-place
335-
end
358+
nothing
336359
end
337360

338361
# "Forwards" getproperty to LinearCache if necessary

0 commit comments

Comments
 (0)