@@ -8,6 +8,220 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u)
88 (cache. fsalfirst, cache. fsallast)
99end
1010
11+ @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
12+ TabType, TFType, UFType, F, JCType, GCType,
13+ RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
14+ u:: uType
15+ uprev:: uType
16+ k₁:: rateType
17+ k₂:: rateType
18+ k₃:: rateType
19+ du1:: rateType
20+ du2:: rateType
21+ f₁:: rateType
22+ fsalfirst:: rateType
23+ fsallast:: rateType
24+ dT:: rateType
25+ J:: JType
26+ W:: WType
27+ tmp:: rateType
28+ atmp:: uNoUnitsType
29+ weight:: uNoUnitsType
30+ tab:: TabType
31+ tf:: TFType
32+ uf:: UFType
33+ linsolve_tmp:: rateType
34+ linsolve:: F
35+ jac_config:: JCType
36+ grad_config:: GCType
37+ reltol:: RTolType
38+ alg:: A
39+ algebraic_vars:: AV
40+ step_limiter!:: StepLimiter
41+ stage_limiter!:: StageLimiter
42+ end
43+
44+ @cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
45+ TabType, TFType, UFType, F, JCType, GCType,
46+ RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
47+ u:: uType
48+ uprev:: uType
49+ k₁:: rateType
50+ k₂:: rateType
51+ k₃:: rateType
52+ du1:: rateType
53+ du2:: rateType
54+ f₁:: rateType
55+ fsalfirst:: rateType
56+ fsallast:: rateType
57+ dT:: rateType
58+ J:: JType
59+ W:: WType
60+ tmp:: rateType
61+ atmp:: uNoUnitsType
62+ weight:: uNoUnitsType
63+ tab:: TabType
64+ tf:: TFType
65+ uf:: UFType
66+ linsolve_tmp:: rateType
67+ linsolve:: F
68+ jac_config:: JCType
69+ grad_config:: GCType
70+ reltol:: RTolType
71+ alg:: A
72+ algebraic_vars:: AV
73+ step_limiter!:: StepLimiter
74+ stage_limiter!:: StageLimiter
75+ end
76+
77+ function alg_cache (alg:: Rosenbrock23 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
78+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
79+ dt, reltol, p, calck,
80+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
81+ k₁ = zero (rate_prototype)
82+ k₂ = zero (rate_prototype)
83+ k₃ = zero (rate_prototype)
84+ du1 = zero (rate_prototype)
85+ du2 = zero (rate_prototype)
86+ # f₀ = zero(u) fsalfirst
87+ f₁ = zero (rate_prototype)
88+ fsalfirst = zero (rate_prototype)
89+ fsallast = zero (rate_prototype)
90+ dT = zero (rate_prototype)
91+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
92+ tmp = zero (rate_prototype)
93+ atmp = similar (u, uEltypeNoUnits)
94+ recursivefill! (atmp, false )
95+ weight = similar (u, uEltypeNoUnits)
96+ recursivefill! (weight, false )
97+ tab = Rosenbrock23Tableau (constvalue (uBottomEltypeNoUnits))
98+ tf = TimeGradientWrapper (f, uprev, p)
99+ uf = UJacobianWrapper (f, t, p)
100+ linsolve_tmp = zero (rate_prototype)
101+
102+ linprob = LinearProblem (W, _vec (linsolve_tmp); u0 = _vec (tmp))
103+ Pl, Pr = wrapprecs (
104+ alg. precs (W, nothing , u, p, t, nothing , nothing , nothing ,
105+ nothing )... , weight, tmp)
106+ linsolve = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
107+ Pl = Pl, Pr = Pr,
108+ assumptions = LinearSolve. OperatorAssumptions (true ))
109+
110+ grad_config = build_grad_config (alg, f, tf, du1, t)
111+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
112+ algebraic_vars = f. mass_matrix === I ? nothing :
113+ [all (iszero, x) for x in eachcol (f. mass_matrix)]
114+
115+ Rosenbrock23Cache (u, uprev, k₁, k₂, k₃, du1, du2, f₁,
116+ fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
117+ linsolve_tmp,
118+ linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg. step_limiter!,
119+ alg. stage_limiter!)
120+ end
121+
122+ function alg_cache (alg:: Rosenbrock32 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
123+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
124+ dt, reltol, p, calck,
125+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
126+ k₁ = zero (rate_prototype)
127+ k₂ = zero (rate_prototype)
128+ k₃ = zero (rate_prototype)
129+ du1 = zero (rate_prototype)
130+ du2 = zero (rate_prototype)
131+ # f₀ = zero(u) fsalfirst
132+ f₁ = zero (rate_prototype)
133+ fsalfirst = zero (rate_prototype)
134+ fsallast = zero (rate_prototype)
135+ dT = zero (rate_prototype)
136+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
137+ tmp = zero (rate_prototype)
138+ atmp = similar (u, uEltypeNoUnits)
139+ recursivefill! (atmp, false )
140+ weight = similar (u, uEltypeNoUnits)
141+ recursivefill! (weight, false )
142+ tab = Rosenbrock32Tableau (constvalue (uBottomEltypeNoUnits))
143+
144+ tf = TimeGradientWrapper (f, uprev, p)
145+ uf = UJacobianWrapper (f, t, p)
146+ linsolve_tmp = zero (rate_prototype)
147+ linprob = LinearProblem (W, _vec (linsolve_tmp); u0 = _vec (tmp))
148+
149+ Pl, Pr = wrapprecs (
150+ alg. precs (W, nothing , u, p, t, nothing , nothing , nothing ,
151+ nothing )... , weight, tmp)
152+ linsolve = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
153+ Pl = Pl, Pr = Pr,
154+ assumptions = LinearSolve. OperatorAssumptions (true ))
155+ grad_config = build_grad_config (alg, f, tf, du1, t)
156+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
157+ algebraic_vars = f. mass_matrix === I ? nothing :
158+ [all (iszero, x) for x in eachcol (f. mass_matrix)]
159+
160+ Rosenbrock32Cache (u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
161+ tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
162+ grad_config, reltol, alg, algebraic_vars, alg. step_limiter!, alg. stage_limiter!)
163+ end
164+
165+ struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} < :
166+ RosenbrockConstantCache
167+ c₃₂:: T
168+ d:: T
169+ tf:: TF
170+ uf:: UF
171+ J:: JType
172+ W:: WType
173+ linsolve:: F
174+ autodiff:: AD
175+ end
176+
177+ function Rosenbrock23ConstantCache (:: Type{T} , tf, uf, J, W, linsolve, autodiff) where {T}
178+ tab = Rosenbrock23Tableau (T)
179+ Rosenbrock23ConstantCache (tab. c₃₂, tab. d, tf, uf, J, W, linsolve, autodiff)
180+ end
181+
182+ function alg_cache (alg:: Rosenbrock23 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
183+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
184+ dt, reltol, p, calck,
185+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
186+ tf = TimeDerivativeWrapper (f, u, p)
187+ uf = UDerivativeWrapper (f, t, p)
188+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
189+ linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
190+ linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
191+ Rosenbrock23ConstantCache (constvalue (uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
192+ alg_autodiff (alg))
193+ end
194+
195+ struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} < :
196+ RosenbrockConstantCache
197+ c₃₂:: T
198+ d:: T
199+ tf:: TF
200+ uf:: UF
201+ J:: JType
202+ W:: WType
203+ linsolve:: F
204+ autodiff:: AD
205+ end
206+
207+ function Rosenbrock32ConstantCache (:: Type{T} , tf, uf, J, W, linsolve, autodiff) where {T}
208+ tab = Rosenbrock32Tableau (T)
209+ Rosenbrock32ConstantCache (tab. c₃₂, tab. d, tf, uf, J, W, linsolve, autodiff)
210+ end
211+
212+ function alg_cache (alg:: Rosenbrock32 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
213+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
214+ dt, reltol, p, calck,
215+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
216+ tf = TimeDerivativeWrapper (f, u, p)
217+ uf = UDerivativeWrapper (f, t, p)
218+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
219+ linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
220+ linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
221+ Rosenbrock32ConstantCache (constvalue (uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
222+ alg_autodiff (alg))
223+ end
224+
11225# ###############################################################################
12226
13227# Shampine's Low-order Rosenbrocks
@@ -19,7 +233,6 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
19233 du:: rateType
20234 du1:: rateType
21235 du2:: rateType
22- f₁:: rateType
23236 ks:: Vector{rateType}
24237 fsalfirst:: rateType
25238 fsallast:: rateType
@@ -43,6 +256,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
43256 stage_limiter!:: StageLimiter
44257 interp_order:: Int
45258end
259+
46260function full_cache (c:: RosenbrockCache )
47261 return [c. u, c. uprev, c. dense... , c. du, c. du1, c. du2,
48262 c. ks... , c. fsalfirst, c. fsallast, c. dT, c. tmp, c. atmp, c. weight, c. linsolve_tmp]
@@ -98,19 +312,18 @@ tabtype(::Rodas5Pr) = Rodas5PTableau
98312tabtype (:: Rodas5Pe ) = Rodas5PTableau
99313
100314function alg_cache (
101- alg:: Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
315+ alg:: Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
102316 u, rate_prototype, :: Type{uEltypeNoUnits} ,
103317 :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
104318 dt, reltol, p, calck,
105319 :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
106- tab = Rodas5PTableau (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
320+ tab = tabtype (alg) (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
107321 dense = [zero (rate_prototype) for _ in 1 : size (tab. H, 1 )]
108322 du = zero (rate_prototype)
109323 du1 = zero (rate_prototype)
110324 du2 = zero (rate_prototype)
111325 ks = [zero (rate_prototype) for _ in 1 : size (tab. A, 1 )]
112326
113- f₁ = zero (rate_prototype)
114327 fsalfirst = zero (rate_prototype)
115328 fsallast = zero (rate_prototype)
116329 dT = zero (rate_prototype)
@@ -135,14 +348,14 @@ function alg_cache(
135348 jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
136349 algebraic_vars = f. mass_matrix === I ? nothing :
137350 [all (iszero, x) for x in eachcol (f. mass_matrix)]
138- RosenbrockCache (u, uprev, dense, du, du1, du2, ks, f₁, fsalfirst, fsallast,
351+ RosenbrockCache (u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
139352 dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
140353 linsolve, jac_config, grad_config, reltol, alg, algebraic_vars,
141354 alg. step_limiter!, alg. stage_limiter!, size (tab. H, 1 ))
142355end
143356
144357function alg_cache (
145- alg:: Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
358+ alg:: Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
146359 u, rate_prototype, :: Type{uEltypeNoUnits} ,
147360 :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
148361 dt, reltol, p, calck,
@@ -152,7 +365,7 @@ function alg_cache(
152365 J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
153366 linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
154367 linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
155- tab =
368+ tab = tabtype (alg)( constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
156369 RosenbrockCombinedConstantCache (tf, uf, tab, J, W, linsolve, alg_autodiff (alg), size (tab. H, 1 ))
157370end
158371
0 commit comments