Skip to content

Commit cb3cf90

Browse files
Add missing docstrings and tests inside Python interface (#80)
Address issue #54 I have added docstrings where they were missing for each class we have in Python Interface. Also, inside the `shared_decoding_tests` module that I created, I have moved tests that are the same for `Simplex` and `Tesseract`, so we have it more generic. Our Python Interface/wrapper should look much prettier now. --------- Co-authored-by: oscarhiggott <29460323+oscarhiggott@users.noreply.github.com>
1 parent 6da2b8f commit cb3cf90

15 files changed

+923
-239
lines changed

src/common.cc

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ std::string common::Symptom::str() {
2525
}
2626

2727
common::Error::Error(const stim::DemInstruction& error) {
28+
if (error.type != stim::DemInstructionType::DEM_ERROR) {
29+
throw std::invalid_argument(
30+
"Error must be loaded from an error dem instruction, but received: " + error.str());
31+
}
2832
assert(error.type == stim::DemInstructionType::DEM_ERROR);
29-
probability = error.arg_data[0];
30-
assert(probability >= 0 && probability <= 1);
33+
double probability = error.arg_data[0];
34+
if (probability < 0 || probability > 1) {
35+
throw std::invalid_argument("Probability must be between 0 and 1, but received: " +
36+
std::to_string(probability));
37+
}
3138

3239
std::set<int> detectors_set;
3340
std::set<int> observables_set;
@@ -60,6 +67,17 @@ std::string common::Error::str() {
6067
return "Error{cost=" + std::to_string(likelihood_cost) + ", symptom=" + symptom.str() + "}";
6168
}
6269

70+
double common::Error::get_probability() const {
71+
return 1.0 / (1.0 + std::exp(likelihood_cost));
72+
}
73+
74+
void common::Error::set_with_probability(double p) {
75+
if (p <= 0 || p >= 1) {
76+
throw std::invalid_argument("Probability must be between 0 and 1.");
77+
}
78+
likelihood_cost = -std::log(p / (1.0 - p));
79+
}
80+
6381
std::vector<stim::DemTarget> common::Symptom::as_dem_instruction_targets() const {
6482
std::vector<stim::DemTarget> targets;
6583
for (int d : detectors) {
@@ -71,7 +89,15 @@ std::vector<stim::DemTarget> common::Symptom::as_dem_instruction_targets() const
7189
return targets;
7290
}
7391

74-
stim::DetectorErrorModel common::merge_identical_errors(const stim::DetectorErrorModel& dem) {
92+
double common::merge_weights(double a, double b) {
93+
auto sgn = std::copysign(1, a) * std::copysign(1, b);
94+
auto signed_min = sgn * std::min(std::abs(a), std::abs(b));
95+
return signed_min + std::log(1 + std::exp(-std::abs(a + b))) -
96+
std::log(1 + std::exp(-std::abs(a - b)));
97+
}
98+
99+
stim::DetectorErrorModel common::merge_indistinguishable_errors(
100+
const stim::DetectorErrorModel& dem) {
75101
stim::DetectorErrorModel out_dem;
76102

77103
// Map to track the distinct symptoms
@@ -80,13 +106,14 @@ stim::DetectorErrorModel common::merge_identical_errors(const stim::DetectorErro
80106
switch (instruction.type) {
81107
case stim::DemInstructionType::DEM_ERROR: {
82108
Error error(instruction);
83-
assert(error.symptom.detectors.size());
84-
// Merge with existing error with the same symptom (if applicable)
109+
if (error.symptom.detectors.size() == 0) {
110+
throw std::invalid_argument("Errors that do not flip any detectors are not supported.");
111+
}
112+
85113
if (errors_by_symptom.find(error.symptom) != errors_by_symptom.end()) {
86-
double p0 = errors_by_symptom[error.symptom].probability;
87-
error.probability = p0 * (1 - error.probability) + (1 - p0) * error.probability;
114+
error.likelihood_cost = merge_weights(error.likelihood_cost,
115+
errors_by_symptom[error.symptom].likelihood_cost);
88116
}
89-
error.likelihood_cost = -1 * std::log(error.probability / (1 - error.probability));
90117
errors_by_symptom[error.symptom] = error;
91118
break;
92119
}
@@ -103,7 +130,7 @@ stim::DetectorErrorModel common::merge_identical_errors(const stim::DetectorErro
103130
}
104131
}
105132
for (const auto& it : errors_by_symptom) {
106-
out_dem.append_error_instruction(it.second.probability,
133+
out_dem.append_error_instruction(it.second.get_probability(),
107134
it.second.symptom.as_dem_instruction_targets(),
108135
/*tag=*/"");
109136
}

src/common.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,23 @@ struct Symptom {
4848
// Represents an error / weighted hyperedge
4949
struct Error {
5050
double likelihood_cost;
51-
double probability;
5251
Symptom symptom;
53-
std::vector<bool> dets_array;
5452
Error() = default;
55-
Error(double likelihood_cost, std::vector<int>& detectors, std::vector<int> observables,
56-
std::vector<bool>& dets_array)
57-
: likelihood_cost(likelihood_cost), symptom{detectors, observables}, dets_array(dets_array) {}
58-
Error(double likelihood_cost, double probability, std::vector<int>& detectors,
59-
std::vector<int> observables, std::vector<bool>& dets_array)
60-
: likelihood_cost(likelihood_cost),
61-
probability(probability),
62-
symptom{detectors, observables},
63-
dets_array(dets_array) {}
53+
Error(double likelihood_cost, std::vector<int>& detectors, std::vector<int> observables)
54+
: likelihood_cost(likelihood_cost), symptom{detectors, observables} {}
6455
Error(const stim::DemInstruction& error);
6556
std::string str();
57+
58+
// Get/calculate the probability from the likelihood cost.
59+
double get_probability() const;
60+
61+
// Set/calculate the likelihood cost from a probability.
62+
void set_with_probability(double p);
6663
};
6764

6865
// Makes a new (flattened) dem where identical error mechanisms have been
6966
// merged.
70-
stim::DetectorErrorModel merge_identical_errors(const stim::DetectorErrorModel& dem);
67+
stim::DetectorErrorModel merge_indistinguishable_errors(const stim::DetectorErrorModel& dem);
7168

7269
// Returns a copy of the given error model with any zero-probability DEM_ERROR
7370
// instructions removed.
@@ -80,6 +77,12 @@ stim::DetectorErrorModel remove_zero_probability_errors(const stim::DetectorErro
8077
stim::DetectorErrorModel dem_from_counts(stim::DetectorErrorModel& orig_dem,
8178
const std::vector<size_t>& error_counts, size_t num_shots);
8279

80+
/// Computes the weight of an edge resulting from merging edges with weight `a' and weight `b',
81+
/// assuming each edge weight is a log-likelihood ratio log((1-p)/p) associated with the probability
82+
/// p of an error occurring on the edge, and that the error mechanisms associated with the two edges
83+
/// being merged are independent.
84+
double merge_weights(double a, double b);
85+
8386
} // namespace common
8487

8588
#endif

src/common.pybind.h

Lines changed: 150 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,65 +29,187 @@ namespace py = pybind11;
2929
void add_common_module(py::module &root) {
3030
auto m = root.def_submodule("common", "classes commonly used by the decoder");
3131

32-
py::class_<common::Symptom>(m, "Symptom")
32+
py::class_<common::Symptom>(m, "Symptom", R"pbdoc(
33+
Represents a symptom of an error, which is a list of detectors and a list of observables
34+
35+
A symptom is defined by a list of detectors that are flipped and a list of
36+
observables that are flipped.
37+
)pbdoc")
3338
.def(py::init<std::vector<int>, std::vector<int>>(),
34-
py::arg("detectors") = std::vector<int>(), py::arg("observables") = std::vector<int>())
35-
.def_readwrite("detectors", &common::Symptom::detectors)
36-
.def_readwrite("observables", &common::Symptom::observables)
39+
py::arg("detectors") = std::vector<int>(), py::arg("observables") = std::vector<int>(),
40+
R"pbdoc(
41+
The constructor for the `Symptom` class.
42+
43+
Parameters
44+
----------
45+
detectors : list[int], default=[]
46+
The indices of the detectors in this symptom.
47+
observables : list[int], default=[]
48+
The indices of the flipped observables.
49+
)pbdoc")
50+
.def_readwrite("detectors", &common::Symptom::detectors,
51+
"A list of the detector indices that are flipped in this symptom.")
52+
.def_readwrite("observables", &common::Symptom::observables,
53+
"A list of observable indices that are flipped in this symptom.")
3754
.def("__str__", &common::Symptom::str)
3855
.def(py::self == py::self)
3956
.def(py::self != py::self)
40-
.def("as_dem_instruction_targets", [](common::Symptom s) {
41-
std::vector<py::object> ret;
42-
for (auto &t : s.as_dem_instruction_targets())
43-
ret.push_back(make_py_object(t, "DemTarget"));
44-
return ret;
45-
});
46-
47-
py::class_<common::Error>(m, "Error")
48-
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
49-
.def_readwrite("probability", &common::Error::probability)
50-
.def_readwrite("symptom", &common::Error::symptom)
57+
.def(
58+
"as_dem_instruction_targets",
59+
[](common::Symptom s) {
60+
std::vector<py::object> ret;
61+
for (auto &t : s.as_dem_instruction_targets())
62+
ret.push_back(make_py_object(t, "DemTarget"));
63+
return ret;
64+
},
65+
R"pbdoc(
66+
Converts the symptom into a list of `stim.DemTarget` objects.
67+
68+
Returns
69+
-------
70+
list[stim.DemTarget]
71+
A list of `stim.DemTarget` objects representing the detectors and observables.
72+
)pbdoc");
73+
74+
py::class_<common::Error>(m, "Error", R"pbdoc(
75+
Represents an error, including its cost, and symptom.
76+
77+
An error is a physical event (or set of indistinguishable physical events)
78+
defined by the detectors and observables that it flips in the circuit.
79+
)pbdoc")
80+
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost,
81+
"The cost of this error (often log((1 - probability) / probability)).")
82+
.def_readwrite("symptom", &common::Error::symptom, "The symptom associated with this error.")
5183
.def("__str__", &common::Error::str)
52-
.def(py::init<>())
53-
.def(py::init<double, std::vector<int> &, std::vector<int>, std::vector<bool> &>(),
54-
py::arg("likelihood_cost"), py::arg("detectors"), py::arg("observables"),
55-
py::arg("dets_array"))
56-
.def(py::init<double, double, std::vector<int> &, std::vector<int>, std::vector<bool> &>(),
57-
py::arg("likelihood_cost"), py::arg("probability"), py::arg("detectors"),
58-
py::arg("observables"), py::arg("dets_array"))
84+
.def(py::init<>(), R"pbdoc(
85+
Default constructor for the `Error` class.
86+
)pbdoc")
87+
.def(py::init<double, std::vector<int> &, std::vector<int>>(), py::arg("likelihood_cost"),
88+
py::arg("detectors"), py::arg("observables"), R"pbdoc(
89+
Constructor for the `Error` class.
90+
91+
Parameters
92+
----------
93+
likelihood_cost : float
94+
The cost of this error.
95+
This is often `log((1 - probability) / probability)`.
96+
detectors : list[int]
97+
A list of indices of the detectors flipped by this error.
98+
observables : list[int]
99+
A list of indices of the observables flipped by this error.
100+
)pbdoc")
101+
59102
.def(py::init([](py::object edi) {
60103
std::vector<double> args;
61104
std::vector<stim::DemTarget> targets;
62105
auto di = parse_py_dem_instruction(edi, args, targets);
63106
return new common::Error(di);
64107
}),
65-
py::arg("error"));
108+
py::arg("error"), R"pbdoc(
109+
Constructor that creates an `Error` from a `stim.DemInstruction`.
110+
111+
Parameters
112+
----------
113+
error : stim.DemInstruction
114+
The instruction to convert into an `Error` object.
115+
)pbdoc")
116+
.def("get_probability", &common::Error::get_probability,
117+
R"pbdoc(
118+
Gets the probability associated with the likelihood cost.
119+
120+
Returns
121+
-------
122+
float
123+
The probability of the error, calculated from the likelihood cost.
124+
)pbdoc")
125+
.def("set_with_probability", &common::Error::set_with_probability, py::arg("probability"),
126+
R"pbdoc(
127+
Sets the likelihood cost based on a given probability.
128+
129+
Parameters
130+
----------
131+
probability : float
132+
The probability to use for setting the likelihood cost.
133+
Must be between 0 and 1 (exclusive).
134+
135+
Raises
136+
------
137+
ValueError
138+
If the provided probability is not between 0 and 1.
139+
)pbdoc");
66140

67141
m.def(
68-
"merge_identical_errors",
142+
"merge_indistinguishable_errors",
69143
[](py::object dem) {
70144
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
71-
auto res = common::merge_identical_errors(input_dem);
145+
auto res = common::merge_indistinguishable_errors(input_dem);
72146
return make_py_object(res, "DetectorErrorModel");
73147
},
74-
py::arg("dem"));
148+
py::arg("dem"), R"pbdoc(
149+
Merges identical errors in a `stim.DetectorErrorModel`.
150+
151+
Errors are identical if they flip the same set of detectors and observables (the same symptom).
152+
For example, two identical errors with probabilities p1 and p2
153+
would be merged into a single error with the same symptom,
154+
but with probability `p1 * (1 - p2) + p2 * (1 - p1)`
155+
156+
Parameters
157+
----------
158+
dem : stim.DetectorErrorModel
159+
The detector error model to process.
160+
161+
Returns
162+
-------
163+
stim.DetectorErrorModel
164+
A new `DetectorErrorModel` with identical errors merged.
165+
)pbdoc");
75166
m.def(
76167
"remove_zero_probability_errors",
77168
[](py::object dem) {
78169
return make_py_object(
79170
common::remove_zero_probability_errors(parse_py_object<stim::DetectorErrorModel>(dem)),
80171
"DetectorErrorModel");
81172
},
82-
py::arg("dem"));
173+
py::arg("dem"), R"pbdoc(
174+
Removes errors with a probability of 0 from a `stim.DetectorErrorModel`.
175+
176+
Parameters
177+
----------
178+
dem : stim.DetectorErrorModel
179+
The detector error model to process.
180+
181+
Returns
182+
-------
183+
stim.DetectorErrorModel
184+
A new `DetectorErrorModel` with zero-probability errors removed.
185+
)pbdoc");
83186
m.def(
84187
"dem_from_counts",
85188
[](py::object orig_dem, const std::vector<size_t> error_counts, size_t num_shots) {
86189
auto dem = parse_py_object<stim::DetectorErrorModel>(orig_dem);
87190
return make_py_object(common::dem_from_counts(dem, error_counts, num_shots),
88191
"DetectorErrorModel");
89192
},
90-
py::arg("orig_dem"), py::arg("error_counts"), py::arg("num_shots"));
193+
py::arg("orig_dem"), py::arg("error_counts"), py::arg("num_shots"), R"pbdoc(
194+
Re-weights errors in a `stim.DetectorErrorModel` based on observed counts.
195+
196+
This function re-calculates the probability of each error based on a list of
197+
observed counts and the total number of shots.
198+
199+
Parameters
200+
----------
201+
orig_dem : stim.DetectorErrorModel
202+
The original detector error model.
203+
error_counts : list[int]
204+
A list of counts for each error in the DEM.
205+
num_shots : int
206+
The total number of shots in the experiment.
207+
208+
Returns
209+
-------
210+
stim.DetectorErrorModel
211+
A new `DetectorErrorModel` with updated error probabilities.
212+
)pbdoc");
91213
}
92214

93215
#endif

0 commit comments

Comments
 (0)