diff --git a/src/tesseract.cc b/src/tesseract.cc index 2c002b4..c9e530b 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -24,12 +24,11 @@ bool Node::operator>(const Node& other) const { double TesseractDecoder::get_detcost(size_t d, const std::vector& blocked_errs, - const std::vector& det_counts, - const std::vector& dets) const { + const std::vector& det_counts) const { double min_cost = INF; for (size_t ei : d2e[d]) { if (!blocked_errs[ei]) { - double ecost = (errors[ei].likelihood_cost) / det_counts[ei]; + double ecost = errors[ei].likelihood_cost / det_counts[ei]; min_cost = std::min(min_cost, ecost); assert(det_counts[ei]); } @@ -46,7 +45,7 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) { assert(config.det_orders[i].size() == config.dem.count_detectors()); } } - assert(this->config.det_orders.size()); + assert(config.det_orders.size()); errors = get_errors_from_dem(config.dem.flattened()); if (config.verbose) { for (auto& error : errors) { @@ -120,7 +119,7 @@ void TesseractDecoder::decode_to_errors( size_t det_order = beam % config.det_orders.size(); decode_to_errors(detections, det_order); double this_cost = cost_from_errors(predicted_errors_buffer); - if (!low_confidence_flag and this_cost < best_cost) { + if (!low_confidence_flag && this_cost < best_cost) { best_errors = predicted_errors_buffer; best_cost = this_cost; } @@ -137,7 +136,7 @@ void TesseractDecoder::decode_to_errors( ++det_order) { decode_to_errors(detections, det_order); double this_cost = cost_from_errors(predicted_errors_buffer); - if (!low_confidence_flag and this_cost < best_cost) { + if (!low_confidence_flag && this_cost < best_cost) { best_errors = predicted_errors_buffer; best_cost = this_cost; } @@ -153,7 +152,7 @@ void TesseractDecoder::decode_to_errors( } config.det_beam = max_det_beam; predicted_errors_buffer = best_errors; - low_confidence_flag = (best_cost == std::numeric_limits::max()); + low_confidence_flag = best_cost == std::numeric_limits::max(); } bool QNode::operator>(const QNode& other) const { @@ -183,20 +182,16 @@ void TesseractDecoder::to_node(const QNode& qnode, // Reconstruct the blocked_errs for (size_t oei : d2e[min_det]) { node.blocked_errs[oei] = true; - if (!config.at_most_two_errors_per_detector and oei == ei) break; + if (!config.at_most_two_errors_per_detector && oei == ei) break; } // Reconstruct the dets for (size_t d : edets[ei]) { - if (node.dets[d]) { - node.dets[d] = false; - if (config.at_most_two_errors_per_detector) { - for (size_t oei : d2e[d]) { - node.blocked_errs[oei] = true; - } + node.dets[d] = !node.dets[d]; + if (!node.dets[d] && config.at_most_two_errors_per_detector) { + for (size_t oei : d2e[d]) { + node.blocked_errs[oei] = true; } - } else { - node.dets[d] = true; } } } @@ -217,40 +212,37 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::unordered_set, VectorCharHash>> discovered_dets; - size_t min_num_dets; - { - std::vector errs; - std::vector blocked_errs(num_errors, false); - std::vector det_counts(num_errors, 0); + size_t min_num_dets = detections.size(); + std::vector errs; + std::vector blocked_errs(num_errors, false); + std::vector det_counts(num_errors, 0); - for (size_t d = 0; d < num_detectors; ++d) { - if (!dets[d]) continue; - for (int ei : d2e[d]) { - det_counts[ei]++; - } - } - double initial_cost = 0.0; - for (size_t d = 0; d < num_detectors; ++d) { - if (!dets[d]) continue; - initial_cost += get_detcost(d, blocked_errs, det_counts, dets); + for (size_t d = 0; d < num_detectors; ++d) { + if (!dets[d]) continue; + for (int ei : d2e[d]) { + ++det_counts[ei]; } - if (initial_cost == INF) { - low_confidence_flag = true; - return; - } - min_num_dets = - static_cast(std::count(dets.begin(), dets.end(), true)); - // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs}); - pq.push({initial_cost, min_num_dets, errs}); } - size_t num_pq_pushed = 1; + double initial_cost = 0.0; + for (size_t d = 0; d < num_detectors; ++d) { + if (!dets[d]) continue; + initial_cost += get_detcost(d, blocked_errs, det_counts); + } + if (initial_cost == INF) { + low_confidence_flag = true; + return; + } + // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs}); + pq.push({initial_cost, min_num_dets, errs}); + size_t num_pq_pushed = 1; size_t max_num_dets = min_num_dets + det_beam; Node node; std::vector next_det_counts; std::vector next_next_blocked_errs; std::vector next_dets; std::vector next_errs; + while (!pq.empty()) { const QNode qnode = pq.top(); if (qnode.num_dets > max_num_dets) { @@ -280,13 +272,12 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, } // Store the predicted errors into the buffer predicted_errors_buffer = node.errs; - return; } if (node.num_dets > max_num_dets) continue; - if (config.no_revisit_dets and + if (config.no_revisit_dets && !discovered_dets[node.num_dets].insert(node.dets).second) { continue; } @@ -336,9 +327,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, for (size_t d = 0; d < num_detectors; ++d) { if (!node.dets[d]) continue; for (int ei : d2e[d]) { - det_counts[ei]++; + ++det_counts[ei]; } } + // We cache as we recompute the det costs std::vector det_costs(num_detectors, -1); std::vector next_blocked_errs = node.blocked_errs; @@ -362,19 +354,14 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, // iteration if (last_ei != std::numeric_limits::max()) { for (int d : edets[last_ei]) { - if (node.dets[d]) { - for (int oei : d2e[d]) { - ++next_det_counts[oei]; - } - } else { - for (int oei : d2e[d]) { - --next_det_counts[oei]; - } + int fired = node.dets[d] ? 1 : -1; + for (int oei : d2e[d]) { + next_det_counts[oei] += fired; } } } - last_ei = ei; + last_ei = ei; next_blocked_errs[ei] = true; next_errs = node.errs; @@ -384,24 +371,21 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, double next_cost = node.cost + errors[ei].likelihood_cost; size_t next_num_dets = node.num_dets; - next_next_blocked_errs = next_blocked_errs; + if (config.at_most_two_errors_per_detector) { + next_next_blocked_errs = next_blocked_errs; + } + for (int d : edets[ei]) { - if (next_dets[d]) { - next_dets[d] = false; - --next_num_dets; - for (int oei : d2e[d]) { - --next_det_counts[oei]; - } - if (config.at_most_two_errors_per_detector) { - for (size_t oei : d2e[d]) { - next_next_blocked_errs[oei] = true; - } - } - } else { - next_dets[d] = true; - ++next_num_dets; - for (int oei : d2e[d]) { - ++next_det_counts[oei]; + next_dets[d] = !next_dets[d]; + int fired = next_dets[d] ? 1 : -1; + next_num_dets += fired; + for (int oei : d2e[d]) { + next_det_counts[oei] += fired; + } + + if (!next_dets[d] && config.at_most_two_errors_per_detector) { + for (size_t oei : d2e[d]) { + next_next_blocked_errs[oei] = true; } } } @@ -410,7 +394,7 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, continue; } - if (config.no_revisit_dets and + if (config.no_revisit_dets && discovered_dets[next_num_dets].find(next_dets) != discovered_dets[next_num_dets].end()) { continue; @@ -420,23 +404,22 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (node.dets[d]) { if (det_costs[d] == -1) { det_costs[d] = - get_detcost(d, node.blocked_errs, det_counts, node.dets); + get_detcost(d, node.blocked_errs, det_counts); } next_cost -= det_costs[d]; } else { - next_cost += get_detcost(d, next_next_blocked_errs, next_det_counts, - next_dets); + next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); } } for (size_t od : eneighbors[ei]) { if (!node.dets[od] || !next_dets[od]) continue; if (det_costs[od] == -1) { det_costs[od] = - get_detcost(od, node.blocked_errs, det_counts, node.dets); + get_detcost(od, node.blocked_errs, det_counts); } next_cost -= det_costs[od]; next_cost += - get_detcost(od, next_next_blocked_errs, next_det_counts, next_dets); + get_detcost(od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); } if (next_cost == INF) { @@ -496,4 +479,4 @@ void TesseractDecoder::decode_shots( for (size_t i = 0; i < shots.size(); ++i) { obs_predicted[i] = decode(shots[i].hits); } -} +} \ No newline at end of file diff --git a/src/tesseract.h b/src/tesseract.h index 96b40fa..0997fda 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "common.h" #include "stim.h" @@ -70,10 +71,12 @@ struct TesseractDecoder { // these detection events, using a specified detector ordering index. void decode_to_errors(const std::vector& detections, size_t det_order); + // Returns the bitwise XOR of all the observables bitmasks of all errors in // the predicted errors buffer. common::ObservablesMask mask_from_errors( const std::vector& predicted_errors); + // Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of // all errors in the predicted errors buffer. double cost_from_errors(const std::vector& predicted_errors); @@ -97,10 +100,9 @@ struct TesseractDecoder { void initialize_structures(size_t num_detectors); double get_detcost(size_t d, const std::vector& blocked_errs, - const std::vector& det_counts, - const std::vector& dets) const; + const std::vector& det_counts) const; void to_node(const QNode& qnode, const std::vector& shot_dets, size_t det_order, Node& node) const; }; -#endif // TESSERACT_DECODER_H +#endif // TESSERACT_DECODER_H \ No newline at end of file