@@ -308,30 +308,54 @@ function SciMLBase.solve!(
308308 )
309309end
310310
311- # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
311+ function setA! (dc:: DualLinearCache , A)
312+ # Put the Dual-stripped versions in the LinearCache
313+ prop = nodual_value! (getproperty (dc. linear_cache, :A ), A) # Update in-place
314+ setproperty! (dc. linear_cache, :A , prop) # Does additional invalidation logic etc.
315+
316+ # Update partials
317+ setfield! (dc, :dual_A , A)
318+ partial_vals! (getfield (dc, :partials_A ), A) # Update in-place
319+
320+ # Invalidate cache (if setting A or b)
321+ setfield! (dc, :rhs_cache_valid , false )
322+ end
323+ function setb! (dc:: DualLinearCache , b)
324+ # Put the Dual-stripped versions in the LinearCache
325+ prop = nodual_value! (getproperty (dc. linear_cache, :b ), b) # Update in-place
326+ setproperty! (dc. linear_cache, :b , prop) # Does additional invalidation logic etc.
327+
328+ # Update partials
329+ setfield! (dc, :dual_b , b)
330+ partial_vals! (getfield (dc, :partials_b ), b) # Update in-place
331+
332+ # Invalidate cache (if setting A or b)
333+ setfield! (dc, :rhs_cache_valid , false )
334+ end
335+ function setu! (dc:: DualLinearCache , u)
336+ # Put the Dual-stripped versions in the LinearCache
337+ prop = nodual_value! (getproperty (dc. linear_cache, :u ), u) # Update in-place
338+ setproperty! (dc. linear_cache, :u , prop) # Does additional invalidation logic etc.
339+
340+ # Update partials
341+ setfield! (dc, :dual_u , u)
342+ partial_vals! (getfield (dc, :partials_u ), u) # Update in-place
343+ end
344+
312345function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
313346 # If the property is A or b, also update it in the LinearCache
314- if sym === :A || sym === :b || sym === :u
315- setproperty! (dc. linear_cache, sym, nodual_value (val))
347+ if sym === :A
348+ setA! (dc, val)
349+ elseif sym === :b
350+ setb! (dc, val)
351+ elseif sym === :u
352+ setu! (dc, val)
316353 elseif hasfield (DualLinearCache, sym)
317354 setfield! (dc, sym, val)
318355 elseif hasfield (LinearSolve. LinearCache, sym)
319356 setproperty! (dc. linear_cache, sym, val)
320357 end
321-
322- # Update the partials and invalidate cache if setting A or b
323- if sym === :A
324- setfield! (dc, :dual_A , val)
325- setfield! (dc, :partials_A , partial_vals (val))
326- setfield! (dc, :rhs_cache_valid , false ) # Invalidate cache
327- elseif sym === :b
328- setfield! (dc, :dual_b , val)
329- setfield! (dc, :partials_b , partial_vals (val))
330- setfield! (dc, :rhs_cache_valid , false ) # Invalidate cache
331- elseif sym === :u
332- setfield! (dc, :dual_u , val)
333- setfield! (dc, :partials_u , partial_vals (val))
334- end
358+ nothing
335359end
336360
337361# "Forwards" getproperty to LinearCache if necessary
@@ -360,30 +384,20 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa
360384partial_vals (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = ForwardDiff. partials (x)
361385partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
362386partial_vals (x) = nothing
387+ partial_vals! (out, x) = map! (partial_vals, out, x) # Update in-place
363388
364389# Add recursive handling for nested dual values
365390nodual_value (x) = x
366391nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
367392nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
368-
369- function nodual_value (x:: AbstractArray{<:Dual} )
370- # Create a similar array with the appropriate element type
371- T = typeof (nodual_value (first (x)))
372- result = similar (x, T)
373-
374- # Fill the result array with values
375- for i in eachindex (x)
376- result[i] = nodual_value (x[i])
377- end
378-
379- return result
380- end
393+ nodual_value (x:: AbstractArray{<:Dual} ) = nodual_value! (similar (x, typeof (nodual_value (first (x)))), x)
394+ nodual_value! (out, x) = map! (nodual_value, out, x) # Update in-place
381395
382396function update_partials_list! (partial_matrix:: AbstractVector{T} , list_cache) where {T}
383397 p = eachindex (first (partial_matrix))
384398 for i in p
385399 for j in eachindex (partial_matrix)
386- list_cache[i][j] = partial_matrix[j][i]
400+ @inbounds list_cache[i][j] = partial_matrix[j][i]
387401 end
388402 end
389403 return list_cache
@@ -396,7 +410,7 @@ function update_partials_list!(partial_matrix, list_cache)
396410 for k in 1 : p
397411 for i in 1 : m
398412 for j in 1 : n
399- list_cache[k][i, j] = partial_matrix[i, j][k]
413+ @inbounds list_cache[k][i, j] = partial_matrix[i, j][k]
400414 end
401415 end
402416 end
0 commit comments