Skip to content

Commit e111712

Browse files
Merge pull request #4037 from hersle/sparse_performance
Improve performance with sparse Jacobians
2 parents 8b0f3cb + 5ad8691 commit e111712

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

benchmark/benchmarks.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,18 @@ prob = ODEProblem(model, u0, tspan)
7676
large_param_init["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
7777

7878
large_param_init["init"] = @benchmarkable init($prob)
79+
80+
sparse_analytical_jacobian = SUITE["sparse_analytical_jacobian"]
81+
82+
eqs = [D(x[i]) ~ prod(x[j] for j in 1:N if (i+j) % 3 == 0) for i in 1:N]
83+
@mtkcompile model = System(eqs, t)
84+
u0 = collect(x .=> 1.0)
85+
tspan = (0.0, 1.0)
86+
jac = true
87+
sparse = true
88+
prob = ODEProblem(model, u0, tspan; jac, sparse)
89+
out = similar(prob.f.jac_prototype)
90+
91+
sparse_analytical_jacobian["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan; jac, sparse)
92+
sparse_analytical_jacobian["f_oop"] = @benchmarkable $(prob.f.jac.f_oop)($(prob.u0), $(prob.p), $(first(tspan)))
93+
sparse_analytical_jacobian["f_iip"] = @benchmarkable $(prob.f.jac.f_iip)($out, $(prob.u0), $(prob.p), $(first(tspan)))

src/systems/codegen.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,14 @@ function calculate_jacobian(sys::System;
186186
if sparse
187187
jac = sparsejacobian(rhs, dvs; simplify)
188188
if get_iv(sys) !== nothing
189-
W_s = W_sparsity(sys)
190-
(Is, Js, Vs) = findnz(W_s)
191-
# Add nonzeros of W as non-structural zeros of the Jacobian (to ensure equal
192-
# results for oop and iip Jacobian)
193-
for (i, j) in zip(Is, Js)
194-
iszero(jac[i, j]) && begin
195-
jac[i, j] = 1
196-
jac[i, j] = 0
197-
end
198-
end
189+
# Add nonzeros of W as non-structural zeros of the Jacobian
190+
# (to ensure equal results for oop and iip Jacobian)
191+
JIs, JJs, JVs = findnz(jac)
192+
WIs, WJs, _ = findnz(W_sparsity(sys))
193+
append!(JIs, WIs) # explicitly put all W's indices also in J,
194+
append!(JJs, WJs) # even if it duplicates some indices
195+
append!(JVs, zeros(eltype(JVs), length(WIs))) # add zero
196+
jac = SparseArrays.sparse(JIs, JJs, JVs) # values at duplicate indices are summed; not overwritten
199197
end
200198
else
201199
jac = jacobian(rhs, dvs; simplify)
@@ -213,21 +211,23 @@ Generate the jacobian function for the equations of a [`System`](@ref).
213211
214212
$GENERATE_X_KWARGS
215213
- `simplify`, `sparse`: Forwarded to [`calculate_jacobian`](@ref).
214+
- `checkbounds`: Whether to check correctness of indices at runtime if `sparse`.
215+
Also forwarded to `build_function_wrapper`.
216216
217217
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
218218
"""
219219
function generate_jacobian(sys::System;
220220
simplify = false, sparse = false, eval_expression = false,
221221
eval_module = @__MODULE__, expression = Val{true}, wrap_gfw = Val{false},
222-
kwargs...)
222+
checkbounds = false, kwargs...)
223223
dvs = unknowns(sys)
224224
jac = calculate_jacobian(sys; simplify, sparse, dvs)
225225
p = reorder_parameters(sys)
226226
t = get_iv(sys)
227-
if t === nothing
228-
wrap_code = (identity, identity)
227+
if t !== nothing && sparse && checkbounds
228+
wrap_code = assert_jac_length_header(sys) # checking sparse J indices at runtime is expensive for large systems
229229
else
230-
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity)
230+
wrap_code = (identity, identity)
231231
end
232232
args = (dvs, p...)
233233
nargs = 2
@@ -236,7 +236,7 @@ function generate_jacobian(sys::System;
236236
nargs = 3
237237
end
238238
res = build_function_wrapper(sys, jac, args...; wrap_code, expression = Val{true},
239-
expression_module = eval_module, kwargs...)
239+
expression_module = eval_module, checkbounds, kwargs...)
240240
return maybe_compile_function(
241241
expression, wrap_gfw, (2, nargs, is_split(sys)), res; eval_expression, eval_module)
242242
end
@@ -328,12 +328,14 @@ Generate the `W = γ * M + J` function for the equations of a [`System`](@ref).
328328
329329
$GENERATE_X_KWARGS
330330
- `simplify`, `sparse`: Forwarded to [`calculate_jacobian`](@ref).
331+
- `checkbounds`: Whether to check correctness of indices at runtime if `sparse`.
332+
Also forwarded to `build_function_wrapper`.
331333
332334
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
333335
"""
334336
function generate_W(sys::System;
335337
simplify = false, sparse = false, expression = Val{true}, wrap_gfw = Val{false},
336-
eval_expression = false, eval_module = @__MODULE__, kwargs...)
338+
eval_expression = false, eval_module = @__MODULE__, checkbounds = false, kwargs...)
337339
dvs = unknowns(sys)
338340
ps = parameters(sys; initial_parameters = true)
339341
M = calculate_massmatrix(sys; simplify)
@@ -343,13 +345,15 @@ function generate_W(sys::System;
343345
J = calculate_jacobian(sys; simplify, sparse, dvs)
344346
W = W_GAMMA * M + J
345347
t = get_iv(sys)
346-
if t !== nothing
347-
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity)
348+
if t !== nothing && sparse && checkbounds
349+
wrap_code = assert_jac_length_header(sys)
350+
else
351+
wrap_code = (identity, identity)
348352
end
349353

350354
p = reorder_parameters(sys, ps)
351355
res = build_function_wrapper(sys, W, dvs, p..., W_GAMMA, t; wrap_code,
352-
p_end = 1 + length(p), kwargs...)
356+
p_end = 1 + length(p), checkbounds, kwargs...)
353357
return maybe_compile_function(
354358
expression, wrap_gfw, (2, 4, is_split(sys)), res; eval_expression, eval_module)
355359
end

test/jacobiansparsity.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SparseArrays, OrdinaryDiffEq, DiffEqBase
1+
using ModelingToolkit, SparseArrays, OrdinaryDiffEq, DiffEqBase, BenchmarkTools
22

33
N = 3
44
xyd_brusselator = range(0, stop = 1, length = N)
@@ -58,6 +58,8 @@ prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
5858
#@test_nowarn solve(prob, Rosenbrock23())
5959
@test findnz(calculate_jacobian(sys, sparse = true))[1:2] ==
6060
findnz(prob.f.jac_prototype)[1:2]
61+
out = similar(prob.f.jac_prototype)
62+
@test (@ballocated $(prob.f.jac.f_iip)($out, $(prob.u0), $(prob.p), 0.0)) == 0 # should not allocate
6163

6264
# test when not sparse
6365
prob = ODEProblem(sys, u0, (0, 11.5), sparse = false, jac = true)
@@ -75,6 +77,12 @@ f = DiffEqBase.ODEFunction(sys, u0 = nothing, sparse = true, jac = false)
7577
@test findnz(f.jac_prototype)[1:2] == findnz(JP)[1:2]
7678
@test eltype(f.jac_prototype) == Float64
7779

80+
# test sparsity index pattern checking
81+
f = DiffEqBase.ODEFunction(sys, u0 = nothing, sparse = true, jac = true, checkbounds = true)
82+
out = sparse([1.0 0.0; 0.0 1.0]) # choose a wrong size on purpose
83+
@test size(out) != size(f.jac_prototype) # check that the size is indeed wrong
84+
@test_throws AssertionError f.jac.f_iip(out, u0, p, 0.0) # check that we get an error
85+
7886
# test when u0 is not Float64
7987
u0 = similar(init_brusselator_2d(xyd_brusselator), Float32)
8088
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
@@ -100,7 +108,7 @@ prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
100108
u0 = [x => 1, y => 0]
101109
prob = ODEProblem(
102110
pend, [u0; [g => 1]], (0, 11.5), guesses ==> 1], sparse = true, jac = true)
103-
jac, jac! = generate_jacobian(pend; expression = Val{false}, sparse = true)
111+
jac, jac! = generate_jacobian(pend; expression = Val{false}, sparse = true, checkbounds = true)
104112
jac_prototype = ModelingToolkit.jacobian_sparsity(pend)
105113
W_prototype = ModelingToolkit.W_sparsity(pend)
106114
@test nnz(W_prototype) == nnz(jac_prototype) + 2
@@ -113,7 +121,7 @@ prob = ODEProblem(sys, u0, (0, 11.5), sparse = true, jac = true)
113121
t = 0.0
114122
@test_throws AssertionError jac!(similar(jac_prototype, Float64), u, p, t)
115123

116-
W, W! = generate_W(pend; expression = Val{false}, sparse = true)
124+
W, W! = generate_W(pend; expression = Val{false}, sparse = true, checkbounds = true)
117125
γ = 0.1
118126
M = sparse(calculate_massmatrix(pend))
119127
@test_throws AssertionError W!(similar(jac_prototype, Float64), u, p, γ, t)

0 commit comments

Comments
 (0)