Skip to content

Commit 6bc7834

Browse files
Refactored Rosenbrock32 and Rosenbrock23
1 parent 4abda1b commit 6bc7834

File tree

5 files changed

+82
-257
lines changed

5 files changed

+82
-257
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ function SciMLBase.interp_summary(::Type{cacheType},
33
cacheType <:
44
Union{Rosenbrock23ConstantCache,
55
Rosenbrock32ConstantCache,
6-
Rosenbrock23Cache,
7-
Rosenbrock32Cache}}
6+
RosenbrockCombinedCache}}
87
dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
98
"1st order linear"
109
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 7 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <:
6262
interp_order::Int
6363
end
6464

65-
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
65+
@cache mutable struct RosenbrockCombinedCache{uType, rateType, uNoUnitsType, JType, WType,
6666
TabType, TFType, UFType, F, JCType, GCType,
6767
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
6868
u::uType
@@ -95,40 +95,7 @@ end
9595
stage_limiter!::StageLimiter
9696
end
9797

98-
@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
99-
TabType, TFType, UFType, F, JCType, GCType,
100-
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
101-
u::uType
102-
uprev::uType
103-
k₁::rateType
104-
k₂::rateType
105-
k₃::rateType
106-
du1::rateType
107-
du2::rateType
108-
f₁::rateType
109-
fsalfirst::rateType
110-
fsallast::rateType
111-
dT::rateType
112-
J::JType
113-
W::WType
114-
tmp::rateType
115-
atmp::uNoUnitsType
116-
weight::uNoUnitsType
117-
tab::TabType
118-
tf::TFType
119-
uf::UFType
120-
linsolve_tmp::rateType
121-
linsolve::F
122-
jac_config::JCType
123-
grad_config::GCType
124-
reltol::RTolType
125-
alg::A
126-
algebraic_vars::AV
127-
step_limiter!::StepLimiter
128-
stage_limiter!::StageLimiter
129-
end
130-
131-
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
98+
function alg_cache(alg::Union{Rosenbrock23, Rosenbrock32}, u, rate_prototype, ::Type{uEltypeNoUnits},
13299
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
133100
dt, reltol, p, calck,
134101
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
@@ -147,7 +114,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
147114
recursivefill!(atmp, false)
148115
weight = similar(u, uEltypeNoUnits)
149116
recursivefill!(weight, false)
150-
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
117+
tab = RosenbrockCombinedTableau(constvalue(uBottomEltypeNoUnits))
151118
tf = TimeGradientWrapper(f, uprev, p)
152119
uf = UJacobianWrapper(f, t, p)
153120
linsolve_tmp = zero(rate_prototype)
@@ -170,61 +137,13 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
170137
algebraic_vars = f.mass_matrix === I ? nothing :
171138
[all(iszero, x) for x in eachcol(f.mass_matrix)]
172139

173-
Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
140+
RosenbrockCombinedCache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
174141
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
175142
linsolve_tmp,
176143
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
177144
alg.stage_limiter!)
178145
end
179146

180-
function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
181-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
182-
dt, reltol, p, calck,
183-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
184-
k₁ = zero(rate_prototype)
185-
k₂ = zero(rate_prototype)
186-
k₃ = zero(rate_prototype)
187-
du1 = zero(rate_prototype)
188-
du2 = zero(rate_prototype)
189-
# f₀ = zero(u) fsalfirst
190-
f₁ = zero(rate_prototype)
191-
fsalfirst = zero(rate_prototype)
192-
fsallast = zero(rate_prototype)
193-
dT = zero(rate_prototype)
194-
tmp = zero(rate_prototype)
195-
atmp = similar(u, uEltypeNoUnits)
196-
recursivefill!(atmp, false)
197-
weight = similar(u, uEltypeNoUnits)
198-
recursivefill!(weight, false)
199-
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))
200-
201-
tf = TimeGradientWrapper(f, uprev, p)
202-
uf = UJacobianWrapper(f, t, p)
203-
linsolve_tmp = zero(rate_prototype)
204-
205-
grad_config = build_grad_config(alg, f, tf, du1, t)
206-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
207-
208-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
209-
210-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
211-
212-
Pl, Pr = wrapprecs(
213-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
214-
nothing)..., weight, tmp)
215-
linsolve = init(
216-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
217-
Pl = Pl, Pr = Pr,
218-
assumptions = LinearSolve.OperatorAssumptions(true))
219-
220-
algebraic_vars = f.mass_matrix === I ? nothing :
221-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
222-
223-
Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
224-
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
225-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!)
226-
end
227-
228147
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
229148
RosenbrockConstantCache
230149
c₃₂::T
@@ -238,7 +157,7 @@ struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
238157
end
239158

