Skip to content

Commit 865e339

Browse files
Merge pull request #3931 from AayushSabharwal/as/fix-gen-initsys
fix: fix minor bug in `generate_initializesystem`
2 parents 8e9ef31 + 997f1ed commit 865e339

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,8 @@ function reorder_vars!(state::TearingState, var_eq_matching, var_sccs, eq_orderi
889889
# the new reality of the system we've just created.
890890
new_graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
891891
nsolved_eq, nsolved_var)
892+
new_solvable_graph = contract_variables(solvable_graph, var_eq_matching, varsperm, eqsperm,
893+
nsolved_eq, nsolved_var)
892894

893895
new_var_to_diff = complete(DiffGraph(length(var_ordering)))
894896
for (v, d) in enumerate(var_to_diff)
@@ -919,6 +921,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, var_sccs, eq_orderi
919921

920922
# Update system structure
921923
@set! state.structure.graph = complete(new_graph)
924+
@set! state.structure.solvable_graph = complete(new_solvable_graph)
922925
@set! state.structure.var_to_diff = new_var_to_diff
923926
@set! state.structure.eq_to_diff = new_eq_to_diff
924927
@set! state.fullvars = new_fullvars

src/systems/connectors.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,10 +895,12 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
895895
stream_eqs, instream_subs = expand_instream(instream_csets, sys; tol = tol)
896896

897897
eqs = [equations(sys); ceqs; stream_eqs]
898-
# substitute `instream(..)` expressions with their new values
899-
for i in eachindex(eqs)
900-
eqs[i] = fixpoint_sub(
901-
eqs[i], instream_subs; maxiters = max(length(instream_subs), 10))
898+
if !isempty(instream_subs)
899+
# substitute `instream(..)` expressions with their new values
900+
for i in eachindex(eqs)
901+
eqs[i] = fixpoint_sub(
902+
eqs[i], instream_subs; maxiters = max(length(instream_subs), 10))
903+
end
902904
end
903905
# get the defaults for domain networks
904906
d_defs = domain_defaults(sys, domain_csets)

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
154154
# 3) process other variables
155155
for var in vars
156156
if var keys(op)
157-
push!(eqs_ics, var ~ defs[var])
157+
push!(eqs_ics, var ~ op[var])
158158
elseif var keys(guesses)
159159
push!(defs, var => guesses[var])
160160
elseif check_defguess
@@ -824,8 +824,6 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w
824824
initialization.
825825
"""
826826
function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
827-
subs = Dict()
828-
tempvars = Set()
829827
rm_idxs = Int[]
830828
for (i, eq) in enumerate(obseqs)
831829
iscall(eq.rhs) || continue
@@ -835,20 +833,7 @@ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
835833
end
836834
end
837835

838-
for (i, eq) in enumerate(obseqs)
839-
if eq.lhs in tempvars
840-
subs[eq.lhs] = eq.rhs
841-
push!(rm_idxs, i)
842-
end
843-
end
844-
845836
obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)]
846-
obseqs = map(obseqs) do eq
847-
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
848-
end
849-
eqs = map(eqs) do eq
850-
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
851-
end
852837
return obseqs, eqs
853838
end
854839

src/systems/systemstructure.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ function trivial_tearing!(ts::TearingState)
635635
matched_vars = BitSet()
636636
# variable to index in fullvars
637637
var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars))
638+
sys_eqs = equations(ts)
638639

639640
complete!(ts.structure)
640641
var_to_diff = ts.structure.var_to_diff
@@ -654,6 +655,14 @@ function trivial_tearing!(ts::TearingState)
654655
push!(blacklist, i)
655656
continue
656657
end
658+
# Edge case for `var ~ var` equations. They don't show up in the incidence
659+
# graph because `TearingState` makes them `0 ~ 0`, but they do cause `var`
660+
# to show up twice in `original_eqs` which fails the assertion.
661+
sys_eq = sys_eqs[i]
662+
if isequal(sys_eq.lhs, 0) && isequal(sys_eq.rhs, 0)
663+
continue
664+
end
665+
657666
# if a variable was the LHS of two trivial observed equations, we wouldn't have
658667
# included it in the list. Error if somehow it made it through.
659668
@assert !(vari in matched_vars)

test/extensions/dynamic_optimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ end
243243
D = M.D_nounits
244244

245245
@parameters h_c m₀ h₀ g₀ D_c c Tₘ m_c
246-
@variables h(..) v(..) m(..) [bounds = (m_c, 1)] T(..) [input = true, bounds = (0, Tₘ)]
246+
@variables h(..) v(..) m(..) = m₀ [bounds = (m_c, 1)] T(..) [input = true, bounds = (0, Tₘ)]
247247
drag(h, v) = D_c * v^2 * exp(-h_c * (h - h₀) / h₀)
248248
gravity(h) = g₀ * (h₀ / h)
249249

0 commit comments

Comments
 (0)