Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
cmake_minimum_required(VERSION 3.16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is unrelated to the functionality in the PR title

project(tesseract_decoder LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

include(FetchContent)
find_package(Threads REQUIRED)

# === External dependencies ===
# Stim
FetchContent_Declare(
stim
GIT_REPOSITORY https://github.com/quantumlib/stim.git
GIT_TAG bd60b73525fd5a9b30839020eb7554ad369e4337
)
FetchContent_MakeAvailable(stim)

# HiGHS
FetchContent_Declare(
highs
URL https://github.com/ERGO-Code/HiGHS/archive/refs/tags/v1.9.0.tar.gz
URL_HASH SHA256=dff575df08d88583c109702c7c5c75ff6e51611e6eacca8b5b3fdfba8ecc2cb4
)
FetchContent_MakeAvailable(highs)

# argparse (header only)
FetchContent_Declare(
argparse
URL https://github.com/p-ranav/argparse/archive/refs/tags/v3.1.zip
URL_HASH SHA256=3e5a59ab7688dcd1f918bc92051a10564113d4f36c3bbed3ef596c25e519a062
)
FetchContent_MakeAvailable(argparse)

# nlohmann_json (header only)
FetchContent_Declare(
nlohmann_json
URL https://github.com/nlohmann/json/archive/9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03.zip
)
FetchContent_MakeAvailable(nlohmann_json)

# Boost headers
FetchContent_Declare(
boost
URL https://archives.boost.io/release/1.83.0/source/boost_1_83_0.tar.gz
URL_HASH SHA256=c0685b68dd44cc46574cce86c4e17c0f611b15e195be9848dfd0769a0a207628
)
FetchContent_MakeAvailable(boost)
add_library(boost_headers INTERFACE)
target_include_directories(boost_headers INTERFACE ${boost_SOURCE_DIR})

# pybind11
FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG v2.11.1
)
set(PYBIND11_TEST OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(pybind11)

set(OPT_COPTS -Ofast -fno-fast-math -march=native)

set(TESSERACT_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src)

# === Libraries ===
add_library(common ${TESSERACT_SRC_DIR}/common.cc ${TESSERACT_SRC_DIR}/common.h)
target_include_directories(common PUBLIC ${TESSERACT_SRC_DIR})
target_compile_options(common PRIVATE ${OPT_COPTS})
target_link_libraries(common PUBLIC libstim Threads::Threads)

add_library(utils ${TESSERACT_SRC_DIR}/utils.cc ${TESSERACT_SRC_DIR}/utils.h)
target_include_directories(utils PUBLIC ${TESSERACT_SRC_DIR})
target_compile_options(utils PRIVATE ${OPT_COPTS})
target_link_libraries(utils PUBLIC common libstim Threads::Threads)

add_library(tesseract_lib ${TESSERACT_SRC_DIR}/tesseract.cc ${TESSERACT_SRC_DIR}/tesseract.h)
target_include_directories(tesseract_lib PUBLIC ${TESSERACT_SRC_DIR})
target_compile_options(tesseract_lib PRIVATE ${OPT_COPTS})
target_link_libraries(tesseract_lib PUBLIC utils boost_headers)

add_library(simplex ${TESSERACT_SRC_DIR}/simplex.cc ${TESSERACT_SRC_DIR}/simplex.h)
target_include_directories(simplex PUBLIC ${TESSERACT_SRC_DIR})
target_compile_options(simplex PRIVATE ${OPT_COPTS})
target_link_libraries(simplex PUBLIC common utils tesseract_lib highs libstim Threads::Threads)

# === Executables ===
add_executable(tesseract ${TESSERACT_SRC_DIR}/tesseract_main.cc)
target_compile_options(tesseract PRIVATE ${OPT_COPTS})
target_link_libraries(tesseract PRIVATE tesseract_lib argparse::argparse nlohmann_json::nlohmann_json)

add_executable(simplex_bin ${TESSERACT_SRC_DIR}/simplex_main.cc)
set_target_properties(simplex_bin PROPERTIES OUTPUT_NAME simplex)
target_compile_options(simplex_bin PRIVATE ${OPT_COPTS})
target_link_libraries(simplex_bin PRIVATE common simplex argparse::argparse nlohmann_json::nlohmann_json)

# === Python module ===
pybind11_add_module(tesseract_decoder MODULE ${TESSERACT_SRC_DIR}/tesseract.pybind.cc)
target_compile_options(tesseract_decoder PRIVATE ${OPT_COPTS})
target_include_directories(tesseract_decoder PRIVATE ${TESSERACT_SRC_DIR})
target_link_libraries(tesseract_decoder PRIVATE common utils simplex tesseract_lib)

18 changes: 15 additions & 3 deletions src/py/simplex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,33 @@ def test_create_simplex_config():
assert sc.window_length == 5
assert (
str(sc)
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)"
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0)"
)


