diff --git a/src/BUILD b/src/BUILD index e92a614..748b945 100644 --- a/src/BUILD +++ b/src/BUILD @@ -70,6 +70,7 @@ pybind_library( "common.pybind.h", "utils.pybind.h", "simplex.pybind.h", + "visualization.pybind.h", "tesseract.pybind.h", ], deps = [ @@ -113,6 +114,19 @@ cc_library( ], ) +cc_library( + name = "libviz", + srcs = ["visualization.cc"], + hdrs = ["visualization.h"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libutils", + "@boost//:dynamic_bitset", + ], + +) + cc_library( name = "libtesseract", srcs = ["tesseract.cc"], @@ -121,6 +135,7 @@ cc_library( linkopts = OPT_LINKOPTS, deps = [ ":libutils", + ":libviz", "@boost//:dynamic_bitset", ], ) diff --git a/src/common.cc b/src/common.cc index da97a41..f0d4ed7 100644 --- a/src/common.cc +++ b/src/common.cc @@ -14,7 +14,7 @@ #include "common.h" -std::string common::Symptom::str() { +std::string common::Symptom::str() const { std::string s = "Symptom{"; for (size_t d : detectors) { s += "D" + std::to_string(d); @@ -63,7 +63,7 @@ common::Error::Error(const stim::DemInstruction& error) { symptom.observables = observables; } -std::string common::Error::str() { +std::string common::Error::str() const { return "Error{cost=" + std::to_string(likelihood_cost) + ", symptom=" + symptom.str() + "}"; } diff --git a/src/common.h b/src/common.h index 67536a1..e2789d7 100644 --- a/src/common.h +++ b/src/common.h @@ -42,7 +42,7 @@ struct Symptom { bool operator==(const Symptom& other) const { return detectors == other.detectors && observables == other.observables; } - std::string str(); + std::string str() const; }; // Represents an error / weighted hyperedge @@ -53,7 +53,7 @@ struct Error { Error(double likelihood_cost, std::vector& detectors, std::vector observables) : likelihood_cost(likelihood_cost), symptom{detectors, observables} {} Error(const stim::DemInstruction& error); - std::string str(); + std::string str() const; // Get/calculate the probability from the likelihood cost. double get_probability() const; diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py index 2fe6915..b947870 100644 --- a/src/py/tesseract_test.py +++ b/src/py/tesseract_test.py @@ -43,7 +43,7 @@ def test_create_config(): assert ( str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)) - == "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)" + == "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0, create_visualization=0)" ) assert ( tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem diff --git a/src/tesseract.cc b/src/tesseract.cc index b90ef62..73ba422 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -61,7 +61,9 @@ std::string TesseractConfig::str() { ss << "verbose=" << config.verbose << ", "; ss << "pqlimit=" << config.pqlimit << ", "; ss << "det_orders=" << config.det_orders << ", "; - ss << "det_penalty=" << config.det_penalty << ")"; + ss << "det_penalty=" << config.det_penalty << ", "; + ss << "create_visualization=" << config.create_visualization; + ss << ")"; return ss.str(); } @@ -124,6 +126,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) { num_errors = config.dem.count_errors(); num_observables = config.dem.count_observables(); initialize_structures(config.dem.count_detectors()); + if (config.create_visualization) { + auto detectors = get_detector_coords(config.dem); + visualizer.add_detector_coords(detectors); + visualizer.add_errors(errors); + } } void TesseractDecoder::initialize_structures(size_t num_detectors) { @@ -294,6 +301,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples); if (node.num_detectors == 0) { + if (config.create_visualization) { + visualizer.add_activated_errors(node.errors); + visualizer.add_activated_detectors(detectors, num_detectors); + } if (config.verbose) { std::cout << "activated_errors = "; for (size_t oei : node.errors) { @@ -318,6 +329,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (config.no_revisit_dets && !visited_detectors[node.num_detectors].insert(detectors).second) continue; + if (config.create_visualization) { + visualizer.add_activated_errors(node.errors); + visualizer.add_activated_detectors(detectors, num_detectors); + } if (config.verbose) { std::cout.precision(13); std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl; diff --git a/src/tesseract.h b/src/tesseract.h index 9ca8c54..88a7d92 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -25,6 +25,7 @@ #include "common.h" #include "stim.h" #include "utils.h" +#include "visualization.h" constexpr size_t INF_DET_BEAM = std::numeric_limits::max(); @@ -38,6 +39,7 @@ struct TesseractConfig { size_t pqlimit = std::numeric_limits::max(); std::vector> det_orders; double det_penalty = 0; + bool create_visualization = false; std::string str(); }; @@ -64,6 +66,8 @@ struct ErrorCost { struct TesseractDecoder { TesseractConfig config; + Visualizer visualizer; + explicit TesseractDecoder(TesseractConfig config); // Clears the predicted_errors_buffer and fills it with the decoded errors for diff --git a/src/tesseract.pybind.cc b/src/tesseract.pybind.cc index b133285..6342d6e 100644 --- a/src/tesseract.pybind.cc +++ b/src/tesseract.pybind.cc @@ -21,6 +21,7 @@ #include "pybind11/detail/common.h" #include "simplex.pybind.h" #include "utils.pybind.h" +#include "visualization.pybind.h" PYBIND11_MODULE(tesseract_decoder, tesseract) { py::module::import("stim"); @@ -29,6 +30,7 @@ PYBIND11_MODULE(tesseract_decoder, tesseract) { add_utils_module(tesseract); add_simplex_module(tesseract); add_tesseract_module(tesseract); + add_visualization_module(tesseract); // Adds a context manager to the python library that can be used to redirect C++'s stdout/stderr // to python's stdout/stderr at run time like diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index b9be989..bcd35c5 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -37,11 +37,11 @@ TesseractConfig tesseract_config_maker( bool no_revisit_dets = false, bool at_most_two_errors_per_detector = false, bool verbose = false, size_t pqlimit = std::numeric_limits::max(), std::vector> det_orders = std::vector>(), - double det_penalty = 0.0) { + double det_penalty = 0.0, bool create_visualization = false) { stim::DetectorErrorModel input_dem = parse_py_object(dem); return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets, at_most_two_errors_per_detector, verbose, pqlimit, det_orders, - det_penalty}); + det_penalty, create_visualization}); } }; // namespace void add_tesseract_module(py::module& root) { @@ -61,6 +61,7 @@ void add_tesseract_module(py::module& root) { py::arg("at_most_two_errors_per_detector") = false, py::arg("verbose") = false, py::arg("pqlimit") = std::numeric_limits::max(), py::arg("det_orders") = std::vector>(), py::arg("det_penalty") = 0.0, + py::arg("create_visualization") = false, R"pbdoc( The constructor for the `TesseractConfig` class. @@ -86,6 +87,8 @@ void add_tesseract_module(py::module& root) { will generate its own orderings. det_penalty : float, default=0.0 A penalty value added to the cost of each detector visited. + create_visualization: bool, defualt=False + Whether to record the information needed to create a visualization or not. )pbdoc") .def_property("dem", &dem_getter, &dem_setter, "The `stim.DetectorErrorModel` that defines the error channels and detectors.") @@ -106,6 +109,8 @@ void add_tesseract_module(py::module& root) { "A list of pre-specified detector orderings.") .def_readwrite("det_penalty", &TesseractConfig::det_penalty, "The penalty cost added for each detector.") + .def_readwrite("create_visualization", &TesseractConfig::create_visualization, + "If True, records necessary information to create visualization.") .def("__str__", &TesseractConfig::str) .def("compile_decoder", &_compile_tesseract_decoder_helper, py::return_value_policy::take_ownership, @@ -374,7 +379,10 @@ void add_tesseract_module(py::module& root) { .def_readwrite("errors", &TesseractDecoder::errors, "The list of all errors in the detector error model.") .def_readwrite("num_observables", &TesseractDecoder::num_observables, - "The total number of logical observables in the detector error model."); + "The total number of logical observables in the detector error model.") + .def_readonly("visualizer", &TesseractDecoder::visualizer, + "An object that can (if config.create_visualization=True) be used to generate " + "visualization of the algorithm"); } -#endif \ No newline at end of file +#endif diff --git a/src/visualization.cc b/src/visualization.cc new file mode 100644 index 0000000..59943f4 --- /dev/null +++ b/src/visualization.cc @@ -0,0 +1,53 @@ + +#include "visualization.h" + +void Visualizer::add_errors(const std::vector& errors) { + for (auto& error : errors) { + lines.push_back(error.str()); + } +} +void Visualizer::add_detector_coords(const std::vector>& detector_coords) { + for (size_t d = 0; d < detector_coords.size(); ++d) { + std::stringstream ss; + ss << "Detector D" << d << " coordinate ("; + size_t e = std::min(3ul, detector_coords[d].size()); + for (size_t i = 0; i < e; ++i) { + ss << detector_coords[d][i]; + if (i + 1 < e) ss << ", "; + } + ss << ")"; + lines.push_back(ss.str()); + } +} + +void Visualizer::add_activated_errors(const std::vector& activated_errors) { + std::stringstream ss; + ss << "activated_errors = "; + for (size_t oei : activated_errors) { + ss << oei << ", "; + } + lines.push_back(ss.str()); +} + +void Visualizer::add_activated_detectors(const boost::dynamic_bitset<>& detectors, + size_t num_detectors) { + std::stringstream ss; + ss << "activated_detectors = "; + for (size_t d = 0; d < num_detectors; ++d) { + if (detectors[d]) { + ss << d << ", "; + } + } + lines.push_back(ss.str()); +} + +void Visualizer::write(const char* fpath) { + FILE* fout = fopen(fpath, "w"); + + for (std::string& line : lines) { + fprintf(fout, line.c_str()); + fputs("\n", fout); + } + + fclose(fout); +} diff --git a/src/visualization.h b/src/visualization.h new file mode 100644 index 0000000..bda88bb --- /dev/null +++ b/src/visualization.h @@ -0,0 +1,22 @@ +#ifndef _VISUALIZATION_H +#define _VISUALIZATION_H + +#include +#include +#include + +#include "common.h" + +struct Visualizer { + void add_detector_coords(const std::vector> &); + void add_errors(const std::vector &); + void add_activated_errors(const std::vector &); + void add_activated_detectors(const boost::dynamic_bitset<> &, size_t); + + void write(const char *fpath); + + private: + std::list lines; +}; + +#endif diff --git a/src/visualization.pybind.h b/src/visualization.pybind.h new file mode 100644 index 0000000..820bf86 --- /dev/null +++ b/src/visualization.pybind.h @@ -0,0 +1,16 @@ +#include +#include +#include +#include +#include + +#include "visualization.h" + +namespace py = pybind11; + +void add_visualization_module(py::module& root) { + auto m = root.def_submodule("viz", "Module containing the visualization tools"); + py::class_(m, "Visualizer") + .def(py::init<>()) + .def("write", &Visualizer::write, py::arg("fpath")); +}