Skip to content

Commit e35529a

Browse files
Merge pull request #822 from hersle/less_dual_allocs
Trim allocations in dual linear problem
2 parents 7f6736a + b2ec6ff commit e35529a

File tree

1 file changed

+46
-32
lines changed

1 file changed

+46
-32
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -308,30 +308,54 @@ function SciMLBase.solve!(
308308
)
309309
end
310310

311-
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
311+
function setA!(dc::DualLinearCache, A)
312+
# Put the Dual-stripped versions in the LinearCache
313+
prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place
314+
setproperty!(dc.linear_cache, :A, prop) # Does additional invalidation logic etc.
315+
316+
# Update partials
317+
setfield!(dc, :dual_A, A)
318+
partial_vals!(getfield(dc, :partials_A), A) # Update in-place
319+
320+
# Invalidate cache (if setting A or b)
321+
setfield!(dc, :rhs_cache_valid, false)
322+
end
323+
function setb!(dc::DualLinearCache, b)
324+
# Put the Dual-stripped versions in the LinearCache
325+
prop = nodual_value!(getproperty(dc.linear_cache, :b), b) # Update in-place
326+
setproperty!(dc.linear_cache, :b, prop) # Does additional invalidation logic etc.
327+
328+
# Update partials
329+
setfield!(dc, :dual_b, b)
330+
partial_vals!(getfield(dc, :partials_b), b) # Update in-place
331+
332+
# Invalidate cache (if setting A or b)
333+
setfield!(dc, :rhs_cache_valid, false)
334+
end
335+
function setu!(dc::DualLinearCache, u)
336+
# Put the Dual-stripped versions in the LinearCache
337+
prop = nodual_value!(getproperty(dc.linear_cache, :u), u) # Update in-place
338+
setproperty!(dc.linear_cache, :u, prop) # Does additional invalidation logic etc.
339+
340+
# Update partials
341+
setfield!(dc, :dual_u, u)
342+
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
343+
end
344+
312345
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
313346
# If the property is A or b, also update it in the LinearCache
314-
if sym === :A || sym === :b || sym === :u
315-
setproperty!(dc.linear_cache, sym, nodual_value(val))
347+
if sym === :A
348+
setA!(dc, val)
349+
elseif sym === :b
350+
setb!(dc, val)
351+
elseif sym === :u
352+
setu!(dc, val)
316353
elseif hasfield(DualLinearCache, sym)
317354
setfield!(dc, sym, val)
318355
elseif hasfield(LinearSolve.LinearCache, sym)
319356
setproperty!(dc.linear_cache, sym, val)
320357
end
321-
322-
# Update the partials and invalidate cache if setting A or b
323-
if sym === :A
324-
setfield!(dc, :dual_A, val)
325-
setfield!(dc, :partials_A, partial_vals(val))
326-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
327-
elseif sym === :b
328-
setfield!(dc, :dual_b, val)
329-
setfield!(dc, :partials_b, partial_vals(val))
330-
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
331-
elseif sym === :u
332-
setfield!(dc, :dual_u, val)
333-
setfield!(dc, :partials_u, partial_vals(val))
334-
end
358+
nothing
335359
end
336360

337361
# "Forwards" getproperty to LinearCache if necessary
@@ -360,30 +384,20 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa
360384
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
361385
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
362386
partial_vals(x) = nothing
387+
partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
363388

364389
# Add recursive handling for nested dual values
365390
nodual_value(x) = x
366391
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
367392
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
368-
369-
function nodual_value(x::AbstractArray{<:Dual})
370-
# Create a similar array with the appropriate element type
371-
T = typeof(nodual_value(first(x)))
372-
result = similar(x, T)
373-
374-
# Fill the result array with values
375-
for i in eachindex(x)
376-
result[i] = nodual_value(x[i])
377-
end
378-
379-
return result
380-
end
393+
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
394+
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place
381395

382396
function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
383397
p = eachindex(first(partial_matrix))
384398
for i in p
385399
for j in eachindex(partial_matrix)
386-
list_cache[i][j] = partial_matrix[j][i]
400+
@inbounds list_cache[i][j] = partial_matrix[j][i]
387401
end
388402
end
389403
return list_cache
@@ -396,7 +410,7 @@ function update_partials_list!(partial_matrix, list_cache)
396410
for k in 1:p
397411
for i in 1:m
398412
for j in 1:n
399-
list_cache[k][i, j] = partial_matrix[i, j][k]
413+
@inbounds list_cache[k][i, j] = partial_matrix[i, j][k]
400414
end
401415
end
402416
end

0 commit comments

Comments
 (0)