Skip to content

Commit 547adc3

Browse files
draganaurosgrbicnoajshuLalehB
authored
Arbitrary number of observables (#64)
Address issue #61 --------- Signed-off-by: Dragana Grbic <draganaurosgrbic@gmail.com> Co-authored-by: noajshu <shutty@google.com> Co-authored-by: LaLeh <lalehbeni@google.com>
1 parent 17ec792 commit 547adc3

19 files changed

+266
-63
lines changed

src/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ cc_library(
193193
linkopts = OPT_LINKOPTS,
194194
deps = [
195195
":libcommon",
196+
":libutils",
196197
"@highs",
197198
"@stim//:stim_lib",
198199
],

src/common.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ common::Error::Error(const stim::DemInstruction& error) {
3030
assert(probability >= 0 && probability <= 1);
3131

3232
std::set<int> detectors_set;
33-
ObservablesMask observables = 0;
33+
std::set<int> observables_set;
3434

3535
for (const stim::DemTarget& target : error.target_data) {
3636
if (target.is_observable_id()) {
37-
observables ^= (1 << target.val());
37+
if (observables_set.find(target.val()) != observables_set.end()) {
38+
observables_set.erase(target.val());
39+
} else {
40+
observables_set.insert(target.val());
41+
}
3842
} else if (target.is_relative_detector_id()) {
3943
if (detectors_set.find(target.val()) != detectors_set.end()) {
4044
detectors_set.erase(target.val());
@@ -46,6 +50,7 @@ common::Error::Error(const stim::DemInstruction& error) {
4650
// Detectors in the set are already sorted order, which we need so that there
4751
// is a unique canonical representative for each set of detectors.
4852
std::vector<int> detectors(detectors_set.begin(), detectors_set.end());
53+
std::vector<int> observables(observables_set.begin(), observables_set.end());
4954
likelihood_cost = -1 * std::log(probability / (1 - probability));
5055
symptom.detectors = detectors;
5156
symptom.observables = observables;
@@ -60,12 +65,8 @@ std::vector<stim::DemTarget> common::Symptom::as_dem_instruction_targets() const
6065
for (int d : detectors) {
6166
targets.push_back(stim::DemTarget::relative_detector_id(d));
6267
}
63-
common::ObservablesMask mask = observables;
64-
for (size_t oi = 0; oi < 8 * sizeof(common::ObservablesMask); ++oi) {
65-
if (mask & 1) {
66-
targets.push_back(stim::DemTarget::observable_id(oi));
67-
}
68-
mask >>= 1;
68+
for (int o : observables) {
69+
targets.push_back(stim::DemTarget::observable_id(o));
6970
}
7071
return targets;
7172
}

src/common.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,21 @@
1919
#include "stim.h"
2020

2121
namespace common {
22-
using ObservablesMask = std::uint64_t;
2322

2423
// Represents the effect of an error
2524
struct Symptom {
2625
std::vector<int> detectors;
27-
ObservablesMask observables;
26+
std::vector<int> observables;
2827

2928
struct hash {
3029
size_t operator()(const Symptom& s) const {
3130
size_t hash = 0;
3231
for (int i : s.detectors) {
3332
hash += std::hash<int>{}(i);
3433
}
35-
hash ^= s.observables;
34+
for (int i : s.observables) {
35+
hash += std::hash<int>{}(i);
36+
}
3637
return hash;
3738
}
3839
};
@@ -51,11 +52,11 @@ struct Error {
5152
Symptom symptom;
5253
std::vector<bool> dets_array;
5354
Error() = default;
54-
Error(double likelihood_cost, std::vector<int>& detectors, ObservablesMask observables,
55+
Error(double likelihood_cost, std::vector<int>& detectors, std::vector<int> observables,
5556
std::vector<bool>& dets_array)
5657
: likelihood_cost(likelihood_cost), symptom{detectors, observables}, dets_array(dets_array) {}
5758
Error(double likelihood_cost, double probability, std::vector<int>& detectors,
58-
ObservablesMask observables, std::vector<bool>& dets_array)
59+
std::vector<int> observables, std::vector<bool>& dets_array)
5960
: likelihood_cost(likelihood_cost),
6061
probability(probability),
6162
symptom{detectors, observables},

src/common.pybind.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ void add_common_module(py::module &root) {
3030
auto m = root.def_submodule("common", "classes commonly used by the decoder");
3131

3232
py::class_<common::Symptom>(m, "Symptom")
33-
.def(py::init<std::vector<int>, common::ObservablesMask>(),
34-
py::arg("detectors") = std::vector<int>(), py::arg("observables") = 0)
33+
.def(py::init<std::vector<int>, std::vector<int>>(),
34+
py::arg("detectors") = std::vector<int>(), py::arg("observables") = std::vector<int>())
3535
.def_readwrite("detectors", &common::Symptom::detectors)
3636
.def_readwrite("observables", &common::Symptom::observables)
3737
.def("__str__", &common::Symptom::str)
@@ -50,11 +50,10 @@ void add_common_module(py::module &root) {
5050
.def_readwrite("symptom", &common::Error::symptom)
5151
.def("__str__", &common::Error::str)
5252
.def(py::init<>())
53-
.def(py::init<double, std::vector<int> &, common::ObservablesMask, std::vector<bool> &>(),
53+
.def(py::init<double, std::vector<int> &, std::vector<int>, std::vector<bool> &>(),
5454
py::arg("likelihood_cost"), py::arg("detectors"), py::arg("observables"),
5555
py::arg("dets_array"))
56-
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
57-
std::vector<bool> &>(),
56+
.def(py::init<double, double, std::vector<int> &, std::vector<int>, std::vector<bool> &>(),
5857
py::arg("likelihood_cost"), py::arg("probability"), py::arg("detectors"),
5958
py::arg("observables"), py::arg("dets_array"))
6059
.def(py::init([](py::object edi) {

src/common.test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ TEST(common, ErrorsStructFromDemInstruction) {
2323
stim::DemInstruction instruction = dem.instructions.at(0);
2424
common::Error ES(instruction);
2525
EXPECT_EQ(ES.symptom.detectors, std::vector<int>{1});
26-
EXPECT_EQ(ES.symptom.observables, 0b01);
26+
EXPECT_EQ(ES.symptom.observables, std::vector<int>{0});
2727
}
2828

2929
TEST(common, DemFromCountsRejectsZeroProbabilityErrors) {

src/py/common_test.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,29 @@
1717

1818
from src import tesseract_decoder
1919

20+
def get_set_bits(n):
21+
"""
22+
Converts an observable bitmask (integer) into a list of observable indices.
23+
24+
Args:
25+
n (int): The integer representing the observable bitmask.
26+
27+
Returns:
28+
list[int]: A list containing the indices of the set bits (observable IDs)
29+
"""
30+
bits = []
31+
i = 0
32+
33+
while n > 0:
34+
if n & 1:
35+
bits.append(i)
36+
n >>= 1
37+
i += 1
38+
return bits
39+
2040

2141
def test_as_dem_instruction_targets():
22-
s = tesseract_decoder.common.Symptom([1, 2], 4324)
42+
s = tesseract_decoder.common.Symptom([1, 2], get_set_bits(4324))
2343
dits = s.as_dem_instruction_targets()
2444
assert dits == [
2545
stim.DemTarget("D1"),

src/py/simplex_test.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,48 @@ def test_create_simplex_decoder():
4343
)
4444
decoder.init_ilp()
4545
decoder.decode_to_errors([1])
46-
assert decoder.mask_from_errors([1]) == 0
46+
assert decoder.get_observables_from_errors([1]) == []
4747
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
48-
assert decoder.decode([1]) == 0
48+
assert decoder.decode([1]) == []
4949

50+
def test_simplex_decoder_predicts_various_observable_flips():
51+
"""
52+
Tests that the Simplex decoder correctly predicts a logical observable
53+
flip when a specific detector is triggered by an error that explicitly
54+
flips that logical observable.
55+
56+
This test iterates through various observable IDs to ensure the backend logic
57+
correctly handles different positions.
58+
"""
59+
# Iterate through observable IDs from 0 to 63 (inclusive)
60+
for observable_id in range(64):
61+
# Create a simple DetectorErrorModel where an error on D0 also flips L{observable_id}
62+
dem_string = f'''
63+
error(0.01) D0 L{observable_id}
64+
'''
65+
dem = stim.DetectorErrorModel(dem_string)
66+
67+
# Initialize SimplexConfig and SimplexDecoder with the generated DEM
68+
config = tesseract_decoder.simplex.SimplexConfig(dem, window_length=1) # window_length must be set
69+
decoder = tesseract_decoder.simplex.SimplexDecoder(config)
70+
decoder.init_ilp() # Initialize the ILP solver
71+
72+
# Simulate a detection event on D0.
73+
# The decoder should identify the most likely error causing D0,
74+
# which in this DEM is the error that also flips L{observable_id}.
75+
# The decode method is expected to return an array where array[i] is True if observable i is flipped.
76+
predicted_logical_flips_array = decoder.decode(detections=[0])
77+
78+
# Convert the boolean array/list to a list of flipped observable IDs
79+
actual_flipped_observables = [
80+
idx for idx, is_flipped in enumerate(predicted_logical_flips_array) if is_flipped
81+
]
82+
83+
# Assert that the list of actual flipped observables matches the single expected observable_id.
84+
assert actual_flipped_observables == [observable_id], \
85+
(f"For observable L{observable_id}: "
86+
f"Expected predicted logical flips: [{observable_id}], "
87+
f"but got: {actual_flipped_observables} (from raw: {predicted_logical_flips_array})")
5088

5189
if __name__ == "__main__":
5290
raise SystemExit(pytest.main([__file__]))

src/py/tesseract_test.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,48 @@ def test_create_decoder():
4848
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
4949
decoder.decode_to_errors([0])
5050
decoder.decode_to_errors(detections=[0], det_order=0, det_beam=0)
51-
assert decoder.mask_from_errors([1]) == 0
51+
assert decoder.get_observables_from_errors([1]) == []
5252
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
53-
assert decoder.decode([0]) == 0
53+
assert decoder.decode([0]) == []
54+
55+
def test_tesseract_decoder_predicts_various_observable_flips():
56+
"""
57+
Tests that the Tesseract decoder correctly predicts a logical observable
58+
flip when a specific detector is triggered by an error that explicitly
59+
flips that logical observable.
60+
61+
This test iterates through various observable IDs to ensure the backend logic
62+
correctly handles different positions.
63+
"""
64+
# Iterate through observable IDs from 0 to 63 (inclusive)
65+
for observable_id in range(64):
66+
# Create a simple DetectorErrorModel where an error on D0 also flips L{observable_id}
67+
dem_string = f'''
68+
error(0.01) D0 L{observable_id}
69+
'''
70+
dem = stim.DetectorErrorModel(dem_string)
71+
72+
# Initialize TesseractConfig and TesseractDecoder with the generated DEM
73+
config = tesseract_decoder.tesseract.TesseractConfig(dem)
74+
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
75+
76+
# Simulate a detection event on D0.
77+
# The decoder should identify the most likely error causing D0,
78+
# which in this DEM is the error that also flips L{observable_id}.
79+
# The decode method is expected to return an array where array[i] is True if observable i is flipped.
80+
predicted_logical_flips_array = decoder.decode(detections=[0])
81+
82+
# Convert the boolean array/list to a list of flipped observable IDs
83+
actual_flipped_observables = [
84+
idx for idx, is_flipped in enumerate(predicted_logical_flips_array) if is_flipped
85+
]
86+
87+
# Assert that the list of actual flipped observables matches the single expected observable_id.
88+
assert actual_flipped_observables == [observable_id], \
89+
(f"For observable L{observable_id}: "
90+
f"Expected predicted logical flips: [{observable_id}], "
91+
f"but got: {actual_flipped_observables} (from raw: {predicted_logical_flips_array})")
92+
5493

5594

5695
if __name__ == "__main__":

src/simplex.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,41 @@ double SimplexDecoder::cost_from_errors(const std::vector<size_t>& predicted_err
326326
return total_cost;
327327
}
328328

329-
common::ObservablesMask SimplexDecoder::mask_from_errors(
329+
std::vector<int> SimplexDecoder::get_flipped_observables(
330330
const std::vector<size_t>& predicted_errors) {
331-
common::ObservablesMask mask = 0;
332-
// Iterate over all errors and add to the mask
333-
for (size_t ei : predicted_errors_buffer) {
334-
mask ^= errors[ei].symptom.observables;
331+
std::unordered_set<int> flipped_observables_set;
332+
333+
// Iterate over all predicted errors
334+
for (size_t ei : predicted_errors) {
335+
// Iterate over the observables associated with each error
336+
for (int obs_index : errors[ei].symptom.observables) {
337+
// Perform an XOR-like sum using a set.
338+
// If the observable is already in the set, it means we've seen it an
339+
// even number of times, so we remove it.
340+
// If it's not, we add it, which means we've seen it an odd number of times.
341+
if (flipped_observables_set.count(obs_index)) {
342+
flipped_observables_set.erase(obs_index);
343+
} else {
344+
flipped_observables_set.insert(obs_index);
345+
}
346+
}
335347
}
336-
return mask;
348+
349+
// Convert the set to a vector and return it.
350+
std::vector<int> flipped_observables(flipped_observables_set.begin(),
351+
flipped_observables_set.end());
352+
// Sort observables
353+
std::sort(flipped_observables.begin(), flipped_observables.end());
354+
return flipped_observables;
337355
}
338356

339-
common::ObservablesMask SimplexDecoder::decode(const std::vector<uint64_t>& detections) {
357+
std::vector<int> SimplexDecoder::decode(const std::vector<uint64_t>& detections) {
340358
decode_to_errors(detections);
341-
return mask_from_errors(predicted_errors_buffer);
359+
return get_flipped_observables(predicted_errors_buffer);
342360
}
343361

344362
void SimplexDecoder::decode_shots(std::vector<stim::SparseShot>& shots,
345-
std::vector<common::ObservablesMask>& obs_predicted) {
363+
std::vector<std::vector<int>>& obs_predicted) {
346364
obs_predicted.resize(shots.size());
347365
for (size_t i = 0; i < shots.size(); ++i) {
348366
obs_predicted[i] = decode(shots[i].hits);

src/simplex.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#ifndef SIMPLEX_HPP
1616
#define SIMPLEX_HPP
17+
#include <unordered_set>
1718
#include <vector>
1819

1920
#include "common.h"
@@ -41,7 +42,7 @@ struct SimplexDecoder {
4142
size_t num_detectors = 0;
4243
size_t num_observables = 0;
4344
std::vector<size_t> predicted_errors_buffer;
44-
std::vector<common::ObservablesMask> error_masks;
45+
std::vector<std::vector<int>> error_masks;
4546
std::vector<std::vector<size_t>> start_time_to_errors;
4647
std::vector<std::vector<size_t>> end_time_to_errors;
4748

@@ -62,14 +63,14 @@ struct SimplexDecoder {
6263
void decode_to_errors(const std::vector<uint64_t>& detections);
6364
// Returns the bitwise XOR of all the observables bitmasks of all errors in
6465
// the predicted errors buffer.
65-
common::ObservablesMask mask_from_errors(const std::vector<size_t>& predicted_errors);
66+
std::vector<int> get_flipped_observables(const std::vector<size_t>& predicted_errors);
6667
// Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of
6768
// all errors in the predicted errors buffer.
6869
double cost_from_errors(const std::vector<size_t>& predicted_errors);
69-
common::ObservablesMask decode(const std::vector<uint64_t>& detections);
70+
std::vector<int> decode(const std::vector<uint64_t>& detections);
7071

7172
void decode_shots(std::vector<stim::SparseShot>& shots,
72-
std::vector<common::ObservablesMask>& obs_predicted);
73+
std::vector<std::vector<int>>& obs_predicted);
7374

7475
~SimplexDecoder();
7576
};

0 commit comments

Comments
 (0)