Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,20 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
end
end

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
f.fields.A .+= t.A
f.fields.b .+= t.b
f.fields.u .+= t.u

return NoRData()
end

# rrules for LinearCache
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode

# rrule for solve!
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode

end
78 changes: 78 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,81 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end

function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions(issquare(prob.A))
alg = defaultalg(prob.A, prob.b, assump)
CRC.rrule(T, prob, alg, args...; kwargs...)
end

function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
init_res = LinearSolve.init(prob, alg)
function init_adjoint(∂init)
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
end

return init_res, init_adjoint
end

function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions()
alg = defaultalg(cache.A, cache.b, assump)
CRC.rrule(T, cache, alg, args...; kwargs)
end

function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
alg, cache.A, cache.b), kwargs...)
(; A, sensealg) = cache
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
if sensealg.linsolve === missing
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = solve!(cache)
function solve!_adjoint(∂sol)
∂∅ = NoTangent()
∂u = ∂sol.u

if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = adjoint(sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ

if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
end

∂prob = LinearProblem(∂A, ∂b, ∂∅)
∂cache = LinearSolve.init(∂prob, u=∂u)
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, solve!_adjoint
end
132 changes: 132 additions & 0 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,135 @@ for alg in (
@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end

# Tests for solve! and init rrules.
n = 4
A = rand(n, n);
b1 = rand(n);
b2 = rand(n);

function f(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)),
f, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f2(A, b1, b2; alg=RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f2(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)),
f2, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f3(A, b1, b2; alg=KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f3(copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_gradient!!(
prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)),
f3, copy(A), copy(b1), copy(b2)
)

@test value == f_primal
@test gradient[2] ≈ dA2 atol = 5e-5
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f4(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
solve!(cache)
s1 = copy(cache.u)
cache.b = b2
solve!(cache)
s2 = copy(cache.u)
norm(s1 + s2)
end

A = rand(n, n);
b1 = rand(n);
b2 = rand(n);
# f_primal = f4(copy(A), copy(b1), copy(b2))

rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
rule, 1.0,
f4, copy(A), copy(b1), copy(b2)
)

# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))

# @test value == f_primal
# @test grad[2] ≈ dA2
# @test grad[3] ≈ db12
# @test grad[4] ≈ db22

A = rand(n, n);
b1 = rand(n);

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
# for B
fb_closure = b -> fnice(A, b, alg)
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec

val, en_jac = Mooncake.value_and_gradient!!(
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
fnice, copy(A), copy(b1), alg
)
@test en_jac[3] ≈ fd_jac_b rtol = 1e-5

# For A
fA_closure = A -> fnice(A, b1, alg)
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
A_grad = en_jac[2] |> vec
@test A_grad ≈ fd_jac_A rtol = 1e-5
end
Loading