@@ -6,7 +6,6 @@ using CommonSolve: solve
66using DifferentiationInterface: DifferentiationInterface
77using FastClosures: @closure
88using ForwardDiff: ForwardDiff, Dual
9- using LinearAlgebra: mul!
109using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1110 NonlinearProblem, NonlinearLeastSquaresProblem, remake
1211
@@ -20,11 +19,14 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
2019end
2120
2221Utils. value (:: Type{Dual{T, V, N}} ) where {T, V, N} = V
23- Utils. value (x:: Dual ) = Utils . value ( ForwardDiff. value (x) )
22+ Utils. value (x:: Dual ) = ForwardDiff. value (x)
2423Utils. value (x:: AbstractArray{<:Dual} ) = Utils. value .(x)
2524
2625function NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
27- prob:: Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem} ,
26+ prob:: Union {
27+ IntervalNonlinearProblem, NonlinearProblem,
28+ ImmutableNonlinearProblem, NonlinearLeastSquaresProblem
29+ },
2830 alg, args... ; kwargs...
2931)
3032 p = Utils. value (prob. p)
@@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
3537 newprob = remake (prob; p, u0 = Utils. value (prob. u0))
3638 end
3739
38- sol = solve (newprob, alg, args... ; kwargs... )
39-
40- uu = sol. u
41- Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, prob. f, uu, p)
42- Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, prob. f, uu, p)
43- z = - Jᵤ \ Jₚ
44- pp = prob. p
45- sumfun = ((z, p),) -> map (Base. Fix2 (* , ForwardDiff. partials (p)), z)
46-
47- if uu isa Number
48- partials = sum (sumfun, zip (z, pp))
49- elseif p isa Number
50- partials = sumfun ((z, pp))
51- else
52- partials = sum (sumfun, zip (eachcol (z), pp))
53- end
54-
55- return sol, partials
56- end
57-
58- function NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
59- prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs...
60- )
61- p = Utils. value (prob. p)
62- newprob = remake (prob; p, u0 = Utils. value (prob. u0))
6340 sol = solve (newprob, alg, args... ; kwargs... )
6441 uu = sol. u
6542
66- # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
67- # nested autodiff as the last resort
68- if SciMLBase. has_vjp (prob. f)
69- if SciMLBase. isinplace (prob)
70- vjp_fn = @closure (du, u, p) -> begin
71- resid = Utils. safe_similar (du, length (sol. resid))
72- prob. f (resid, u, p)
73- prob. f. vjp (du, resid, u, p)
74- du .*= 2
75- return nothing
76- end
77- else
78- vjp_fn = @closure (u, p) -> begin
79- resid = prob. f (u, p)
80- return reshape (2 .* prob. f. vjp (resid, u, p), size (u))
81- end
82- end
83- elseif SciMLBase. has_jac (prob. f)
84- if SciMLBase. isinplace (prob)
85- vjp_fn = @closure (du, u, p) -> begin
86- J = Utils. safe_similar (du, length (sol. resid), length (u))
87- prob. f. jac (J, u, p)
88- resid = Utils. safe_similar (du, length (sol. resid))
89- prob. f (resid, u, p)
90- mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
91- return nothing
92- end
93- else
94- vjp_fn = @closure (u, p) -> begin
95- return reshape (2 .* vec (prob. f (u, p))' * prob. f. jac (u, p), size (u))
96- end
97- end
98- else
99- # For small problems, nesting ForwardDiff is actually quite fast
100- autodiff = length (uu) + length (sol. resid) ≥ 50 ?
101- NonlinearSolveBase. select_reverse_mode_autodiff (prob, nothing ) :
102- AutoForwardDiff ()
103-
104- if SciMLBase. isinplace (prob)
105- vjp_fn = @closure (du, u, p) -> begin
106- resid = Utils. safe_similar (du, length (sol. resid))
107- prob. f (resid, u, p)
108- # Using `Constant` lead to dual ordering issues
109- ff = @closure (du, u) -> prob. f (du, u, p)
110- resid2 = copy (resid)
111- DI. pullback! (ff, resid2, (du,), autodiff, u, (resid,))
112- @. du *= 2
113- return nothing
114- end
115- else
116- vjp_fn = @closure (u, p) -> begin
117- v = prob. f (u, p)
118- # Using `Constant` lead to dual ordering issues
119- ff = Base. Fix2 (prob. f, p)
120- res = only (DI. pullback (ff, autodiff, u, (v,)))
121- ArrayInterface. can_setindex (res) || return 2 .* res
122- @. res *= 2
123- return res
124- end
125- end
126- end
43+ fn = prob isa NonlinearLeastSquaresProblem ?
44+ NonlinearSolveBase. nlls_generate_vjp_function (prob, sol, uu) : prob. f
12745
128- Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, vjp_fn , uu, newprob . p)
129- Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, vjp_fn , uu, newprob . p)
46+ Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, fn , uu, p)
47+ Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, fn , uu, p)
13048 z = - Jᵤ \ Jₚ
13149 pp = prob. p
13250 sumfun = ((z, p),) -> map (Base. Fix2 (* , ForwardDiff. partials (p)), z)
0 commit comments