@@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
3030 (dval. b for dval in dres)
3131 end
3232
33- return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b))
33+
34+ prob_d_A = if EnzymeRules. width (config) == 1
35+ prob. dval. A
36+ else
37+ (dval. A for dval in prob. dval)
38+ end
39+ prob_d_b = if EnzymeRules. width (config) == 1
40+ prob. dval. b
41+ else
42+ (dval. b for dval in prob. dval)
43+ end
44+
45+ return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
3446end
3547
3648function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
37- d_A, d_b = cache
49+ d_A, d_b, prob_d_A, prob_d_b = cache
3850
3951 if EnzymeRules. width (config) == 1
40- if d_A != = prob . dval . A
41- prob . dval . A .+ = d_A
52+ if d_A != = prob_d_A
53+ prob_d_A .+ = d_A
4254 d_A .= 0
4355 end
44- if d_b != = prob . dval . b
45- prob . dval . b .+ = d_b
56+ if d_b != = prob_d_b
57+ prob_d_b .+ = d_b
4658 d_b .= 0
4759 end
4860 else
4961 for i in 1 : EnzymeRules. width (config)
50- if d_A != = prob . dval . A
51- prob . dval . A [i] .+ = d_A[i]
62+ if d_A != = prob_d_A[i]
63+ prob_d_A [i] .+ = d_A[i]
5264 d_A[i] .= 0
5365 end
54- if d_b != = prob . dval . b
55- prob . dval . b [i] .+ = d_b[i]
66+ if d_b != = prob_d_b[i]
67+ prob_d_b [i] .+ = d_b[i]
5668 d_b[i] .= 0
5769 end
5870 end
@@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8799 resvals = if EnzymeRules. width (config) == 1
88100 dres. u
89101 else
90- (dr. u for dr in dres)
102+ ntuple (Val (EnzymeRules. width (config))) do i
103+ Base. @_inline_meta
104+ dres[i]. u
105+ end
91106 end
92107
93108 dAs = if EnzymeRules. width (config) == 1
94109 (linsolve. dval. A,)
95110 else
96- (dval. A for dval in linsolve. dval)
111+ ntuple (Val (EnzymeRules. width (config))) do i
112+ Base. @_inline_meta
113+ linsolve. dval[i]. A
114+ end
97115 end
98116
99117 dbs = if EnzymeRules. width (config) == 1
100118 (linsolve. dval. b,)
101119 else
102- (dval. b for dval in linsolve. dval)
120+ ntuple (Val (EnzymeRules. width (config))) do i
121+ Base. @_inline_meta
122+ linsolve. dval[i]. b
123+ end
103124 end
104125
105- cache = (res, resvals, deepcopy (linsolve. val), dAs, dbs)
126+ cachesolve = deepcopy (linsolve. val)
127+
128+ cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
106129 return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, cache)
107130end
108131
0 commit comments