Skip to content

Commit 7b714be

Browse files
Merge pull request #4018 from SciML/as/gpu-stuff
fix: support GPUs again
2 parents ab858cf + 6a34ef8 commit 7b714be

File tree

5 files changed

+45
-14
lines changed

5 files changed

+45
-14
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ function SciMLBase.remake_initialization_data(
669669
end
670670

671671
promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T
672-
promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T
672+
promote_type_with_nothing(::Type{T}, ::StaticVector{0}) where {T} = T
673673
function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2}
674674
promote_type(T, T2)
675675
end
@@ -678,7 +678,7 @@ function promote_type_with_nothing(::Type{T}, p::MTKParameters) where {T}
678678
end
679679

680680
promote_with_nothing(::Type, ::Nothing) = nothing
681-
promote_with_nothing(::Type, x::SizedVector{0}) = x
681+
promote_with_nothing(::Type, x::StaticVector{0}) = x
682682
promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x
683683
function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2}
684684
if ArrayInterface.ismutable(x)

src/systems/parameter_buffer.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@ struct MTKParameters{T, I, D, C, N, H}
1010
constant::C
1111
nonnumeric::N
1212
caches::H
13+
14+
function MTKParameters{T, I, D, C, N, H}(tunables::T, initials::I, discrete::D,
15+
constant::C, nonnumeric::N,
16+
caches::H) where {T, I, D, C, N, H}
17+
if tunables isa StaticVector{0}
18+
tunables = SVector{0, eltype(tunables)}()
19+
end
20+
if initials isa StaticVector{0}
21+
initials = SVector{0, eltype(initials)}()
22+
end
23+
return new{typeof(tunables), typeof(initials), D, C, N, H}(tunables, initials,
24+
discrete, constant,
25+
nonnumeric, caches)
26+
end
27+
function MTKParameters(tunables::T, initials::I, discrete::D,
28+
constant::C, nonnumeric::N,
29+
caches::H) where {T, I, D, C, N, H}
30+
return MTKParameters{T, I, D, C, N, H}(tunables, initials, discrete, constant,
31+
nonnumeric, caches)
32+
end
1333
end
1434

1535
"""
@@ -138,11 +158,11 @@ function MTKParameters(
138158
end
139159
tunable_buffer = narrow_buffer_type(tunable_buffer; p_constructor)
140160
if isempty(tunable_buffer)
141-
tunable_buffer = SizedVector{0, Float64}()
161+
tunable_buffer = SVector{0, Float64}()
142162
end
143163
initials_buffer = narrow_buffer_type(initials_buffer; p_constructor)
144164
if isempty(initials_buffer)
145-
initials_buffer = SizedVector{0, Float64}()
165+
initials_buffer = SVector{0, Float64}()
146166
end
147167
disc_buffer = narrow_buffer_type.(disc_buffer; p_constructor)
148168
const_buffer = narrow_buffer_type.(const_buffer; p_constructor)
@@ -879,10 +899,10 @@ end
879899
@generated function Base.getindex(
880900
ps::MTKParameters{T, I, D, C, N, H}, idx::Int) where {T, I, D, C, N, H}
881901
paths = []
882-
if !(T <: SizedVector{0})
902+
if !(T <: StaticVector{0})
883903
push!(paths, :(ps.tunable))
884904
end
885-
if !(I <: SizedVector{0})
905+
if !(I <: StaticVector{0})
886906
push!(paths, :(ps.initials))
887907
end
888908
for i in 1:fieldcount(D)
@@ -909,10 +929,10 @@ end
909929
@generated function Base.length(ps::MTKParameters{
910930
T, I, D, C, N, H}) where {T, I, D, C, N, H}
911931
len = 0
912-
if !(T <: SizedVector{0})
932+
if !(T <: SVector{0})
913933
len += 1
914934
end
915-
if !(I <: SizedVector{0})
935+
if !(I <: SVector{0})
916936
len += 1
917937
end
918938
len += fieldcount(D) + fieldcount(C) + fieldcount(N) + fieldcount(H)

src/systems/problem_utils.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
780780
# `syms[1]` is always the tunables because `srcsys` will have initials.
781781
tunable_syms = syms[1]
782782
tunable_getter = if isempty(tunable_syms)
783-
Returns(SizedVector{0, Float64}())
783+
Returns(SVector{0, Float64}())
784784
else
785785
p_constructor concrete_getu(srcsys, tunable_syms; eval_expression, eval_module)
786786
end
@@ -803,7 +803,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
803803
end
804804
p_constructor concrete_getu(srcsys, initsyms; eval_expression, eval_module)
805805
else
806-
Returns(SizedVector{0, Float64}())
806+
Returns(SVector{0, Float64}())
807807
end
808808
discs_getter = if isempty(syms[3])
809809
Returns(())
@@ -923,7 +923,7 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
923923
if newp isa MTKParameters
924924
# and initials portion
925925
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
926-
if eltype(buf) != T
926+
if eltype(buf) != T && !(buf isa SVector{0})
927927
newbuf = similar(buf, T)
928928
copyto!(newbuf, buf)
929929
newp = repack(newbuf)
@@ -1148,8 +1148,10 @@ function maybe_build_initialization_problem(
11481148
if is_split(sys)
11491149
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp)
11501150
initp = repack(floatT.(buffer))
1151-
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp)
1152-
initp = repack(floatT.(buffer))
1151+
if !(initp.initials isa StaticVector{0})
1152+
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp)
1153+
initp = repack(floatT.(buffer))
1154+
end
11531155
elseif initp isa AbstractArray
11541156
if ArrayInterface.ismutable(initp)
11551157
initp′ = similar(initp, floatT)

test/mtkparameters.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,3 +429,13 @@ end
429429
grad = ForwardDiff.gradient(Base.Fix2(loss, (setter, prob)), [3.0])
430430
@test grad [0.14882627068752538] atol=1e-10
431431
end
432+
433+
@testset "MTKParameters can be made `isbits`" begin
434+
@variables x(t)
435+
@parameters p
436+
@named sys = System(D(x) ~ x * p, t)
437+
sys = complete(sys)
438+
prob = ODEProblem(sys, SA[x => 1.0, p => 1.0], (0.0, 1.0))
439+
@test isbits(prob.p)
440+
@test isbits(prob.f.initialization_data.initializeprob.p)
441+
end

test/split_parameters.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using BlockArrays: BlockedArray
66
using ModelingToolkit: t_nounits as t, D_nounits as D
77
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
88
using SciMLStructures: Tunable, Discrete, Constants, Initials
9-
using StaticArrays: SizedVector
109
using SymbolicIndexingInterface: is_parameter, getp
1110

1211
x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]

0 commit comments

Comments
 (0)