Skip to content

Commit 212b433

Browse files
committed
Refactor out-of-place nodual_value with new in-place dispatch
1 parent 0388574 commit 212b433

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,21 +367,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
367367
nodual_value(x) = x
368368
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
369369
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
370+
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
370371
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place
371372

372-
function nodual_value(x::AbstractArray{<:Dual})
373-
# Create a similar array with the appropriate element type
374-
T = typeof(nodual_value(first(x)))
375-
result = similar(x, T)
376-
377-
# Fill the result array with values
378-
for i in eachindex(x)
379-
result[i] = nodual_value(x[i])
380-
end
381-
382-
return result
383-
end
384-
385373
function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
386374
p = eachindex(first(partial_matrix))
387375
for i in p

0 commit comments

Comments
 (0)