-
Notifications
You must be signed in to change notification settings - Fork 106
Add COCG method for complex symmetric linear systems #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
c5b440e
91208cd
3337884
f7df543
d5b31f4
97315be
593e88c
835a894
9c47e7f
3cd7969
d16fecb
04c3c16
4800129
9a7fb26
f3710c3
b457247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,20 +2,29 @@ import Base: iterate | |
| using Printf | ||
| export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||
|
|
||
| mutable struct CGIterable{matT, solT, vecT, numT <: Real} | ||
| # Conjugated dot product | ||
| _dot(x, ::Val{true}) = sum(abs2, x) # for x::Complex, returns Real | ||
| _dot(x, y, ::Val{true}) = dot(x, y) | ||
|
|
||
| # Unconjugated dot product | ||
| _dot(x, ::Val{false}) = sum(xₖ^2 for xₖ in x) | ||
| _dot(x, y, ::Val{false}) = sum(prod, zip(x,y)) | ||
|
|
||
| mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}} | ||
| A::matT | ||
| x::solT | ||
| r::vecT | ||
| c::vecT | ||
| u::vecT | ||
| tol::numT | ||
| residual::numT | ||
| prev_residual::numT | ||
| ρ_prev::paramT | ||
| maxiter::Int | ||
| mv_products::Int | ||
| conjugate_dot::boolT | ||
| end | ||
|
|
||
| mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number} | ||
| mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, boolT <: Union{Val{true},Val{false}}} | ||
|
||
| Pl::precT | ||
| A::matT | ||
| x::solT | ||
|
|
@@ -24,9 +33,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb | |
| u::vecT | ||
| tol::numT | ||
| residual::numT | ||
| ρ::paramT | ||
| ρ_prev::paramT | ||
| maxiter::Int | ||
| mv_products::Int | ||
| conjugate_dot::boolT | ||
| end | ||
|
|
||
| @inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol | ||
|
|
@@ -47,18 +57,19 @@ function iterate(it::CGIterable, iteration::Int=start(it)) | |
| end | ||
|
|
||
| # u := r + βu (almost an axpy) | ||
| β = it.residual^2 / it.prev_residual^2 | ||
| ρ = _dot(it.r, it.conjugate_dot) | ||
|
||
| β = ρ / it.ρ_prev | ||
| it.u .= it.r .+ β .* it.u | ||
|
|
||
| # c = A * u | ||
| mul!(it.c, it.A, it.u) | ||
| α = it.residual^2 / dot(it.u, it.c) | ||
| α = ρ / _dot(it.u, it.c, it.conjugate_dot) | ||
|
|
||
| # Improve solution and residual | ||
| it.ρ_prev = ρ | ||
| it.x .+= α .* it.u | ||
| it.r .-= α .* it.c | ||
|
|
||
| it.prev_residual = it.residual | ||
| it.residual = norm(it.r) | ||
|
|
||
| # Return the residual at item and iteration number as state | ||
|
|
@@ -78,18 +89,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it)) | |
| # Apply left preconditioner | ||
| ldiv!(it.c, it.Pl, it.r) | ||
|
|
||
| ρ_prev = it.ρ | ||
| it.ρ = dot(it.c, it.r) | ||
|
|
||
| # u := c + βu (almost an axpy) | ||
| β = it.ρ / ρ_prev | ||
| ρ = _dot(it.r, it.c, it.conjugate_dot) | ||
| β = ρ / it.ρ_prev | ||
| it.u .= it.c .+ β .* it.u | ||
|
|
||
| # c = A * u | ||
| mul!(it.c, it.A, it.u) | ||
| α = it.ρ / dot(it.u, it.c) | ||
| α = ρ / _dot(it.u, it.c, it.conjugate_dot) | ||
|
|
||
| # Improve solution and residual | ||
| it.ρ_prev = ρ | ||
| it.x .+= α .* it.u | ||
| it.r .-= α .* it.c | ||
|
|
||
|
|
@@ -122,7 +132,8 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
| reltol::Real = sqrt(eps(real(eltype(b)))), | ||
| maxiter::Int = size(A, 2), | ||
| statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
| initially_zero::Bool = false) | ||
| initially_zero::Bool = false, | ||
| conjugate_dot::Bool = true) | ||
| u = statevars.u | ||
| r = statevars.r | ||
| c = statevars.c | ||
|
|
@@ -142,15 +153,13 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
|
|
||
| # Return the iterable | ||
| if isa(Pl, Identity) | ||
| return CGIterable(A, x, r, c, u, | ||
| tolerance, residual, one(residual), | ||
| maxiter, mv_products | ||
| ) | ||
| return CGIterable(A, x, r, c, u, tolerance, residual, | ||
| conjugate_dot ? one(real(eltype(r))) : one(eltype(r)), # for conjugated dot, ρ_prev remains real | ||
| maxiter, mv_products, Val(conjugate_dot)) | ||
| else | ||
| return PCGIterable(Pl, A, x, r, c, u, | ||
| tolerance, residual, one(eltype(x)), | ||
| maxiter, mv_products | ||
| ) | ||
| tolerance, residual, one(eltype(r)), | ||
| maxiter, mv_products, Val(conjugate_dot)) | ||
| end | ||
| end | ||
|
|
||
|
|
@@ -211,6 +220,7 @@ function cg!(x, A, b; | |
| statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
| verbose::Bool = false, | ||
| Pl = Identity(), | ||
| conjugate_dot::Bool = true, | ||
| kwargs...) | ||
| history = ConvergenceHistory(partial = !log) | ||
| history[:abstol] = abstol | ||
|
|
@@ -219,7 +229,7 @@ function cg!(x, A, b; | |
|
|
||
| # Actually perform CG | ||
| iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter, | ||
| statevars = statevars, kwargs...) | ||
| statevars = statevars, conjugate_dot = conjugate_dot, kwargs...) | ||
| if log | ||
| history.mvps = iterable.mv_products | ||
| end | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,26 @@ ldiv!(y, P::JacobiPrec, x) = y .= x ./ P.diagonal | |
|
|
||
| Random.seed!(1234321) | ||
|
|
||
| @testset "Vector{$T}, conjugated and unconjugated dot products" for T in (ComplexF32, ComplexF64) | ||
| n = 100 | ||
| x = rand(T, n) | ||
| y = rand(T, n) | ||
|
|
||
| # Conjugated dot product | ||
| @test IterativeSolvers._dot(x, Val(true)) ≈ x'x | ||
| @test IterativeSolvers._dot(x, y, Val(true)) ≈ x'y | ||
| @test IterativeSolvers._dot(x, Val(true)) ≈ IterativeSolvers._dot(x, x, Val(true)) | ||
|
|
||
| # Unonjugated dot product | ||
| @test IterativeSolvers._dot(x, Val(false)) ≈ transpose(x) * x | ||
| @test IterativeSolvers._dot(x, y, Val(false)) ≈ transpose(x) * y | ||
| @test IterativeSolvers._dot(x, Val(false)) ≈ IterativeSolvers._dot(x, x, Val(false)) | ||
| end | ||
|
|
||
| @testset "Small full system" begin | ||
| n = 10 | ||
|
|
||
| @testset "Matrix{$T}" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
| @testset "Matrix{$T}, conjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
| A = rand(T, n, n) | ||
| A = A' * A + I | ||
| b = rand(T, n) | ||
|
|
@@ -50,6 +66,37 @@ Random.seed!(1234321) | |
| x0 = cg(A, zeros(T, n)) | ||
| @test x0 == zeros(T, n) | ||
| end | ||
|
|
||
| @testset "Matrix{$T}, unconjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
| A = rand(T, n, n) | ||
| A = A + transpose(A) + 15I | ||
| x = ones(T, n) | ||
| b = A * x | ||
|
|
||
| reltol = √eps(real(T)) | ||
|
|
||
| # Solve without preconditioner | ||
| x1, his1 = cg(A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false) | ||
| @test isa(his1, ConvergenceHistory) | ||
| @test norm(A * x1 - b) / norm(b) ≤ reltol | ||
|
|
||
| # With an initial guess | ||
| x_guess = rand(T, n) | ||
| x2, his2 = cg!(x_guess, A, b, reltol = reltol, maxiter = 100, log = true, conjugate_dot = false) | ||
| @test isa(his2, ConvergenceHistory) | ||
| @test x2 == x_guess | ||
| @test norm(A * x2 - b) / norm(b) ≤ reltol | ||
|
|
||
| # The following tests fails CI on Windows and Ubuntu due to a | ||
| # `SingularException(4)` | ||
|
||
| if T == Float32 && (Sys.iswindows() || Sys.islinux()) | ||
| continue | ||
| end | ||
| # Do an exact LU decomp of a nearby matrix | ||
| F = lu(A + rand(T, n, n)) | ||
| x3, his3 = cg(A, b, Pl = F, maxiter = 100, reltol = reltol, log = true, conjugate_dot = false) | ||
| @test norm(A * x3 - b) / norm(b) ≤ reltol | ||
|
||
| end | ||
| end | ||
|
|
||
| @testset "Sparse Laplacian" begin | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to just use
_norm/norm; that's what was used before and norm is slightly more stable. Also I'm not sure ifsum(abs2, x)works as efficiently on GPUs,normis definitely optimized.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "norm is more stable"? If you are talking about the possibility of overflow, I think it doesn't matter because the calculated norm is squared again in the algorithm.
Also, for the unconjugated dot product, the
norm(x)function would returnsqrt(xᵀx), which is not a norm becausexᵀxis complex. In COCG the quantity we use isxᵀx, notsqrt(xᵀx), so I don't feel that it is a good idea to store the square-rooted quantity and then to square it again when using it.I verify that the GPU performance of
sumdoes not degrade.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main advantage of falling back to
normanddotin the conjugated case is that they generalize better — if the user has defined some specialized Banach-space type (and corresponding self-adjoint linear operatorsAwith overloaded*), they should have overloadednormanddot, whereassum(abs2, x)might no longer work.(Even as simple a generalization as an array of arrays will fail with
sum(abs2, x), whereas they work withdotandnorm.)The overhead of the additional
sqrtshould be negligible.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because it dispatches to BLAS which does a stable norm computation. But generally BLAS libs are free to choose how to implement
norm.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very convincing argument. Additionally, considering that a user-defined Banach-space type should overload
normanddotbut not the unconjugated dot, it will be nice to be able to define the unconjugated dot usingdotlikebut of course this is inefficient as it conjugates
xtwice, not to mention extra allocations. I wishdotuwas exported, as you suggested in https://github.com/JuliaLang/julia/issues/22227#issuecomment-306224429.Users can overload
_dot(x, y, ::UnconjugatedDot)for the Banach-space type they define, but what would be the best way to minimize such a requirement? I have pushed a commit implementing the unconjugated dot withsum, but if there is a better approach, please let me know.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could implement the unconjugated dot as
transpose(x) * y... this should have no extra allocations sincetransposeis a "view" type by default?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though
transpose(x) * ywouldn't work ifxandyare matrices (andAis some kind of multlinear operator); maybe it's best to stick with thesum(zip(x,y))solution you have now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we assume that
getindex()is always defined forxandy? Then how about usingtranspose(@view(x[:])) * @view(y[:])? This uses allocations, but it seems much faster thansum(prod, zip(x,y)):It is nearly on a par with
dot:Also, when
xandyareAbstractVectors, the proposed method does not use allocations:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do
dot(transpose(x'), x), maybe?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that
x'was nonallocating! This looked like a great nonallocating implementation, but it turns out that this is slower than the earlier proposed solution using@view:Also, neither
x'nortranspose(x)works forxwhose dimension is greater than 2.I just pushed an implementation using
@viewfor now.