Skip to content

Commit b128e05

Browse files
reneSchmxsaschakoannawendlerMaxBetzDLR
authored
468 ABM parameter study (#1395)
- Refactor ParameterStudy, so it can use any parameter type and more generic simulations. - Add example for an ABM study. - Update python bindings, behavior is mostly unchanged. Co-authored-by: Sascha Korf <51127093+xsaschako@users.noreply.github.com> Co-authored-by: annawendler <106674756+annawendler@users.noreply.github.com> Co-authored-by: MaxBetz <104758467+MaxBetzDLR@users.noreply.github.com>
1 parent 47f4820 commit b128e05

28 files changed

+971
-533
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ set(CMAKE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
7474
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
7575
set(CMAKE_INSTALL_RPATH "${CMAKE_BINARY_DIR}/lib" "${CMAKE_BINARY_DIR}/bin")
7676

77-
file(TO_CMAKE_PATH "${PROJECT_SOURCE_DIR}/.." MEMILIO_BASE_DIR)
77+
# sets MEMILIO_BASE_DIR to the directory containing cpp (i.e., the root of the git repo)
78+
cmake_path(CONVERT "${PROJECT_SOURCE_DIR}/.." TO_CMAKE_PATH_LIST MEMILIO_BASE_DIR NORMALIZE)
7879

7980
# code coverage analysis
8081
# Note: this only works under linux and with make

cpp/examples/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ add_executable(abm_minimal_example abm_minimal.cpp)
108108
target_link_libraries(abm_minimal_example PRIVATE memilio abm)
109109
target_compile_options(abm_minimal_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS})
110110

111+
if(MEMILIO_HAS_HDF5)
112+
add_executable(abm_parameter_study_example abm_parameter_study.cpp)
113+
target_link_libraries(abm_parameter_study_example PRIVATE memilio abm)
114+
target_compile_options(abm_parameter_study_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS})
115+
endif()
116+
111117
add_executable(abm_history_example abm_history_object.cpp)
112118
target_link_libraries(abm_history_example PRIVATE memilio abm)
113119
target_compile_options(abm_history_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS})

cpp/examples/abm_minimal.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "abm/lockdown_rules.h"
2222
#include "abm/model.h"
2323
#include "abm/common_abm_loggers.h"
24-
#include "memilio/utils/abstract_parameter_distribution.h"
2524

