@@ -2,64 +2,80 @@ module LinearSolveEnzymeExt
22
33using LinearSolve
44using LinearSolve. LinearAlgebra
5- isdefined (Base, :get_extension ) ? (import Enzyme) : (import .. Enzyme)
6-
7- using Enzyme
8-
95using EnzymeCore
6+ using EnzymeCore: EnzymeRules
107
11- function EnzymeCore . EnzymeRules. forward (
8+ function EnzymeRules. forward (config :: EnzymeRules.FwdConfigWidth{1} ,
129 func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , prob:: EnzymeCore.Annotation{LP} ,
1310 alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
1411 @assert ! (prob isa Const)
1512 res = func. val (prob. val, alg. val; kwargs... )
1613 if RT <: Const
17- return res
14+ if EnzymeRules. needs_primal (config)
15+ return res
16+ else
17+ return nothing
18+ end
1819 end
20+
1921 dres = func. val (prob. dval, alg. val; kwargs... )
20- dres. b .= res. b == dres. b ? zero (dres. b) : dres. b
21- dres. A .= res. A == dres. A ? zero (dres. A) : dres. A
22- if RT <: DuplicatedNoNeed
23- return dres
24- elseif RT <: Duplicated
22+
23+ if dres. b == res. b
24+ dres. b .= false
25+ end
26+ if dres. A == res. A
27+ dres. A .= false
28+ end
29+
30+ if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
2531 return Duplicated (res, dres)
32+ elseif EnzymeRules. needs_shadow (config)
33+ return dres
34+ elseif EnzymeRules. needs_primal (config)
35+ return res
36+ else
37+ return nothing
2638 end
27- error (" Unsupported return type $RT " )
2839end
2940
30- function EnzymeCore. EnzymeRules. forward (func:: Const{typeof(LinearSolve.solve!)} ,
41+ function EnzymeRules. forward (
42+ config:: EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
3143 :: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
3244 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
3345 @assert ! (linsolve isa Const)
3446
3547 res = func. val (linsolve. val; kwargs... )
3648
3749 if RT <: Const
38- return res
50+ if EnzymeRules. needs_primal (config)
51+ return res
52+ else
53+ return nothing
54+ end
3955 end
4056 if linsolve. val. alg isa LinearSolve. AbstractKrylovSubspaceMethod
4157 error (" Algorithm $(_linsolve. alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling" )
4258 end
43- b = deepcopy (linsolve. val. b)
4459
45- db = linsolve. dval. b
46- dA = linsolve. dval. A
60+ res = deepcopy (res) # Without this copy, the next solve will end up mutating the result
4761
48- linsolve. val. b = db - dA * res. u
62+ b = linsolve. val. b
63+ linsolve. val. b = linsolve. dval. b - linsolve. dval. A * res. u
4964 dres = func. val (linsolve. val; kwargs... )
50-
5165 linsolve. val. b = b
5266
53- if RT <: DuplicatedNoNeed
54- return dres
55- elseif RT <: Duplicated
67+ if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
5668 return Duplicated (res, dres)
69+ elseif EnzymeRules. needs_shadow (config)
70+ return dres
71+ elseif EnzymeRules. needs_primal (config)
72+ return res
73+ else
74+ return nothing
5775 end
58-
59- return Duplicated (res, dres)
6076end
6177
62- function EnzymeCore . EnzymeRules. augmented_primal (
78+ function EnzymeRules. augmented_primal (
6379 config, func:: Const{typeof(LinearSolve.init)} ,
6480 :: Type{RT} , prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
6581 kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
@@ -94,10 +110,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
94110 (dval. b for dval in prob. dval)
95111 end
96112
97- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
113+ return EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
98114end
99115
100- function EnzymeCore . EnzymeRules. reverse (
116+ function EnzymeRules. reverse (
101117 config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} ,
102118 cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
103119 kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
131147# y=inv(A) B
132148# dA −= z y^T
133149# dB += z, where z = inv(A^T) dy
134- function EnzymeCore . EnzymeRules. augmented_primal (
150+ function EnzymeRules. augmented_primal (
135151 config, func:: Const{typeof(LinearSolve.solve!)} ,
136152 :: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
137153 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
@@ -184,10 +200,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
184200 cachesolve = deepcopy (linsolve. val)
185201
186202 cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
187- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, cache)
203+ return EnzymeRules. AugmentedReturn (res, dres, cache)
188204end
189205
190- function EnzymeCore . EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
206+ function EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
191207 :: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ;
192208 kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
193209 y, dys, _linsolve, dAs, dbs = cache
0 commit comments