Skip to content

Commit fce0f9f

Browse files
Simplify tstop flag logic per feedback
- Use simple ifelse on original_dt < distance_to_tstop for flag setting - Remove unnecessary complexity in flag detection - Handle t snapping in fixed_t_for_floatingpoint_error! - Reset flag when snapping to tstop target - Don't modify t in handle_tstop_step! - let normal flow handle it 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 97e035e commit fce0f9f

File tree

2 files changed

+186
-35
lines changed

2 files changed

+186
-35
lines changed

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -82,40 +82,31 @@ function modify_dt_for_tstops!(integrator)
8282
original_dt = abs(integrator.dt)
8383

8484
if integrator.opts.adaptive
85-
new_dt = min(original_dt, distance_to_tstop)
86-
integrator.dt = integrator.tdir * new_dt
87-
88-
# Check if dt was significantly shrunk for tstop
89-
if new_dt < original_dt * 0.999
85+
if original_dt < distance_to_tstop
86+
# Normal step, no tstop interference
87+
integrator.next_step_tstop = false
88+
else
89+
# Distance is smaller, entering tstop snap mode
9090
integrator.next_step_tstop = true
9191
integrator.tstop_target = integrator.tdir * tdir_tstop
92-
93-
# If dt became extremely small (< eps(t)), flag for special handling
94-
eps_threshold = eps(abs(integrator.t))
95-
if new_dt < eps_threshold
96-
integrator.dt = integrator.tdir * eps_threshold # Minimal non-zero dt
97-
end
98-
else
99-
integrator.next_step_tstop = false
10092
end
93+
integrator.dt = integrator.tdir * min(original_dt, distance_to_tstop)
10194
elseif iszero(integrator.dtcache) && integrator.dtchangeable
102-
new_dt = distance_to_tstop
103-
integrator.dt = integrator.tdir * new_dt
95+
integrator.dt = integrator.tdir * distance_to_tstop
10496
integrator.next_step_tstop = true
10597
integrator.tstop_target = integrator.tdir * tdir_tstop
10698
elseif integrator.dtchangeable && !integrator.force_stepfail
10799
# always try to step! with dtcache, but lower if a tstop
108100
# however, if force_stepfail then don't set to dtcache, and no tstop worry
109-
new_dt = min(abs(integrator.dtcache), distance_to_tstop)
110-
integrator.dt = integrator.tdir * new_dt
111-
112-
# Check if dt was reduced for tstop
113-
if new_dt < abs(integrator.dtcache) * 0.999
101+
if abs(integrator.dtcache) < distance_to_tstop
102+
# Normal step with dtcache, no tstop interference
103+
integrator.next_step_tstop = false
104+
else
105+
# Distance is smaller, entering tstop snap mode
114106
integrator.next_step_tstop = true
115107
integrator.tstop_target = integrator.tdir * tdir_tstop
116-
else
117-
integrator.next_step_tstop = false
118108
end
109+
integrator.dt = integrator.tdir * min(abs(integrator.dtcache), distance_to_tstop)
119110
else
120111
integrator.next_step_tstop = false
121112
end
@@ -125,27 +116,18 @@ function modify_dt_for_tstops!(integrator)
125116
end
126117

127118
function handle_tstop_step!(integrator)
128-
# Check if dt became extremely small (< eps(t))
119+
# Check if dt is extremely small (< eps(t))
129120
eps_threshold = eps(abs(integrator.t))
130121

131122
if abs(integrator.dt) < eps_threshold
132-
# Skip perform_step! entirely for tiny dt, just snap to tstop
133-
integrator.t = integrator.tstop_target
134-
# Keep u and other states unchanged (no physics step)
123+
# Skip perform_step! entirely for tiny dt
135124
integrator.accept_step = true
136125
else
137-
# Normal step but with guaranteed exact tstop snapping
126+
# Normal step
138127
perform_step!(integrator, integrator.cache)
139-
# After the step, snap exactly to tstop to eliminate floating-point errors
140-
integrator.t = integrator.tstop_target
141-
integrator.accept_step = true
142128
end
143129

144-
# Reset the flag for next iteration
145-
integrator.next_step_tstop = false
146-
147-
# Mark that we hit a tstop for callback handling
148-
integrator.just_hit_tstop = true
130+
# Flag will be reset in fixed_t_for_floatingpoint_error! when t is updated
149131
end
150132

151133
# Want to extend savevalues! for DDEIntegrator
@@ -386,6 +368,13 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts
386368
end
387369

388370
function fixed_t_for_floatingpoint_error!(integrator, ttmp)
371+
# If we're in tstop snap mode, use exact tstop target
372+
if integrator.next_step_tstop
373+
# Reset the flag now that we're snapping to tstop
374+
integrator.next_step_tstop = false
375+
return integrator.tstop_target
376+
end
377+
389378
if has_tstop(integrator)
390379
tstop = integrator.tdir * first_tstop(integrator)
391380
if abs(ttmp - tstop) <