240159
function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
241-
tab = Rosenbrock23Tableau(T)
160+
tab = RosenbrockCombinedTableau(T)
242161
Rosenbrock23ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
243162
end
244163

@@ -268,7 +187,7 @@ struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <:
268187
end
269188

270189
function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
271-
tab = Rosenbrock32Tableau(T)
190+
tab = RosenbrockCombinedTableau(T)
272191
Rosenbrock32ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
273192
end
274193

@@ -832,7 +751,7 @@ function alg_cache(
832751
end
833752

834753
function get_fsalfirstlast(
835-
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache,
754+
cache::Union{RosenbrockCombinedCache, Rosenbrock33Cache,
836755
Rosenbrock34Cache,
837756
Rosenbrock4Cache},
838757
u)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
### Fallbacks to capture
2-
ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
3-
Rosenbrock32ConstantCache, Rosenbrock32Cache,
2+
ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
3+
Rosenbrock32ConstantCache,
44
Rodas23WConstantCache, Rodas3PConstantCache,
55
Rodas23WCache, Rodas3PCache,
66
RosenbrockCombinedConstantCache,
@@ -46,24 +46,24 @@ end
4646
end
4747

4848
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
49-
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache},
49+
cache::RosenbrockCombinedCache,
5050
idxs::Nothing, T::Type{Val{0}}, differential_vars)
5151
@rosenbrock2332pre0
5252
@inbounds @.. y₀ + dt * (c1 * k[1] + c2 * k[2])
5353
end
5454

5555
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
56-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
57-
Rosenbrock32ConstantCache, Rosenbrock32Cache
56+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
57+
Rosenbrock32ConstantCache
5858
}, idxs, T::Type{Val{0}}, differential_vars)
5959
@rosenbrock2332pre0
6060
@.. y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
6161
end
6262

6363
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
6464
cache::Union{Rosenbrock23ConstantCache,
65-
Rosenbrock23Cache,
66-
Rosenbrock32ConstantCache, Rosenbrock32Cache
65+
RosenbrockCombinedCache,
66+
Rosenbrock32ConstantCache
6767
}, idxs::Nothing, T::Type{Val{0}}, differential_vars)
6868
@rosenbrock2332pre0
6969
@inbounds @.. out = y₀ + dt * (c1 * k[1] + c2 * k[2])
@@ -72,8 +72,8 @@ end
7272

7373
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
7474
cache::Union{Rosenbrock23ConstantCache,
75-
Rosenbrock23Cache,
76-
Rosenbrock32ConstantCache, Rosenbrock32Cache
75+
RosenbrockCombinedCache,
76+
Rosenbrock32ConstantCache
7777
}, idxs, T::Type{Val{0}}, differential_vars)
7878
@rosenbrock2332pre0
7979
@views @.. out = y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
@@ -88,25 +88,25 @@ end
8888
end
8989

9090
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
91-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
92-
Rosenbrock32ConstantCache, Rosenbrock32Cache
91+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
92+
Rosenbrock32ConstantCache
9393
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
9494
@rosenbrock2332pre1
9595
@.. c1diff * k[1] + c2diff * k[2]
9696
end
9797

9898
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
99-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
100-
Rosenbrock32ConstantCache, Rosenbrock32Cache
99+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
100+
Rosenbrock32ConstantCache
101101
}, idxs, T::Type{Val{1}}, differential_vars)
102102
@rosenbrock2332pre1
103103
@.. c1diff * k[1][idxs] + c2diff * k[2][idxs]
104104
end
105105

106106
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
107107
cache::Union{Rosenbrock23ConstantCache,
108-
Rosenbrock23Cache,
109-
Rosenbrock32ConstantCache, Rosenbrock32Cache
108+
RosenbrockCombinedCache,
109+
Rosenbrock32ConstantCache
110110
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
111111
@rosenbrock2332pre1
112112
@.. out = c1diff * k[1] + c2diff * k[2]
@@ -115,8 +115,8 @@ end
115115

116116
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
117117
cache::Union{Rosenbrock23ConstantCache,
118-
Rosenbrock23Cache,
119-
Rosenbrock32ConstantCache, Rosenbrock32Cache
118+
RosenbrockCombinedCache,
119+
Rosenbrock32ConstantCache
120120
}, idxs, T::Type{Val{1}}, differential_vars)
121121
@rosenbrock2332pre1
122122
@views @.. out = c1diff * k[1][idxs] + c2diff * k[2][idxs]

0 commit comments

Comments
 (0)