Skip to content

Commit 89baa2d

Browse files
authored
Fix handling for default algorithms for DualLinearProblems (#775)
* fix handling for default algs for DualLinearProblems * remove redunant check * fix alg * make sure JET test uses DualCache * add JET tests * add comment * default tests are broken anyway * add comment * remove broken * set tests broken only for < 1.11
1 parent ed1ea74 commit 89baa2d

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4-
using LinearSolve: SciMLLinearSolveAlgorithm, __init
4+
using LinearSolve: SciMLLinearSolveAlgorithm, __init, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg
55
using LinearAlgebra
66
using ForwardDiff
77
using ForwardDiff: Dual, Partials
@@ -196,6 +196,24 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza
196196
return __init(prob, alg, args...; kwargs...)
197197
end
198198

199+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
200+
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
201+
return __init(prob, alg, args...; kwargs...)
202+
else
203+
return __dual_init(prob, alg, args...; kwargs...)
204+
end
205+
end
206+
207+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::Nothing,
208+
args...;
209+
assumptions = OperatorAssumptions(issquare(prob.A)),
210+
kwargs...)
211+
new_A = nodual_value(prob.A)
212+
new_b = nodual_value(prob.b)
213+
SciMLBase.init(
214+
prob, defaultalg(new_A, new_b, assumptions), args...; assumptions, kwargs...)
215+
end
216+
199217
function __dual_init(
200218
prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
201219
args...;
@@ -225,11 +243,8 @@ function __dual_init(
225243
dual_type = get_dual_type(prob.b)
226244
end
227245

228-
alg isa LinearSolve.DefaultLinearSolver ?
229-
real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg
230-
231246
non_partial_cache = init(
232-
primal_prob, real_alg, assumptions, args...;
247+
primal_prob, alg, assumptions, args...;
233248
alias = alias, abstol = abstol, reltol = reltol,
234249
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
235250
sensealg = sensealg, u0 = new_u0, kwargs...)

test/forwarddiff_overloads.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ end
1313
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
1414

1515
prob = LinearProblem(A, b)
16-
overload_x_p = solve(prob)
16+
overload_x_p = solve(prob, LUFactorization())
1717
backslash_x_p = A \ b
1818
krylov_overload_x_p = solve(prob, KrylovJL_GMRES())
1919
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
@@ -42,7 +42,7 @@ prob = LinearProblem(A, b)
4242
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])
4343

4444
prob = LinearProblem(A, b)
45-
cache = init(prob)
45+
cache = init(prob, LUFactorization())
4646

4747
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
4848
cache.A = new_A
@@ -60,7 +60,7 @@ backslash_x_p = new_A \ new_b
6060
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])
6161

6262
prob = LinearProblem(A, b)
63-
cache = init(prob)
63+
cache = init(prob, LUFactorization())
6464

6565
new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
6666
cache.A = new_A
@@ -75,7 +75,7 @@ backslash_x_p = new_A \ b
7575
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
7676

7777
prob = LinearProblem(A, b)
78-
cache = init(prob)
78+
cache = init(prob, LUFactorization())
7979

8080
_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
8181
cache.b = new_b
@@ -99,7 +99,7 @@ original_x_p = A \ b
9999
@test (overload_x_p, original_x_p, rtol = 1e-9)
100100

101101
prob = LinearProblem(A, b)
102-
cache = init(prob)
102+
cache = init(prob, LUFactorization())
103103

104104
new_A,
105105
new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0),
@@ -155,7 +155,7 @@ end
155155
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
156156

157157
prob = LinearProblem(A, b)
158-
cache = init(prob)
158+
cache = init(prob, LUFactorization())
159159

160160
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
161161
cache.A = new_A
@@ -193,3 +193,5 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
193193

194194
prob = LinearProblem(A, b)
195195
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache
196+
197+
@test init(prob) isa LinearSolve.LinearCache

test/nopre/jet.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,21 @@ end
136136

137137
@testset "JET Tests for creating Dual solutions" begin
138138
# Make sure there's no runtime dispatch when making solutions of Dual problems
139-
dual_cache = init(dual_prob)
139+
dual_cache = init(dual_prob, LUFactorization())
140140
ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt)
141141
JET.@test_opt ext.linearsolve_dual_solution(
142142
[1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache)
143+
end
144+
145+
@testset "JET Tests for default algs with DualLinear Problems" begin
146+
# Test for Default alg choosing for DualLinear Problems
147+
# These should both produce a LinearCache
148+
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization)
149+
if VERSION < v"1.11"
150+
JET.@test_opt init(dual_prob, alg) broken=true
151+
JET.@test_opt init(dual_prob) broken=true
152+
else
153+
JET.@test_opt init(dual_prob, alg)
154+
JET.@test_opt init(dual_prob)
155+
end
143156
end

0 commit comments

Comments
 (0)