Skip to content

Commit 52f2191

Browse files
Add API to update iterative solver tolerances
1 parent d9b3a96 commit 52f2191

File tree

5 files changed

+72
-1
lines changed

5 files changed

+72
-1
lines changed

ext/LinearSolveIterativeSolversExt.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...
105105
cache.Pr = Pr
106106
cache.precsisfresh = false
107107
end
108-
if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable)
108+
if cache.isfresh || !(cache.cacheval isa IterativeSolvers.GMRESIterable)
109109
solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl,
110110
cache.Pr,
111111
cache.maxiters, cache.abstol, cache.reltol,
@@ -149,4 +149,29 @@ function purge_history!(iter::IterativeSolvers.GMRESIterable, x, b)
149149
nothing
150150
end
151151

152+
# The constructors above all set the tolerance as follows.
153+
# tol = max(reltol * ||residual||, abstol)
154+
#
155+
# The iterable in turn is stored in `cache.cacheval`.
156+
function update_tolerances_iterativesolversjl!(iter, atol, rtol)
157+
Rnorm = norm(iter.r)
158+
iter.tol = max(rtol * Rnorm, atol)
159+
end
160+
function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.GMRESIterable, atol, rtol)
161+
Rnorm = iter.residual.current
162+
iter.tol = max(rtol * Rnorm, atol)
163+
end
164+
function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.MINRESIterable, atol, rtol)
165+
Rnorm = norm(iter.v_curr)
166+
iter.tol = max(rtol * Rnorm, atol)
167+
end
168+
function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.IDRSIterable, atol, rtol)
169+
Rnorm = iter.normR
170+
iter.tol = max(rtol * Rnorm, atol)
171+
end
172+
173+
function LinearSolve.update_tolerances_internal!(cache, alg::IterativeSolversJL, atol, rtol)
174+
update_tolerances_iterativesolversjl!(cache.cacheval, atol, rtol)
175+
end
176+
152177
end

ext/LinearSolveKrylovKitExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...)
4848
iters = iters)
4949
end
5050

51+
LinearSolve.update_tolerances_internal!(cache, alg::KrylovKitJL, atol, rtol) = nothing
52+
5153
end

src/common.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,22 @@ function SciMLBase.solve(prob::StaticLinearProblem,
478478
return SciMLBase.build_linear_solution(
479479
alg, u, nothing, prob; retcode = ReturnCode.Success)
480480
end
481+
482+
function update_tolerances!(cache; abstol = nothing, reltol = nothing)
483+
if abstol !== nothing
484+
cache.abstol = abstol
485+
end
486+
if reltol !== nothing
487+
cache.reltol = reltol
488+
end
489+
update_tolerances_internal!(cache, cache.alg, abstol, reltol)
490+
end
491+
492+
493+
function update_tolerances_internal!(cache, alg::AbstractFactorization, abstol, reltol)
494+
error("Cannot update tolerances for factorization.")
495+
end
496+
497+
function update_tolerances_internal!(cache, alg::AbstractKrylovSubspaceMethod, abstol, reltol)
498+
@warn "Tolerance update for Krylov subspace method '$typeof(alg)' not implemented." maxlog = 1
499+
end

src/iterative_wrappers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,5 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
338338
return SciMLBase.build_linear_solution(alg, cache.u, Ref(resid), cache;
339339
iters = stats.niter, retcode, stats)
340340
end
341+
342+
update_tolerances_internal!(cache, alg::KrylovJL, atol, rtol) = nothing

test/basictests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ A4 = A2 .|> ComplexF32
3333
b4 = b2 .|> ComplexF32
3434
x4 = x2 .|> ComplexF32
3535

36+
A5_ = A - 0.01Tridiagonal(ones(n,n)) + sparse([1], [8], 0.5, n,n)
37+
A5 = sparse(transpose(A5_) * A5_)
38+
x5 = zeros(n)
39+
u5 = ones(n)
40+
b5 = A5*u5
41+
3642
prob1 = LinearProblem(A1, b1; u0 = x1)
3743
prob2 = LinearProblem(A2, b2; u0 = x2)
3844
prob3 = LinearProblem(A3, b3; u0 = x3)
3945
prob4 = LinearProblem(A4, b4; u0 = x4)
46+
prob5 = LinearProblem(A5, b5)
4047

4148
cache_kwargs = (;abstol = 1e-8, reltol = 1e-8, maxiter = 30)
4249

@@ -69,6 +76,19 @@ function test_interface(alg, prob1, prob2)
6976
return
7077
end
7178

79+
function test_tolerance_update(alg, prob, u)
80+
cache = init(prob, alg; verbose=LinearVerbosity(; error_control=SciMLLogging.WarnLevel(), numerical=SciMLLogging.WarnLevel()))
81+
LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol=1e-8)
82+
u1 = copy(solve!(cache).u)
83+
84+
LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol=1e-8)
85+
u2 = solve!(cache).u
86+
87+
@test norm(u2 - u) < norm(u1 - u)
88+
89+
return
90+
end
91+
7292
@testset "LinearSolve" begin
7393
@testset "Default Linear Solver" begin
7494
test_interface(nothing, prob1, prob2)
@@ -379,6 +399,7 @@ end
379399
@testset "$name" begin
380400
test_interface(algorithm, prob1, prob2)
381401
test_interface(algorithm, prob3, prob4)
402+
test_tolerance_update(algorithm, prob5, u5)
382403
end
383404
end
384405
end
@@ -418,6 +439,7 @@ end
418439
@testset "$(alg[1])" begin
419440
test_interface(alg[2], prob1, prob2)
420441
test_interface(alg[2], prob3, prob4)
442+
test_tolerance_update(alg[2], prob5, u5)
421443
end
422444
end
423445
end
@@ -432,6 +454,7 @@ end
432454
@testset "$(alg[1])" begin
433455
test_interface(alg[2], prob1, prob2)
434456
test_interface(alg[2], prob3, prob4)
457+
test_tolerance_update(alg[2], prob5, u5)
435458
end
436459
@test alg[2] isa KrylovKitJL
437460
end

0 commit comments

Comments
 (0)