Skip to content

Commit 05697ee

Browse files
fixing #92 by updating decode_to_errors (#111)
Co-authored-by: Noureldin <noureldinyosri@gmail.com>
1 parent d3b4459 commit 05697ee

File tree

11 files changed

+143
-63
lines changed

11 files changed

+143
-63
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ eclipse-*bin/
3232
/.sass-cache
3333
# User-specific .bazelrc
3434
user.bazelrc
35+
36+
# Ignore python extension module produced by CMake.
37+
src/tesseract_decoder*.so

AGENTS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Agent Instructions
2+
3+
- Use the **CMake** build system when interacting with this repository. Humans use Bazel.
4+
- A bug in some LLM coding environments makes Bazel difficult to use, so agents should rely on CMake.
5+
- Keep both the CMake and Bazel builds working at all times.

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,11 @@ pybind11_add_module(tesseract_decoder MODULE ${TESSERACT_SRC_DIR}/tesseract.pybi
103103
target_compile_options(tesseract_decoder PRIVATE ${OPT_COPTS})
104104
target_include_directories(tesseract_decoder PRIVATE ${TESSERACT_SRC_DIR})
105105
target_link_libraries(tesseract_decoder PRIVATE common utils simplex tesseract_lib)
106+
set_target_properties(tesseract_decoder PROPERTIES
107+
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/src
108+
LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/src
109+
LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/src
110+
LIBRARY_OUTPUT_DIRECTORY_MINSIZEREL ${PROJECT_SOURCE_DIR}/src
111+
LIBRARY_OUTPUT_DIRECTORY_RELWITHDEBINFO ${PROJECT_SOURCE_DIR}/src
112+
)
106113

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ config = tesseract.TesseractConfig(dem=dem, det_beam=50)
190190
# 3. Create a decoder instance
191191
decoder = config.compile_decoder()
192192

193-
# 4. Simulate detection events
194-
syndrome = [0, 1, 1]
193+
# 4. Simulate detector outcomes
194+
syndrome = np.array([0, 1, 1], dtype=bool)
195195

196196
# 5a. Decode to observables
197197
flipped_observables = decoder.decode(syndrome)
198198
print(f"Flipped observables: {flipped_observables}")
199199

200200
# 5b. Alternatively, decode to errors
201-
decoder.decode_to_errors(np.where(syndrome)[0])
201+
decoder.decode_to_errors(syndrome)
202202
predicted_errors = decoder.predicted_errors_buffer
203203
# Indices of predicted errors
204204
print(f"Predicted errors indices: {predicted_errors}")

src/py/README.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,28 @@ print(f"Custom configuration detection penalty: {config2.det_beam}")
6464
#### Class `tesseract.TesseractDecoder`
6565
This is the main class that implements the Tesseract decoding logic.
6666
* `TesseractDecoder(config: tesseract.TesseractConfig)`
67-
* `decode_to_errors(detections: list[int])`
68-
* `decode_to_errors(detections: list[int], det_order: int, det_beam: int)`
67+
* `decode_to_errors(syndrome: np.ndarray)`
68+
* `decode_to_errors(syndrome: np.ndarray, det_order: int, det_beam: int)`
6969
* `get_observables_from_errors(predicted_errors: list[int]) -> list[bool]`
7070
* `cost_from_errors(predicted_errors: list[int]) -> float`
71-
* `decode(detections: list[int]) -> list[bool]`
71+
* `decode(syndrome: np.ndarray) -> np.ndarray`
7272

7373
Explanation of each method:
74-
#### `decode_to_errors(detections: list[int])`
74+
#### `decode_to_errors(syndrome: np.ndarray)`
7575

7676
Decodes a single measurement shot to predict a list of errors.
7777

78-
* **Parameters:** `detections` is a list of integers that represent the indices of the detectors that have fired in a single shot.
78+
* **Parameters:** `syndrome` is a 1D NumPy array of booleans representing the detector outcomes for a single shot.
7979

8080
* **Returns:** A list of integers, where each integer is the index of a predicted error.
8181

82-
#### `decode_to_errors(detections: list[int], det_order: int, det_beam: int)`
82+
#### `decode_to_errors(syndrome: np.ndarray, det_order: int, det_beam: int)`
8383

8484
An overloaded version of the `decode_to_errors` method that allows for a different decoding strategy.
8585

8686
* **Parameters:**
8787

88-
* `detections` is a list of integers representing the indices of the fired detectors.
88+
* `syndrome` is a 1D NumPy array of booleans representing the detector outcomes for a single shot.
8989

9090
* `det_order` is an integer that specifies a different ordering of detectors to use for the decoding.
9191

@@ -219,17 +219,18 @@ print(f"Configuration verbose enabled: {config.verbose}")
219219
This is the main class for performing decoding using the Simplex algorithm.
220220
* `SimplexDecoder(config: simplex.SimplexConfig)`
221221
* `init_ilp()`
222-
* `decode_to_errors(detections: list[int])`
222+
* `decode_to_errors(syndrome: np.ndarray)`
223223
* `get_observables_from_errors(predicted_errors: list[int]) -> list[bool]`
224224
* `cost_from_errors(predicted_errors: list[int]) -> float`
225-
* `decode(detections: list[int]) -> list[bool]`
225+
* `decode(syndrome: np.ndarray) -> np.ndarray`
226226

227227
**Example Usage**:
228228

229229
```python
230230
import tesseract_decoder.simplex as simplex
231231
import stim
232232
import tesseract_decoder.common as common
233+
import numpy as np
233234

234235
# Create a DEM and a configuration
235236
dem = stim.DetectorErrorModel("""
@@ -245,9 +246,9 @@ decoder = simplex.SimplexDecoder(config)
245246
decoder.init_ilp()
246247

247248
# Decode a shot where detector D1 fired
248-
detections = [1]
249-
flipped_observables = decoder.decode(detections)
250-
print(f"Flipped observables for detections {detections}: {flipped_observables}")
249+
syndrome = np.array([0, 1], dtype=bool)
250+
flipped_observables = decoder.decode(syndrome)
251+
print(f"Flipped observables for syndrome {syndrome.tolist()}: {flipped_observables}")
251252

252253
# Access predicted errors
253254
predicted_error_indices = decoder.predicted_errors_buffer

src/py/shared_decoding_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,16 @@ def shared_test_merge_errors_affects_cost(decoder_class, config_class):
316316
error(0.01) D0
317317
"""
318318
)
319-
detections = [0]
319+
syndrome = np.array([True], dtype=bool)
320320

321321
config_no_merge = config_class(dem, merge_errors=False)
322322
decoder_no_merge = decoder_class(config_no_merge)
323-
predicted_errors_no_merge = decoder_no_merge.decode_to_errors(detections)
323+
predicted_errors_no_merge = decoder_no_merge.decode_to_errors(syndrome)
324324
cost_no_merge = decoder_no_merge.cost_from_errors(decoder_no_merge.predicted_errors_buffer)
325325

326326
config_merge = config_class(dem, merge_errors=True)
327327
decoder_merge = decoder_class(config_merge)
328-
predicted_errors_merge = decoder_merge.decode_to_errors(detections)
328+
predicted_errors_merge = decoder_merge.decode_to_errors(syndrome)
329329
cost_merge = decoder_merge.cost_from_errors(decoder_merge.predicted_errors_buffer)
330330

331331
p_merged = 0.1 * (1 - 0.01) + 0.01 * (1 - 0.1)

src/py/simplex_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pytest
16+
import numpy as np
1617
import stim
1718

1819
from src import tesseract_decoder
@@ -56,7 +57,7 @@ def test_create_simplex_decoder():
5657
decoder = tesseract_decoder.simplex.SimplexDecoder(
5758
tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
5859
)
59-
decoder.decode_to_errors([1])
60+
decoder.decode_to_errors(np.array([False, True], dtype=bool))
6061
assert decoder.get_observables_from_errors([1]) == []
6162
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
6263

src/py/tesseract_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pytest
16+
import numpy as np
1617
import stim
1718

1819
from src import tesseract_decoder
@@ -191,8 +192,10 @@ def test_create_tesseract_config_no_dem_with_custom_args():
191192
def test_create_tesseract_decoder():
192193
config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)
193194
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
194-
decoder.decode_to_errors([0])
195-
decoder.decode_to_errors(detections=[0], det_order=0, det_beam=0)
195+
decoder.decode_to_errors(np.array([True, False], dtype=bool))
196+
decoder.decode_to_errors(
197+
syndrome=np.array([True, False], dtype=bool), det_order=0, det_beam=0
198+
)
196199
assert decoder.get_observables_from_errors([1]) == []
197200
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
198201

src/simplex.pybind.h

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,42 @@ void add_simplex_module(py::module& root) {
140140
141141
This method must be called before decoding.
142142
)pbdoc")
143-
.def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections"),
144-
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(), R"pbdoc(
143+
.def(
144+
"decode_to_errors",
145+
[](SimplexDecoder& self, const py::array_t<bool>& syndrome) {
146+
if ((size_t)syndrome.size() != self.num_detectors) {
147+
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
148+
") does not match the number of detectors in the decoder (" +
149+
std::to_string(self.num_detectors) + ").";
150+
throw std::invalid_argument(msg);
151+
}
152+
153+
std::vector<uint64_t> detections;
154+
auto syndrome_unchecked = syndrome.unchecked<1>();
155+
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
156+
if (syndrome_unchecked(i)) {
157+
detections.push_back(i);
158+
}
159+
}
160+
self.decode_to_errors(detections);
161+
return self.predicted_errors_buffer;
162+
},
163+
py::arg("syndrome"),
164+
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
165+
R"pbdoc(
145166
Decodes a single shot to a list of error indices.
146167
147168
Parameters
148169
----------
149-
detections : list[int]
150-
A list of indices of the detectors that have fired.
170+
syndrome : np.ndarray
171+
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
172+
The length of the array should match the number of detectors in the DEM.
151173
152174
Returns
153175
-------
154176
list[int]
155177
A list of predicted error indices.
156-
)pbdoc")
178+
)pbdoc")
157179
.def(
158180
"get_observables_from_errors",
159181
[](SimplexDecoder& self, const std::vector<size_t>& predicted_errors) {
@@ -228,11 +250,10 @@ void add_simplex_module(py::module& root) {
228250
"decode",
229251
[](SimplexDecoder& self, const py::array_t<bool>& syndrome) {
230252
if ((size_t)syndrome.size() != self.num_detectors) {
231-
std::ostringstream msg;
232-
msg << "Syndrome array size (" << syndrome.size()
233-
<< ") does not match the number of detectors in the decoder ("
234-
<< self.num_detectors << ").";
235-
throw std::invalid_argument(msg.str());
253+
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
254+
") does not match the number of detectors in the decoder (" +
255+
std::to_string(self.num_detectors) + ").";
256+
throw std::invalid_argument(msg);
236257
}
237258

238259
std::vector<uint64_t> detections;
@@ -287,11 +308,11 @@ void add_simplex_module(py::module& root) {
287308
size_t num_detectors = syndromes_unchecked.shape(1);
288309

289310
if (num_detectors != self.num_detectors) {
290-
std::ostringstream msg;
291-
msg << "The number of detectors in the input array (" << num_detectors
292-
<< ") does not match the number of detectors in the decoder ("
293-
<< self.num_detectors << ").";
294-
throw std::invalid_argument(msg.str());
311+
std::string msg = "The number of detectors in the input array (" +
312+
std::to_string(num_detectors) +
313+
") does not match the number of detectors in the decoder (" +
314+
std::to_string(self.num_detectors) + ").";
315+
throw std::invalid_argument(msg);
295316
}
296317

297318
// Allocate the result array.

src/tesseract.pybind.h

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -243,33 +243,73 @@ void add_tesseract_module(py::module& root) {
243243
config : TesseractConfig
244244
The configuration object for the decoder.
245245
)pbdoc")
246-
.def("decode_to_errors",
247-
py::overload_cast<const std::vector<uint64_t>&>(&TesseractDecoder::decode_to_errors),
248-
py::arg("detections"),
249-
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(), R"pbdoc(
246+
.def(
247+
"decode_to_errors",
248+
[](TesseractDecoder& self, const py::array_t<bool>& syndrome) {
249+
if ((size_t)syndrome.size() != self.num_detectors) {
250+
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
251+
") does not match the number of detectors in the decoder (" +
252+
std::to_string(self.num_detectors) + ").";
253+
throw std::invalid_argument(msg);
254+
}
255+
256+
std::vector<uint64_t> detections;
257+
auto syndrome_unchecked = syndrome.unchecked<1>();
258+
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
259+
if (syndrome_unchecked(i)) {
260+
detections.push_back(i);
261+
}
262+
}
263+
self.decode_to_errors(detections);
264+
return self.predicted_errors_buffer;
265+
},
266+
py::arg("syndrome"),
267+
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
268+
R"pbdoc(
250269
Decodes a single shot to a list of error indices.
251270
252271
Parameters
253272
----------
254-
detections : list[int]
255-
A list of indices of the detectors that have fired.
273+
syndrome : np.ndarray
274+
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
275+
The length of the array should match the number of detectors in the DEM.
256276
257277
Returns
258278
-------
259279
list[int]
260280
A list of predicted error indices.
261-
)pbdoc")
262-
.def("decode_to_errors",
263-
py::overload_cast<const std::vector<uint64_t>&, size_t, size_t>(
264-
&TesseractDecoder::decode_to_errors),
265-
py::arg("detections"), py::arg("det_order"), py::arg("det_beam"),
266-
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(), R"pbdoc(
281+
)pbdoc")
282+
.def(
283+
"decode_to_errors",
284+
[](TesseractDecoder& self, const py::array_t<bool>& syndrome, size_t det_order,
285+
size_t det_beam) {
286+
if ((size_t)syndrome.size() != self.num_detectors) {
287+
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
288+
") does not match the number of detectors in the decoder (" +
289+
std::to_string(self.num_detectors) + ").";
290+
throw std::invalid_argument(msg);
291+
}
292+
293+
std::vector<uint64_t> detections;
294+
auto syndrome_unchecked = syndrome.unchecked<1>();
295+
for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) {
296+
if (syndrome_unchecked(i)) {
297+
detections.push_back(i);
298+
}
299+
}
300+
self.decode_to_errors(detections, det_order, det_beam);
301+
return self.predicted_errors_buffer;
302+
},
303+
py::arg("syndrome"), py::arg("det_order"), py::arg("det_beam"),
304+
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>(),
305+
R"pbdoc(
267306
Decodes a single shot using a specific detector ordering and beam size.
268307
269308
Parameters
270309
----------
271-
detections : list[int]
272-
A list of indices of the detectors that have fired.
310+
syndrome : np.ndarray
311+
A 1D NumPy array of booleans representing the detector outcomes for a single shot.
312+
The length of the array should match the number of detectors in the DEM.
273313
det_order : int
274314
The index of the detector ordering to use.
275315
det_beam : int
@@ -279,7 +319,7 @@ void add_tesseract_module(py::module& root) {
279319
-------
280320
list[int]
281321
A list of predicted error indices.
282-
)pbdoc")
322+
)pbdoc")
283323
.def(
284324
"get_observables_from_errors",
285325
[](TesseractDecoder& self, const std::vector<size_t>& predicted_errors) {
@@ -355,11 +395,10 @@ void add_tesseract_module(py::module& root) {
355395
"decode",
356396
[](TesseractDecoder& self, const py::array_t<bool>& syndrome) {
357397
if ((size_t)syndrome.size() != self.num_detectors) {
358-
std::ostringstream msg;
359-
msg << "Syndrome array size (" << syndrome.size()
360-
<< ") does not match the number of detectors in the decoder ("
361-
<< self.num_detectors << ").";
362-
throw std::invalid_argument(msg.str());
398+
std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) +
399+
") does not match the number of detectors in the decoder (" +
400+
std::to_string(self.num_detectors) + ").";
401+
throw std::invalid_argument(msg);
363402
}
364403

365404
std::vector<uint64_t> detections;
@@ -413,11 +452,11 @@ void add_tesseract_module(py::module& root) {
413452
size_t num_detectors = syndromes_unchecked.shape(1);
414453

415454
if (num_detectors != self.num_detectors) {
416-
std::ostringstream msg;
417-
msg << "The number of detectors in the input array (" << num_detectors
418-
<< ") does not match the number of detectors in the decoder ("
419-
<< self.num_detectors << ").";
420-
throw std::invalid_argument(msg.str());
455+
std::string msg = "The number of detectors in the input array (" +
456+
std::to_string(num_detectors) +
457+
") does not match the number of detectors in the decoder (" +
458+
std::to_string(self.num_detectors) + ").";
459+
throw std::invalid_argument(msg);
421460
}
422461

423462
// Allocate the result array.

0 commit comments

Comments
 (0)