def test_create_simplex_decoder():
decoder = tesseract_decoder.simplex.SimplexDecoder(
tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
)
decoder.init_ilp()
decoder.decode_to_errors([1])
assert decoder.get_observables_from_errors([1]) == []
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
assert decoder.decode([1]) == []


def test_simplex_verbose_callback_receives_output():
lines = []

def cb(s: str) -> None:
lines.append(s)

config = tesseract_decoder.simplex.SimplexConfig(
_DETECTOR_ERROR_MODEL, window_length=5, verbose_callback=cb
)
decoder = tesseract_decoder.simplex.SimplexDecoder(config)
decoder.decode_to_errors([1])
assert any(lines)

def test_simplex_decoder_predicts_various_observable_flips():
"""
Tests that the Simplex decoder correctly predicts a logical observable
Expand All @@ -67,7 +80,6 @@ def test_simplex_decoder_predicts_various_observable_flips():
# Initialize SimplexConfig and SimplexDecoder with the generated DEM
config = tesseract_decoder.simplex.SimplexConfig(dem, window_length=1) # window_length must be set
decoder = tesseract_decoder.simplex.SimplexDecoder(config)
decoder.init_ilp() # Initialize the ILP solver

# Simulate a detection event on D0.
# The decoder should identify the most likely error causing D0,
Expand Down
16 changes: 15 additions & 1 deletion src/py/tesseract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def test_create_config():
assert (
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
)
assert (
tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem
Expand All @@ -52,6 +52,20 @@ def test_create_decoder():
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
assert decoder.decode([0]) == []


def test_tesseract_verbose_callback_receives_output():
lines = []

def cb(s: str) -> None:
lines.append(s)

config = tesseract_decoder.tesseract.TesseractConfig(
_DETECTOR_ERROR_MODEL, verbose_callback=cb
)
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
decoder.decode_to_errors([0])
assert any(lines)

def test_tesseract_decoder_predicts_various_observable_flips():
"""
Tests that the Tesseract decoder correctly predicts a logical observable
Expand Down
66 changes: 42 additions & 24 deletions src/simplex.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,36 @@
#include "simplex.h"

#include <cassert>
#include <iostream>

#include "Highs.h"
#include "io/HMPSIO.h"

constexpr size_t T_COORD = 2;

namespace {
void highs_log_cb(HighsLogType, const char* msg, void* user_data) {
CallbackStream* stream = static_cast<CallbackStream*>(user_data);
(*stream) << msg;
stream->flush();
}
} // namespace

std::string SimplexConfig::str() {
auto& self = *this;
std::stringstream ss;
ss << "SimplexConfig(";
ss << "dem=" << "DetectorErrorModel_Object" << ", ";
ss << "window_length=" << self.window_length << ", ";
ss << "window_slide_length=" << self.window_slide_length << ", ";
ss << "verbose=" << self.verbose << ")";
ss << "window_slide_length=" << self.window_slide_length << ")";
return ss.str();
}

SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) {
if (!config.verbose_callback) {
config.verbose_callback = [](const std::string& s) { std::cout << s; };
}
config.log_stream.callback = config.verbose_callback;
config.dem = common::remove_zero_probability_errors(config.dem);
std::vector<double> detector_t_coords(config.dem.count_detectors());
for (const stim::DemInstruction& instruction : config.dem.flattened().instructions) {
Expand Down Expand Up @@ -152,7 +164,11 @@ void SimplexDecoder::init_ilp() {
// Disabled presolve entirely after encountering bugs similar to this one:
// https://github.com/ERGO-Code/HiGHS/issues/1273
highs->setOptionValue("presolve", "off");
highs->setOptionValue("output_flag", config.verbose);
highs->setOptionValue("output_flag", config.log_stream.active);
highs->setOptionValue("log_to_console", config.log_stream.active);
if (config.log_stream.active) {
highs->setLogCallback(highs_log_cb, &config.log_stream);
}
}

void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
Expand Down Expand Up @@ -197,9 +213,7 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
add_costs_for_time(t1);
++t1;
}
if (config.verbose) {
std::cout << "t0 = " << t0 << " t1 = " << t1 << std::endl;
}
config.log_stream << "t0 = " << t0 << " t1 = " << t1 << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also have this: if (config.log_stream.active) { for this log_stream?


// Pass the model to HiGHS
*return_status = highs->passModel(*model);
Expand Down Expand Up @@ -235,16 +249,19 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
}
assert(*return_status == HighsStatus::kOk);

if (config.verbose) {
// Get the solution information
if (config.log_stream.active) {
const HighsInfo& info = highs->getInfo();
std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl;
std::cout << "Objective function value: " << info.objective_function_value << std::endl;
std::cout << "Primal solution status: "
<< highs->solutionStatusToString(info.primal_solution_status) << std::endl;
std::cout << "Dual solution status: "
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl;
config.log_stream << "Simplex iteration count: " << info.simplex_iteration_count
<< std::endl;
config.log_stream << "Objective function value: " << info.objective_function_value
<< std::endl;
config.log_stream << "Primal solution status: "
<< highs->solutionStatusToString(info.primal_solution_status)
<< std::endl;
config.log_stream << "Dual solution status: "
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
config.log_stream << "Basis: " << highs->basisValidityToString(info.basis_validity)
<< std::endl;
}

// Get the model status
Expand Down Expand Up @@ -286,16 +303,17 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
*return_status = highs->run();
assert(*return_status == HighsStatus::kOk);

if (config.verbose) {
// Get the solution information
if (config.log_stream.active) {
const HighsInfo& info = highs->getInfo();
std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl;
std::cout << "Objective function value: " << info.objective_function_value << std::endl;
std::cout << "Primal solution status: "
<< highs->solutionStatusToString(info.primal_solution_status) << std::endl;
std::cout << "Dual solution status: "
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl;
config.log_stream << "Simplex iteration count: " << info.simplex_iteration_count << std::endl;
config.log_stream << "Objective function value: " << info.objective_function_value
<< std::endl;
config.log_stream << "Primal solution status: "
<< highs->solutionStatusToString(info.primal_solution_status) << std::endl;
config.log_stream << "Dual solution status: "
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
config.log_stream << "Basis: " << highs->basisValidityToString(info.basis_validity)
<< std::endl;
}

// Get the model status
Expand Down
10 changes: 7 additions & 3 deletions src/simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#ifndef SIMPLEX_HPP
#define SIMPLEX_HPP
#include <functional>
#include <unordered_set>
#include <vector>

#include "common.h"
#include "stim.h"
#include "utils.h"

struct HighsModel;
struct Highs;
Expand All @@ -29,7 +31,8 @@ struct SimplexConfig {
bool parallelize = false;
size_t window_length = 0;
size_t window_slide_length = 0;
bool verbose = false;
std::function<void(const std::string&)> verbose_callback;
CallbackStream log_stream;
bool windowing_enabled() {
return (window_length != 0);
}
Expand All @@ -56,8 +59,6 @@ struct SimplexDecoder {

SimplexDecoder(SimplexConfig config);

void init_ilp();

// Clears the predicted_errors_buffer and fills it with the decoded errors for
// these detection events.
void decode_to_errors(const std::vector<uint64_t>& detections);
Expand All @@ -73,6 +74,9 @@ struct SimplexDecoder {
std::vector<std::vector<int>>& obs_predicted);

~SimplexDecoder();

private:
void init_ilp();
};

#endif // SIMPLEX_HPP
27 changes: 22 additions & 5 deletions src/simplex.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <iostream>

#include "common.h"
#include "simplex.h"
#include "stim_utils.pybind.h"
Expand All @@ -28,9 +30,26 @@ namespace py = pybind11;
namespace {
SimplexConfig simplex_config_maker(py::object dem, bool parallelize = false,
size_t window_length = 0, size_t window_slide_length = 0,
bool verbose = false) {
py::object verbose_callback = py::none()) {
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
return SimplexConfig({input_dem, parallelize, window_length, window_slide_length, verbose});
SimplexConfig cfg;
cfg.dem = input_dem;
cfg.parallelize = parallelize;
cfg.window_length = window_length;
cfg.window_slide_length = window_slide_length;
std::function<void(const std::string&)> cb;
bool active = false;
if (!verbose_callback.is_none()) {
py::function f = verbose_callback;
cb = [f](const std::string& s) {
py::gil_scoped_acquire gil;
f(s);
};
active = true;
}
cfg.verbose_callback = cb;
cfg.log_stream = CallbackStream(active, cfg.verbose_callback);
return cfg;
}

}; // namespace
Expand All @@ -42,12 +61,11 @@ void add_simplex_module(py::module& root) {
py::class_<SimplexConfig>(m, "SimplexConfig")
.def(py::init(&simplex_config_maker), py::arg("dem"), py::arg("parallelize") = false,
py::arg("window_length") = 0, py::arg("window_slide_length") = 0,
py::arg("verbose") = false)
py::arg("verbose_callback") = py::none())
.def_property("dem", &dem_getter<SimplexConfig>, &dem_setter<SimplexConfig>)
.def_readwrite("parallelize", &SimplexConfig::parallelize)
.def_readwrite("window_length", &SimplexConfig::window_length)
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)
.def_readwrite("verbose", &SimplexConfig::verbose)
.def("windowing_enabled", &SimplexConfig::windowing_enabled)
.def("__str__", &SimplexConfig::str);

Expand All @@ -62,7 +80,6 @@ void add_simplex_module(py::module& root) {
.def_readwrite("start_time_to_errors", &SimplexDecoder::start_time_to_errors)
.def_readwrite("end_time_to_errors", &SimplexDecoder::end_time_to_errors)
.def_readonly("low_confidence_flag", &SimplexDecoder::low_confidence_flag)
.def("init_ilp", &SimplexDecoder::init_ilp)
.def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections"))
.def(
"get_observables_from_errors",
Expand Down
Loading
Loading