Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pybind_library(
"common.pybind.h",
"utils.pybind.h",
"simplex.pybind.h",
"visualization.pybind.h",
"tesseract.pybind.h",
],
deps = [
Expand Down Expand Up @@ -113,6 +114,19 @@ cc_library(
],
)

cc_library(
name = "libviz",
srcs = ["visualization.cc"],
hdrs = ["visualization.h"],
copts = OPT_COPTS,
linkopts = OPT_LINKOPTS,
deps = [
":libutils",
"@boost//:dynamic_bitset",
],

)

cc_library(
name = "libtesseract",
srcs = ["tesseract.cc"],
Expand All @@ -121,6 +135,7 @@ cc_library(
linkopts = OPT_LINKOPTS,
deps = [
":libutils",
":libviz",
"@boost//:dynamic_bitset",
],
)
Expand Down
4 changes: 2 additions & 2 deletions src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "common.h"

std::string common::Symptom::str() {
std::string common::Symptom::str() const {
std::string s = "Symptom{";
for (size_t d : detectors) {
s += "D" + std::to_string(d);
Expand Down Expand Up @@ -63,7 +63,7 @@ common::Error::Error(const stim::DemInstruction& error) {
symptom.observables = observables;
}

std::string common::Error::str() {
std::string common::Error::str() const {
return "Error{cost=" + std::to_string(likelihood_cost) + ", symptom=" + symptom.str() + "}";
}

Expand Down
4 changes: 2 additions & 2 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct Symptom {
bool operator==(const Symptom& other) const {
return detectors == other.detectors && observables == other.observables;
}
std::string str();
std::string str() const;
};

// Represents an error / weighted hyperedge
Expand All @@ -53,7 +53,7 @@ struct Error {
Error(double likelihood_cost, std::vector<int>& detectors, std::vector<int> observables)
: likelihood_cost(likelihood_cost), symptom{detectors, observables} {}
Error(const stim::DemInstruction& error);
std::string str();
std::string str() const;

// Get/calculate the probability from the likelihood cost.
double get_probability() const;
Expand Down
2 changes: 1 addition & 1 deletion src/py/tesseract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
def test_create_config():
assert (
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0, create_visualization=0)"
)
assert (
tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem
Expand Down
17 changes: 16 additions & 1 deletion src/tesseract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ std::string TesseractConfig::str() {
ss << "verbose=" << config.verbose << ", ";
ss << "pqlimit=" << config.pqlimit << ", ";
ss << "det_orders=" << config.det_orders << ", ";
ss << "det_penalty=" << config.det_penalty << ")";
ss << "det_penalty=" << config.det_penalty << ", ";
ss << "create_visualization=" << config.create_visualization;
ss << ")";
return ss.str();
}

Expand Down Expand Up @@ -124,6 +126,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
num_errors = config.dem.count_errors();
num_observables = config.dem.count_observables();
initialize_structures(config.dem.count_detectors());
if (config.create_visualization) {
auto detectors = get_detector_coords(config.dem);
visualizer.add_detector_coords(detectors);
visualizer.add_errors(errors);
}
}

void TesseractDecoder::initialize_structures(size_t num_detectors) {
Expand Down Expand Up @@ -294,6 +301,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples);

if (node.num_detectors == 0) {
if (config.create_visualization) {
visualizer.add_activated_errors(node.errors);
visualizer.add_activated_detectors(detectors, num_detectors);
}
if (config.verbose) {
std::cout << "activated_errors = ";
for (size_t oei : node.errors) {
Expand All @@ -318,6 +329,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
if (config.no_revisit_dets && !visited_detectors[node.num_detectors].insert(detectors).second)
continue;

if (config.create_visualization) {
visualizer.add_activated_errors(node.errors);
visualizer.add_activated_detectors(detectors, num_detectors);
}
if (config.verbose) {
std::cout.precision(13);
std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl;
Expand Down
4 changes: 4 additions & 0 deletions src/tesseract.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "common.h"
#include "stim.h"
#include "utils.h"
#include "visualization.h"

constexpr size_t INF_DET_BEAM = std::numeric_limits<uint16_t>::max();

Expand All @@ -38,6 +39,7 @@ struct TesseractConfig {
size_t pqlimit = std::numeric_limits<size_t>::max();
std::vector<std::vector<size_t>> det_orders;
double det_penalty = 0;
bool create_visualization = false;

std::string str();
};
Expand All @@ -64,6 +66,8 @@ struct ErrorCost {

struct TesseractDecoder {
TesseractConfig config;
Visualizer visualizer;

explicit TesseractDecoder(TesseractConfig config);

// Clears the predicted_errors_buffer and fills it with the decoded errors for
Expand Down
2 changes: 2 additions & 0 deletions src/tesseract.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "pybind11/detail/common.h"
#include "simplex.pybind.h"
#include "utils.pybind.h"
#include "visualization.pybind.h"

PYBIND11_MODULE(tesseract_decoder, tesseract) {
py::module::import("stim");
Expand All @@ -29,6 +30,7 @@ PYBIND11_MODULE(tesseract_decoder, tesseract) {
add_utils_module(tesseract);
add_simplex_module(tesseract);
add_tesseract_module(tesseract);
add_visualization_module(tesseract);

// Adds a context manager to the python library that can be used to redirect C++'s stdout/stderr
// to python's stdout/stderr at run time like
Expand Down
16 changes: 12 additions & 4 deletions src/tesseract.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ TesseractConfig tesseract_config_maker(
bool no_revisit_dets = false, bool at_most_two_errors_per_detector = false,
bool verbose = false, size_t pqlimit = std::numeric_limits<size_t>::max(),
std::vector<std::vector<size_t>> det_orders = std::vector<std::vector<size_t>>(),
double det_penalty = 0.0) {
double det_penalty = 0.0, bool create_visualization = false) {
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets,
at_most_two_errors_per_detector, verbose, pqlimit, det_orders,
det_penalty});
det_penalty, create_visualization});
}
}; // namespace
void add_tesseract_module(py::module& root) {
Expand All @@ -61,6 +61,7 @@ void add_tesseract_module(py::module& root) {
py::arg("at_most_two_errors_per_detector") = false, py::arg("verbose") = false,
py::arg("pqlimit") = std::numeric_limits<size_t>::max(),
py::arg("det_orders") = std::vector<std::vector<size_t>>(), py::arg("det_penalty") = 0.0,
py::arg("create_visualization") = false,
R"pbdoc(
The constructor for the `TesseractConfig` class.

Expand All @@ -86,6 +87,8 @@ void add_tesseract_module(py::module& root) {
will generate its own orderings.
det_penalty : float, default=0.0
A penalty value added to the cost of each detector visited.
create_visualization: bool, defualt=False
Whether to record the information needed to create a visualization or not.
)pbdoc")
.def_property("dem", &dem_getter<TesseractConfig>, &dem_setter<TesseractConfig>,
"The `stim.DetectorErrorModel` that defines the error channels and detectors.")
Expand All @@ -106,6 +109,8 @@ void add_tesseract_module(py::module& root) {
"A list of pre-specified detector orderings.")
.def_readwrite("det_penalty", &TesseractConfig::det_penalty,
"The penalty cost added for each detector.")
.def_readwrite("create_visualization", &TesseractConfig::create_visualization,
"If True, records necessary information to create visualization.")
.def("__str__", &TesseractConfig::str)
.def("compile_decoder", &_compile_tesseract_decoder_helper,
py::return_value_policy::take_ownership,
Expand Down Expand Up @@ -374,7 +379,10 @@ void add_tesseract_module(py::module& root) {
.def_readwrite("errors", &TesseractDecoder::errors,
"The list of all errors in the detector error model.")
.def_readwrite("num_observables", &TesseractDecoder::num_observables,
"The total number of logical observables in the detector error model.");
"The total number of logical observables in the detector error model.")
.def_readonly("visualizer", &TesseractDecoder::visualizer,
"An object that can (if config.create_visualization=True) be used to generate "
"visualization of the algorithm");
}

#endif
#endif
53 changes: 53 additions & 0 deletions src/visualization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

#include "visualization.h"

void Visualizer::add_errors(const std::vector<common::Error>& errors) {
for (auto& error : errors) {
lines.push_back(error.str());
}
}
void Visualizer::add_detector_coords(const std::vector<std::vector<double>>& detector_coords) {
for (size_t d = 0; d < detector_coords.size(); ++d) {
std::stringstream ss;
ss << "Detector D" << d << " coordinate (";
size_t e = std::min(3ul, detector_coords[d].size());
for (size_t i = 0; i < e; ++i) {
ss << detector_coords[d][i];
if (i + 1 < e) ss << ", ";
}
ss << ")";
lines.push_back(ss.str());
}
}

void Visualizer::add_activated_errors(const std::vector<size_t>& activated_errors) {
std::stringstream ss;
ss << "activated_errors = ";
for (size_t oei : activated_errors) {
ss << oei << ", ";
}
lines.push_back(ss.str());
}

void Visualizer::add_activated_detectors(const boost::dynamic_bitset<>& detectors,
size_t num_detectors) {
std::stringstream ss;
ss << "activated_detectors = ";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LalehB is it supposed to be activated_detectors or activated_dets ... the script seems to look for _dets

activated_dets = parse_implicit_list(det_line, "activated_dets =")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NoureldinYosri yeah amazing catch! it used to be activated_dets and then in the code got changed to activated_detectors but apparently the visualization did not get updated! thanks a lot!

for (size_t d = 0; d < num_detectors; ++d) {
if (detectors[d]) {
ss << d << ", ";
}
}
lines.push_back(ss.str());
}

void Visualizer::write(const char* fpath) {
FILE* fout = fopen(fpath, "w");

for (std::string& line : lines) {
fprintf(fout, line.c_str());
fputs("\n", fout);
}

fclose(fout);
}
22 changes: 22 additions & 0 deletions src/visualization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef _VISUALIZATION_H
#define _VISUALIZATION_H

#include <boost/dynamic_bitset.hpp>
#include <list>
#include <vector>

#include "common.h"

struct Visualizer {
void add_detector_coords(const std::vector<std::vector<double>> &);
void add_errors(const std::vector<common::Error> &);
void add_activated_errors(const std::vector<size_t> &);
void add_activated_detectors(const boost::dynamic_bitset<> &, size_t);

void write(const char *fpath);

private:
std::list<std::string> lines;
};

#endif
16 changes: 16 additions & 0 deletions src/visualization.pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <pybind11/iostream.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "visualization.h"

namespace py = pybind11;

void add_visualization_module(py::module& root) {
auto m = root.def_submodule("viz", "Module containing the visualization tools");
py::class_<Visualizer>(m, "Visualizer")
.def(py::init<>())
.def("write", &Visualizer::write, py::arg("fpath"));
}
Loading