2625
#include <fstream>
2726

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/*
2+
* Copyright (C) 2020-2025 MEmilio
3+
*
4+
* Authors: Rene Schmieding, Sascha Korf
5+
*
6+
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
#include "abm/result_simulation.h"
21+
#include "abm/household.h"
22+
#include "abm/lockdown_rules.h"
23+
#include "abm/model.h"
24+
#include "abm/time.h"
25+
26+
#include "memilio/compartments/parameter_studies.h"
27+
#include "memilio/data/analyze_result.h"
28+
#include "memilio/io/io.h"
29+
#include "memilio/io/result_io.h"
30+
#include "memilio/utils/base_dir.h"
31+
#include "memilio/utils/logging.h"
32+
#include "memilio/utils/miompi.h"
33+
#include "memilio/utils/random_number_generator.h"
34+
#include "memilio/utils/stl_util.h"
35+
36+
#include <string>
37+
38+
constexpr size_t num_age_groups = 4;
39+
40+
/// An ABM setup taken from abm_minimal.cpp.
41+
mio::abm::Model make_model(const mio::RandomNumberGenerator& rng)
42+
{
43+
44+
const auto age_group_0_to_4 = mio::AgeGroup(0);
45+
const auto age_group_5_to_14 = mio::AgeGroup(1);
46+
const auto age_group_15_to_34 = mio::AgeGroup(2);
47+
const auto age_group_35_to_59 = mio::AgeGroup(3);
48+
// Create the model with 4 age groups.
49+
auto model = mio::abm::Model(num_age_groups);
50+
model.get_rng() = rng;
51+
52+
// Set same infection parameter for all age groups. For example, the incubation period is log normally distributed with parameters 4 and 1.
53+
model.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionLogNormal(4., 1.);
54+
55+
// Set the age groups that can go to school; here this is AgeGroup(1) (i.e. 5-14)
56+
model.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
57+
model.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
58+
// Set the age groups that can go to work; here these are AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59)
59+
model.parameters.get<mio::abm::AgeGroupGotoWork>().set_multiple({age_group_15_to_34, age_group_35_to_59}, true);
60+
61+
// Check if the parameters satisfy their constraints.
62+
model.parameters.check_constraints();
63+
64+
// There are 10 households for each household group.
65+
int n_households = 10;
66+
67+
// For more than 1 family households we need families. These are parents and children and randoms (which are distributed like the data we have for these households).
68+
auto child = mio::abm::HouseholdMember(num_age_groups); // A child is 50/50% 0-4 or 5-14.
69+
child.set_age_weight(age_group_0_to_4, 1);
70+
child.set_age_weight(age_group_5_to_14, 1);
71+
72+
auto parent = mio::abm::HouseholdMember(num_age_groups); // A parent is 50/50% 15-34 or 35-59.
73+
parent.set_age_weight(age_group_15_to_34, 1);
74+
parent.set_age_weight(age_group_35_to_59, 1);
75+
76+
// Two-person household with one parent and one child.
77+
auto twoPersonHousehold_group = mio::abm::HouseholdGroup();
78+
auto twoPersonHousehold_full = mio::abm::Household();
79+
twoPersonHousehold_full.add_members(child, 1);
80+
twoPersonHousehold_full.add_members(parent, 1);
81+
twoPersonHousehold_group.add_households(twoPersonHousehold_full, n_households);
82+
add_household_group_to_model(model, twoPersonHousehold_group);
83+
84+
// Three-person household with two parent and one child.
85+
auto threePersonHousehold_group = mio::abm::HouseholdGroup();
86+
auto threePersonHousehold_full = mio::abm::Household();
87+
threePersonHousehold_full.add_members(child, 1);
88+
threePersonHousehold_full.add_members(parent, 2);
89+
threePersonHousehold_group.add_households(threePersonHousehold_full, n_households);
90+
add_household_group_to_model(model, threePersonHousehold_group);
91+
92+
// Add one social event with 5 maximum contacts.
93+
// Maximum contacts limit the number of people that a person can infect while being at this location.
94+
auto event = model.add_location(mio::abm::LocationType::SocialEvent);
95+
model.get_location(event).get_infection_parameters().set<mio::abm::MaximumContacts>(5);
96+
// Add hospital and ICU with 5 maximum contacs.
97+
auto hospital = model.add_location(mio::abm::LocationType::Hospital);
98+
model.get_location(hospital).get_infection_parameters().set<mio::abm::MaximumContacts>(5);
99+
auto icu = model.add_location(mio::abm::LocationType::ICU);
100+
model.get_location(icu).get_infection_parameters().set<mio::abm::MaximumContacts>(5);
101+
// Add one supermarket, maximum constacts are assumed to be 20.
102+
auto shop = model.add_location(mio::abm::LocationType::BasicsShop);
103+
model.get_location(shop).get_infection_parameters().set<mio::abm::MaximumContacts>(20);
104+
// At every school, the maximum contacts are 20.
105+
auto school = model.add_location(mio::abm::LocationType::School);
106+
model.get_location(school).get_infection_parameters().set<mio::abm::MaximumContacts>(20);
107+
// At every workplace, maximum contacts are 20.
108+
auto work = model.add_location(mio::abm::LocationType::Work);
109+
model.get_location(work).get_infection_parameters().set<mio::abm::MaximumContacts>(20);
110+
111+
// Increase aerosol transmission for all locations
112+
model.parameters.get<mio::abm::AerosolTransmissionRates>() = 10.0;
113+
// Increase contact rate for all people between 15 and 34 (i.e. people meet more often in the same location)
114+
model.get_location(work)
115+
.get_infection_parameters()
116+
.get<mio::abm::ContactRates>()[{age_group_15_to_34, age_group_15_to_34}] = 10.0;
117+
118+
// People can get tested at work (and do this with 0.5 probability) from time point 0 to day 10.
119+
auto validity_period = mio::abm::days(1);
120+
auto probability = 0.5;
121+
auto start_date = mio::abm::TimePoint(0);
122+
auto end_date = mio::abm::TimePoint(0) + mio::abm::days(10);
123+
auto test_type = mio::abm::TestType::Antigen;
124+
auto test_parameters = model.parameters.get<mio::abm::TestData>()[test_type];
125+
auto testing_criteria_work = mio::abm::TestingCriteria();
126+
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
127+
test_parameters, probability);
128+
model.get_testing_strategy().add_scheme(mio::abm::LocationType::Work, testing_scheme_work);
129+
130+
// Assign infection state to each person.
131+
// The infection states are chosen randomly with the following distribution
132+
std::vector<ScalarType> infection_distribution{0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0};
133+
for (auto& person : model.get_persons()) {
134+
mio::abm::InfectionState infection_state = mio::abm::InfectionState(
135+
mio::DiscreteDistribution<size_t>::get_instance()(mio::thread_local_rng(), infection_distribution));
136+
auto person_rng = mio::abm::PersonalRandomNumberGenerator(person);
137+
if (infection_state != mio::abm::InfectionState::Susceptible) {
138+
person.add_new_infection(mio::abm::Infection(person_rng, mio::abm::VirusVariant::Wildtype, person.get_age(),
139+
model.parameters, start_date, infection_state));
140+
}
141+
}
142+
143+
// Assign locations to the people
144+
for (auto& person : model.get_persons()) {
145+
const auto id = person.get_id();
146+
//assign shop and event
147+
model.assign_location(id, event);
148+
model.assign_location(id, shop);
149+
//assign hospital and ICU
150+
model.assign_location(id, hospital);
151+
model.assign_location(id, icu);
152+
//assign work/school to people depending on their age
153+
if (person.get_age() == age_group_5_to_14) {
154+
model.assign_location(id, school);
155+
}
156+
if (person.get_age() == age_group_15_to_34 || person.get_age() == age_group_35_to_59) {
157+
model.assign_location(id, work);
158+
}
159+
}
160+
161+
// During the lockdown, social events are closed for 90% of people.
162+
auto t_lockdown = mio::abm::TimePoint(0) + mio::abm::days(10);
163+
mio::abm::close_social_events(t_lockdown, 0.9, model.parameters);
164+
165+
return model;
166+
}
167+
168+
int main()
169+
{
170+
mio::mpi::init();
171+
172+
mio::set_log_level(mio::LogLevel::warn);
173+
174+
// Set start and end time for the simulation.
175+
auto t0 = mio::abm::TimePoint(0);
176+
auto tmax = t0 + mio::abm::days(5);
177+
// Set the number of simulations to run in the study
178+
const size_t num_runs = 3;
179+
180+
// Create a parameter study.
181+
// Note that the study for the ABM currently does not make use of the arguments "parameters" or "dt", as we create
182+
// a new model for each simulation. Hence we set both arguments to 0.
183+
// This is mostly due to https://github.com/SciCompMod/memilio/issues/1400
184+
mio::ParameterStudy study(0, t0, tmax, mio::abm::TimeSpan(0), num_runs);
185+
186+
// Optional: set seeds to get reproducable results
187+
// study.get_rng().seed({12341234, 53456, 63451, 5232576, 84586, 52345});
188+
189+
const std::string result_dir = mio::path_join(mio::base_dir(), "example_results");
190+
if (!mio::create_directory(result_dir)) {
191+
mio::log_error("Could not create result directory \"{}\".", result_dir);
192+
return 1;
193+
}
194+
195+
auto ensemble_results = study.run(
196+
[](auto, auto t0_, auto, size_t) {
197+
return mio::abm::ResultSimulation(make_model(mio::thread_local_rng()), t0_);
198+
},
199+
[result_dir](auto&& sim, auto&& run_idx) {
200+
auto interpolated_result = mio::interpolate_simulation_result(sim.get_result());
201+
std::string outpath = mio::path_join(result_dir, "abm_minimal_run_" + std::to_string(run_idx) + ".txt");
202+
std::ofstream outfile_run(outpath);
203+
sim.get_result().print_table(outfile_run, {"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4);
204+
std::cout << "Results written to " << outpath << std::endl;
205+
auto params = std::vector<mio::abm::Model>{};
206+
return std::vector{interpolated_result};
207+
});
208+
209+
if (ensemble_results.size() > 0) {
210+
auto ensemble_results_p05 = ensemble_percentile(ensemble_results, 0.05);
211+
auto ensemble_results_p25 = ensemble_percentile(ensemble_results, 0.25);
212+
auto ensemble_results_p50 = ensemble_percentile(ensemble_results, 0.50);
213+
auto ensemble_results_p75 = ensemble_percentile(ensemble_results, 0.75);
214+
auto ensemble_results_p95 = ensemble_percentile(ensemble_results, 0.95);
215+
216+
mio::unused(save_result(ensemble_results_p05, {0}, num_age_groups,
217+
mio::path_join(result_dir, "Results_" + std::string("p05") + ".h5")));
218+
mio::unused(save_result(ensemble_results_p25, {0}, num_age_groups,
219+
mio::path_join(result_dir, "Results_" + std::string("p25") + ".h5")));
220+
mio::unused(save_result(ensemble_results_p50, {0}, num_age_groups,
221+
mio::path_join(result_dir, "Results_" + std::string("p50") + ".h5")));
222+
mio::unused(save_result(ensemble_results_p75, {0}, num_age_groups,
223+
mio::path_join(result_dir, "Results_" + std::string("p75") + ".h5")));
224+
mio::unused(save_result(ensemble_results_p95, {0}, num_age_groups,
225+
mio::path_join(result_dir, "Results_" + std::string("p95") + ".h5")));
226+
}
227+
228+
mio::mpi::finalize();
229+
230+
return 0;
231+
}

cpp/examples/ode_secir_parameter_study.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
* See the License for the specific language governing permissions and
1818
* limitations under the License.
1919
*/
20+
#include "memilio/config.h"
21+
#include "memilio/utils/base_dir.h"
22+
#include "memilio/utils/miompi.h"
23+
#include "memilio/utils/stl_util.h"
24+
#include "ode_secir/model.h"
2025
#include "ode_secir/parameters_io.h"
2126
#include "ode_secir/parameter_space.h"
2227
#include "memilio/compartments/parameter_studies.h"
23-
#include "memilio/mobility/metapopulation_mobility_instant.h"
2428
#include "memilio/io/result_io.h"
2529

