11import Base: start, next, done
22
3- export cg, cg!, CGIterable, PCGIterable, cg_iterator!
3+ export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
44
55mutable struct CGIterable{matT, solT, vecT, numT <: Real }
66 A:: matT
9090
9191# Utility functions
9292
93+ struct CGStateVariables{T}
94+ u:: Vector{T}
95+ r:: Vector{T}
96+ c:: Vector{T}
97+ end
98+
9399function cg_iterator! (x, A, b, Pl = Identity ();
94100 tol = sqrt (eps (real (eltype (b)))),
95101 maxiter:: Int = size (A, 2 ),
102+ statevars:: CGStateVariables = CGStateVariables {eltype(x)} (zeros (x), similar (x), similar (x)),
96103 initially_zero:: Bool = false
97104)
98- u = zeros (x)
99- r = similar (x)
105+ u = statevars. u
106+ r = statevars. r
107+ c = statevars. c
108+ u .= zero (eltype (x))
100109 copy! (r, b)
101110
102111 # Compute r with an MV-product or not.
@@ -107,7 +116,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
107116 reltol = residual * tol # Save one dot product
108117 else
109118 mv_products = 1
110- c = A * x
119+ A_mul_B! (c, A, x)
111120 r .- = c
112121 residual = norm (r)
113122 reltol = norm (b) * tol
@@ -175,6 +184,7 @@ function cg!(x, A, b;
175184 tol = sqrt (eps (real (eltype (b)))),
176185 maxiter:: Int = size (A, 2 ),
177186 log:: Bool = false ,
187+ statevars:: CGStateVariables = CGStateVariables {eltype(x)} (zeros (x), similar (x), similar (x)),
178188 verbose:: Bool = false ,
179189 Pl = Identity (),
180190 kwargs...
@@ -184,7 +194,7 @@ function cg!(x, A, b;
184194 log && reserve! (history, :resnorm , maxiter + 1 )
185195
186196 # Actually perform CG
187- iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs... )
197+ iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs... )
188198 if log
189199 history. mvps = iterable. mv_products
190200 end
0 commit comments