test/tstop_robustness_tests.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
using OrdinaryDiffEqVerner, StaticArrays, Test
2+
using OrdinaryDiffEqCore: handle_tstop_step!
3+
4+
# Test cases for the tstop robustness fix with next_step_tstop flag
5+
@testset "Tstop Robustness Tests" begin
6+
7+
@testset "Basic tstop flag functionality" begin
8+
# Simple ODE problem to test flag behavior
9+
function simple_ode!(du, u, p, t)
10+
du[1] = -u[1]
11+
du[2] = u[1] - u[2]
12+
end
13+
14+
function simple_ode(u, p, t)
15+
[-u[1], u[1] - u[2]]
16+
end
17+
18+
u0_array = [1.0, 0.0]
19+
u0_static = SVector{2}(1.0, 0.0)
20+
tspan = (0.0, 1.0)
21+
22+
# Test with regular arrays
23+
prob_array = ODEProblem(simple_ode!, u0_array, tspan)
24+
sol_array = solve(prob_array, Vern9(); reltol=1e-12, abstol=1e-12,
25+
tstops=[0.5], save_everystep=false)
26+
@test sol_array.retcode == :Success
27+
@test 0.5 in sol_array.t # Should have saved at tstop
28+
29+
# Test with StaticArrays - should work now without tstop error
30+
prob_static = ODEProblem(simple_ode, u0_static, tspan)
31+
sol_static = solve(prob_static, Vern9(); reltol=1e-12, abstol=1e-12,
32+
tstops=[0.5], save_everystep=false)
33+
@test sol_static.retcode == :Success
34+
@test 0.5 in sol_static.t # Should have saved at tstop
35+
36+
# Solutions should be very close despite different array types
37+
@test isapprox(sol_array(1.0), sol_static(1.0), rtol=1e-10)
38+
end
39+
40+
@testset "Tiny tstop step handling" begin
41+
# Test case where tstop is very close to current time
42+
function test_ode(u, p, t)
43+
[u[1]] # Simple growth
44+
end
45+
46+
u0 = SVector{1}(1.0)
47+
tspan = (0.0, 1.0)
48+
49+
# Create tstop very close to start time (would cause tiny dt)
50+
tiny_tstop = 1e-15
51+
52+
prob = ODEProblem(test_ode, u0, tspan)
53+
sol = solve(prob, Vern9(); tstops=[tiny_tstop], save_everystep=false)
54+
55+
@test sol.retcode == :Success
56+
@test tiny_tstop in sol.t # Should handle tiny tstop correctly
57+
end
58+
59+
@testset "Multiple close tstops" begin
60+
# Test with multiple tstops that are very close together
61+
function growth_ode(u, p, t)
62+
[0.1 * u[1]]
63+
end
64+
65+
u0 = SVector{1}(1.0)
66+
tspan = (0.0, 2.0)
67+
68+
# Multiple tstops close together
69+
close_tstops = [0.5, 0.5 + 1e-14, 0.5 + 2e-14, 1.0]
70+
71+
prob = ODEProblem(growth_ode, u0, tspan)
72+
sol = solve(prob, Vern9(); tstops=close_tstops, reltol=1e-12, abstol=1e-12)
73+
74+
@test sol.retcode == :Success
75+
# All tstops should be handled correctly
76+
for tstop in close_tstops
77+
@test any(abs.(sol.t .- tstop) .< 1e-12) # Should have hit each tstop
78+
end
79+
end
80+
81+
@testset "Extreme precision with StaticArrays" begin
82+
# Test the specific case that was failing: extreme precision + StaticArrays
83+
function precise_dynamics(u, p, t)
84+
# Simplified electromagnetic-like dynamics
85+
x = @view u[1:2]
86+
v = @view u[3:4]
87+
88+
# Simple force model
89+
dv = -0.01 * x + 1e-6 * sin(1000*t) * [1, 1]
90+
91+
return SVector{4}(v[1], v[2], dv[1], dv[2])
92+
end
93+
94+
# Initial conditions similar to the original issue
95+
u0 = SVector{4}(1.0, -0.5, 0.01, 0.01)
96+
tspan = (-1.0, 1.0)
97+
98+
# Test with extreme tolerances that originally caused issues
99+
prob = ODEProblem(precise_dynamics, u0, tspan)
100+
sol = solve(prob, Vern9(); reltol=1e-12, abstol=1e-15,
101+
tstops=[0.0], save_everystep=false, maxiters=10^6)
102+
103+
@test sol.retcode == :Success
104+
@test 0.0 in sol.t
105+
end
106+
107+
@testset "Flag state management" begin
108+
# Test that flags are properly set and reset
109+
function flag_test_ode(u, p, t)
110+
[u[1]]
111+
end
112+
113+
u0 = SVector{1}(1.0)
114+
prob = ODEProblem(flag_test_ode, u0, (0.0, 2.0))
115+
116+
# Create integrator manually to inspect flag states
117+
integrator = init(prob, Vern9(); tstops=[1.0])
118+
119+
# Initially, flag should be false
120+
@test integrator.next_step_tstop == false
121+
122+
# Step until we approach the tstop
123+
while integrator.t < 0.9
124+
step!(integrator)
125+
# Flag should still be false when not near tstop
126+
@test integrator.next_step_tstop == false
127+
end
128+
129+
# Take steps near tstop - flag should get set
130+
while integrator.t < 1.0
131+
step!(integrator)
132+
# When dt is reduced for tstop, flag should be set
133+
if integrator.next_step_tstop
134+
@test integrator.tstop_target 1.0
135+
break
136+
end
137+
end
138+
139+
# After hitting tstop, flag should be reset
140+
step!(integrator)
141+
@test integrator.next_step_tstop == false
142+
143+
finalize!(integrator)
144+
end
145+
146+
@testset "Backward time integration" begin
147+
# Test that the fix works for backward time integration too
148+
function backward_ode(u, p, t)
149+
[-u[1]] # Decay
150+
end
151+
152+
u0 = SVector{1}(1.0)
153+
tspan = (1.0, 0.0) # Backward integration
154+
155+
prob = ODEProblem(backward_ode, u0, tspan)
156+
sol = solve(prob, Vern9(); tstops=[0.5], reltol=1e-12, abstol=1e-12)
157+
158+
@test sol.retcode == :Success
159+
@test 0.5 in sol.t
160+
end
161+
162+
end

0 commit comments

Comments
 (0)