2630
/**
@@ -30,13 +34,10 @@
3034
* @param t0 starting point of simulation
3135
* @param tmax end point of simulation
3236
*/
33-
mio::IOResult<void>
34-
write_single_run_result(const size_t run,
35-
const mio::Graph<mio::SimulationNode<ScalarType, mio::osecir::Simulation<ScalarType>>,
36-
mio::MobilityEdge<ScalarType>>& graph)
37+
mio::IOResult<void> write_single_run_result(const size_t run, const mio::osecir::Simulation<ScalarType>& sim)
3738
{
38-
std::string abs_path;
39-
BOOST_OUTCOME_TRY(auto&& created, mio::create_directory("results", abs_path));
39+
std::string abs_path = mio::path_join(mio::base_dir(), "example_results");
40+
BOOST_OUTCOME_TRY(auto&& created, mio::create_directory(abs_path));
4041

4142
if (run == 0) {
4243
std::cout << "Results are stored in " << abs_path << '\n';
@@ -46,44 +47,29 @@ write_single_run_result(const size_t run,
4647
}
4748

4849
//write sampled parameters for this run
49-
//omit edges to save space as they are not sampled
50-
int inode = 0;
51-
for (auto&& node : graph.nodes()) {
52-
BOOST_OUTCOME_TRY(auto&& js_node_model, serialize_json(node.property.get_result(), mio::IOF_OmitDistributions));
53-
Json::Value js_node(Json::objectValue);
54-
js_node["NodeId"] = node.id;
55-
js_node["Model"] = js_node_model;
56-
auto node_filename = mio::path_join(abs_path, "Parameters_run" + std::to_string(run) + "_node" +
57-
std::to_string(inode++) + ".json");
58-
BOOST_OUTCOME_TRY(mio::write_json(node_filename, js_node));
59-
}
50+
auto node_filename = mio::path_join(abs_path, "Parameters_run" + std::to_string(run) + ".json");
51+
BOOST_OUTCOME_TRY(mio::write_json(node_filename, sim.get_result()));
6052

6153
//write results for this run
6254
std::vector<mio::TimeSeries<ScalarType>> all_results;
6355
std::vector<int> ids;
6456

65-
ids.reserve(graph.nodes().size());
66-
all_results.reserve(graph.nodes().size());
67-
std::transform(graph.nodes().begin(), graph.nodes().end(), std::back_inserter(all_results), [](auto& node) {
68-
return node.property.get_result();
69-
});
70-
std::transform(graph.nodes().begin(), graph.nodes().end(), std::back_inserter(ids), [](auto& node) {
71-
return node.id;
72-
});
73-
auto num_groups = (int)(size_t)graph.nodes()[0].property.get_simulation().get_model().parameters.get_num_groups();
74-
BOOST_OUTCOME_TRY(mio::save_result(all_results, ids, num_groups,
75-
mio::path_join(abs_path, ("Results_run" + std::to_string(run) + ".h5"))));
57+
BOOST_OUTCOME_TRY(mio::save_result({sim.get_result()}, {0}, (int)sim.get_model().parameters.get_num_groups().get(),
58+
mio::path_join(abs_path, "Results_run" + std::to_string(run) + ".h5")));
7659

7760
return mio::success();
7861
}
7962

8063
int main()
8164
{
82-
mio::set_log_level(mio::LogLevel::debug);
65+
mio::mpi::init();
66+
mio::set_log_level(mio::LogLevel::warn);
8367

8468
ScalarType t0 = 0;
8569
ScalarType tmax = 50;
70+
ScalarType dt = 0.1;
8671

72+
// set up model with parameters
8773
ScalarType cont_freq = 10; // see Polymod study
8874

8975
ScalarType num_total_t0 = 10000, num_exp_t0 = 100, num_inf_t0 = 50, num_car_t0 = 50, num_hosp_t0 = 20,
@@ -139,22 +125,30 @@ int main()
139125
return -1;
140126
}
141127

142-
//create study
143-
auto num_runs = size_t(1);
144-
mio::ParameterStudy<ScalarType, mio::osecir::Simulation<ScalarType>> parameter_study(model, t0, tmax, num_runs);
128+
// create study
129+
auto num_runs = size_t(3);
130+
mio::ParameterStudy parameter_study(model, t0, tmax, dt, num_runs);
145131

146-
//run study
147-
auto sample_graph = [](auto&& graph) {
148-
return mio::osecir::draw_sample(graph);
132+
// set up for run
133+
auto sample_graph = [](const auto& model_, ScalarType t0_, ScalarType dt_, size_t) {
134+
mio::osecir::Model<ScalarType> copy = model_;
135+
mio::osecir::draw_sample(copy);
136+
return mio::osecir::Simulation<ScalarType>(std::move(copy), t0_, dt_);
149137
};
150-
auto handle_result = [](auto&& graph, auto&& run) {
151-
auto write_result_status = write_single_run_result(run, graph);
138+
auto handle_result = [](auto&& sim, auto&& run) {
139+
auto write_result_status = write_single_run_result(run, sim);
152140
if (!write_result_status) {
153141
std::cout << "Error writing result: " << write_result_status.error().formatted_message();
154142
}
155-
return 0; //Result handler must return something, but only meaningful when using MPI.
156143
};
144+
145+
// Optional: set seeds to get reproducable results
146+
// parameter_study.get_rng().seed({1456, 157456, 521346, 35345, 6875, 6435});
147+
148+
// run study
157149
parameter_study.run(sample_graph, handle_result);
158150

151+
mio::mpi::finalize();
152+
159153
return 0;
160154
}

0 commit comments

Comments
 (0)