11import Base: start, next, done
22
3- export cg, cg!, CGIterable, PCGIterable, cg_iterator!
3+ export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
44
55mutable struct CGIterable{matT, solT, vecT, numT <: Real }
66 A:: matT
9090
9191# Utility functions
9292
93+ """
94+ Intermediate CG state variables to be used inside cg and cg!. `u`, `r` and `c` should be of the same type as the solution of `cg` or `cg!`.
95+ ```
96+ struct CGStateVariables{T,Tx<:AbstractArray{T}}
97+ u::Tx
98+ r::Tx
99+ c::Tx
100+ end
101+ ```
102+ """
103+ struct CGStateVariables{T,Tx<: AbstractArray{T} }
104+ u:: Tx
105+ r:: Tx
106+ c:: Tx
107+ end
108+
93109function cg_iterator! (x, A, b, Pl = Identity ();
94110 tol = sqrt (eps (real (eltype (b)))),
95111 maxiter:: Int = size (A, 2 ),
112+ statevars:: CGStateVariables = CGStateVariables {eltype(x),typeof(x)} (zeros (x), similar (x), similar (x)),
96113 initially_zero:: Bool = false
97114)
98- u = zeros (x)
99- r = similar (x)
115+ u = statevars. u
116+ r = statevars. r
117+ c = statevars. c
118+ u .= zero (eltype (x))
100119 copy! (r, b)
101120
102121 # Compute r with an MV-product or not.
@@ -107,7 +126,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
107126 reltol = residual * tol # Save one dot product
108127 else
109128 mv_products = 1
110- c = A * x
129+ A_mul_B! (c, A, x)
111130 r .- = c
112131 residual = norm (r)
113132 reltol = norm (b) * tol
@@ -145,15 +164,16 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
145164
146165## Keywords
147166
167+ - `statevars::CGStateVariables`: Has 3 arrays similar to `x` to hold intermediate results;
148168- `initially_zero::Bool`: If `true` assumes that `iszero(x)` so that one
149169 matrix-vector product can be saved when computing the initial
150170 residual vector;
151171- `Pl = Identity()`: left preconditioner of the method. Should be symmetric,
152- positive-definite like `A`.
172+ positive-definite like `A`;
153173- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`;
154174- `maxiter::Int = size(A,2)`: maximum number of iterations;
155175- `verbose::Bool = false`: print method information;
156- - `log::Bool = false`: keep track of the residual norm in each iteration;
176+ - `log::Bool = false`: keep track of the residual norm in each iteration.
157177
158178# Output
159179
@@ -175,6 +195,7 @@ function cg!(x, A, b;
175195 tol = sqrt (eps (real (eltype (b)))),
176196 maxiter:: Int = size (A, 2 ),
177197 log:: Bool = false ,
198+ statevars:: CGStateVariables = CGStateVariables {eltype(x), typeof(x)} (zeros (x), similar (x), similar (x)),
178199 verbose:: Bool = false ,
179200 Pl = Identity (),
180201 kwargs...
@@ -184,7 +205,7 @@ function cg!(x, A, b;
184205 log && reserve! (history, :resnorm , maxiter + 1 )
185206
186207 # Actually perform CG
187- iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs... )
208+ iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs... )
188209 if log
189210 history. mvps = iterable. mv_products
190211 end
0 commit comments