diff --git a/src/tesseract.cc b/src/tesseract.cc index 0c1099c..ae802b2 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -57,47 +57,50 @@ std::string TesseractConfig::str() return ss.str(); } -bool Node::operator>(const Node& other) const { - return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); -} - std::string Node::str() { std::stringstream ss; auto &self = *this; ss << "Node("; - ss << "errs=" << self.errs << ", "; - ss << "dets=" << self.dets << ", "; + ss << "errors=" << self.errors << ", "; ss << "cost=" << self.cost << ", "; - ss << "num_dets=" << self.num_dets << ", "; - ss << "blocked_errs=" << self.blocked_errs << ")"; + ss << "num_detectors=" << self.num_detectors << ", "; return ss.str(); } -std::string QNode::str() { - auto & self = *this; - std::stringstream ss; - ss << "QNode("; - ss << "cost=" << self.cost << ", "; - ss << "num_dets=" << self.num_dets << ", "; - ss << "errs=" << self.errs << ")"; - return ss.str(); +bool Node::operator>(const Node& other) const { + return cost > other.cost || (cost == other.cost && num_detectors < other.num_detectors); } -double TesseractDecoder::get_detcost(size_t d, - const std::vector& blocked_errs, - const std::vector& det_counts) const { +double TesseractDecoder::get_detcost(size_t d, const std::vector& detector_cost_tuples) const { double min_cost = INF; + ErrorCost ec; + DetectorCostTuple dct; + for (size_t ei : d2e[d]) { - if (!blocked_errs[ei]) { - double ecost = errors[ei].likelihood_cost / det_counts[ei]; - min_cost = std::min(min_cost, ecost); - assert(det_counts[ei]); + ec = error_costs[ei]; + dct = detector_cost_tuples[ei]; + if (ec.min_cost >= min_cost) break; + if (!dct.error_blocked) { + double error_cost = ec.likelihood_cost / dct.detectors_count; + min_cost = std::min(min_cost, error_cost); } } + return min_cost + config.det_penalty; } +struct VectorCharHash { + size_t operator()(const std::vector& v) const { + size_t seed = v.size(); + + for (char el : v) { + seed = seed * 31 + static_cast(el); + } + return seed; + } +}; + TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) { config.dem = common::remove_zero_probability_errors(config.dem); if (config.det_orders.empty()) { @@ -131,11 +134,20 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) { } } + for (size_t i = 0; i < errors.size(); ++i) { + error_costs.push_back({errors[i].likelihood_cost, errors[i].likelihood_cost / errors[i].symptom.detectors.size()}); + } + + for (size_t d = 0; d < num_detectors; ++d) { + std::sort(d2e[d].begin(), d2e[d].end(), [this](size_t idx_a, size_t idx_b) { + return error_costs[idx_a].min_cost < error_costs[idx_b].min_cost; + }); + } + eneighbors.resize(num_errors); std::vector> edets_sets(edets.size()); for (size_t ei = 0; ei < edets.size(); ++ei) { - edets_sets[ei] = - std::unordered_set(edets[ei].begin(), edets[ei].end()); + edets_sets[ei] = std::unordered_set(edets[ei].begin(), edets[ei].end()); } for (size_t ei = 0; ei < num_errors; ++ei) { std::set neighbor_set; @@ -152,179 +164,136 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) { } } -struct VectorCharHash { - size_t operator()(const std::vector& v) const { - size_t seed = v.size(); // Still good practice to incorporate vector size - - // Iterate over char elements. Accessing 'b_val' is now a direct memory - // read. - for (char b_val : v) { - // The polynomial rolling hash with 31 (or another prime) - // 'b_val' is already a char (an 8-bit integer). - // static_cast(b_val) ensures it's promoted to size_t before - // arithmetic. This cast is efficient (likely a simple register - // extension/move). - seed = seed * 31 + static_cast(b_val); - } - return seed; - } -}; - -void TesseractDecoder::decode_to_errors( - const std::vector& detections) { +void TesseractDecoder::decode_to_errors(const std::vector& detections) { std::vector best_errors; double best_cost = std::numeric_limits::max(); assert(config.det_orders.size()); - int max_det_beam = config.det_beam; + if (config.beam_climbing) { - for (int beam = max_det_beam; beam >= 0; --beam) { - config.det_beam = beam; - size_t det_order = beam % config.det_orders.size(); - decode_to_errors(detections, det_order); - double this_cost = cost_from_errors(predicted_errors_buffer); - if (!low_confidence_flag && this_cost < best_cost) { + for (int beam = config.det_beam; beam >= 0; --beam) { + size_t detector_order = beam % config.det_orders.size(); + decode_to_errors(detections, detector_order, beam); + double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { best_errors = predicted_errors_buffer; - best_cost = this_cost; + best_cost = local_cost; } if (config.verbose) { - std::cout << "for det_order " << det_order << " beam " << beam + std::cout << "for detector_order " << detector_order << " beam " << beam << " got low confidence " << low_confidence_flag - << " and cost " << this_cost << " and obs_mask " + << " and cost " << local_cost << " and obs_mask " << mask_from_errors(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } } } else { - for (size_t det_order = 0; det_order < config.det_orders.size(); - ++det_order) { - decode_to_errors(detections, det_order); - double this_cost = cost_from_errors(predicted_errors_buffer); - if (!low_confidence_flag && this_cost < best_cost) { + for (size_t detector_order = 0; detector_order < config.det_orders.size(); ++detector_order) { + decode_to_errors(detections, detector_order, config.det_beam); + double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { best_errors = predicted_errors_buffer; - best_cost = this_cost; + best_cost = local_cost; } if (config.verbose) { - std::cout << "for det_order " << det_order << " beam " + std::cout << "for detector_order " << detector_order << " beam " << config.det_beam << " got low confidence " - << low_confidence_flag << " and cost " << this_cost + << low_confidence_flag << " and cost " << local_cost << " and obs_mask " << mask_from_errors(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } } } - config.det_beam = max_det_beam; predicted_errors_buffer = best_errors; low_confidence_flag = best_cost == std::numeric_limits::max(); } -bool QNode::operator>(const QNode& other) const { - return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); -} +void TesseractDecoder::flip_detectors_and_block_errors(size_t detector_order, const std::vector& errors, + std::vector& detectors, std::vector& detector_cost_tuples) const { -void TesseractDecoder::to_node(const QNode& qnode, - const std::vector& shot_dets, - size_t det_order, Node& node) const { - node.cost = qnode.cost; - node.errs = qnode.errs; - node.num_dets = qnode.num_dets; - - // Reconstruct the dets and blocked_errs - node.dets = shot_dets; - node.blocked_errs.resize(0); - node.blocked_errs.resize(num_errors, false); - for (size_t ei : node.errs) { - // Get the min index activated detector before updating the dets - size_t min_det = std::numeric_limits::max(); + for (size_t ei : errors) { + size_t min_detector = std::numeric_limits::max(); for (size_t d = 0; d < num_detectors; ++d) { - if (node.dets[config.det_orders[det_order][d]]) { - min_det = config.det_orders[det_order][d]; + if (detectors[config.det_orders[detector_order][d]]) { + min_detector = config.det_orders[detector_order][d]; break; } } - // Reconstruct the blocked_errs - for (size_t oei : d2e[min_det]) { - node.blocked_errs[oei] = true; + + for (size_t oei : d2e[min_detector]) { + detector_cost_tuples[oei].error_blocked = true; if (!config.at_most_two_errors_per_detector && oei == ei) break; } - // Reconstruct the dets for (size_t d : edets[ei]) { - node.dets[d] = !node.dets[d]; - if (!node.dets[d] && config.at_most_two_errors_per_detector) { + detectors[d] = !detectors[d]; + if (!detectors[d] && config.at_most_two_errors_per_detector) { for (size_t oei : d2e[d]) { - node.blocked_errs[oei] = true; + detector_cost_tuples[oei].error_blocked = true; } } } } } -void TesseractDecoder::decode_to_errors(const std::vector& detections, - size_t det_order) { - size_t det_beam = config.det_beam; +void TesseractDecoder::decode_to_errors(const std::vector& detections, size_t detector_order, size_t detector_beam) { predicted_errors_buffer.clear(); low_confidence_flag = false; - std::vector dets(num_detectors, false); - for (size_t d : detections) { - dets[d] = true; - } - std::priority_queue, std::greater> pq; - std::unordered_map, VectorCharHash>> - discovered_dets; + std::priority_queue, std::greater> pq; + std::unordered_map, VectorCharHash>> discovered_detectors; - size_t min_num_dets = detections.size(); - std::vector errs; - std::vector blocked_errs(num_errors, false); - std::vector det_counts(num_errors, 0); + std::vector initial_detectors(num_detectors, false); + std::vector initial_detector_cost_tuples(num_errors); - for (size_t d = 0; d < num_detectors; ++d) { - if (!dets[d]) continue; + for (size_t d : detections) { + initial_detectors[d] = true; for (int ei : d2e[d]) { - ++det_counts[ei]; + ++initial_detector_cost_tuples[ei].detectors_count; } } - double initial_cost = 0.0; - for (size_t d = 0; d < num_detectors; ++d) { - if (!dets[d]) continue; - initial_cost += get_detcost(d, blocked_errs, det_counts); + + double initial_cost = 0; + for (size_t d : detections) { + initial_cost += get_detcost(d, initial_detector_cost_tuples); } + if (initial_cost == INF) { low_confidence_flag = true; return; } - // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs}); - pq.push({initial_cost, min_num_dets, errs}); + size_t min_num_detectors = detections.size(); + size_t max_num_detectors = min_num_detectors + detector_beam; + + std::vector next_errors; + std::vector next_detectors; + std::vector next_detector_cost_tuples; + std::vector next_next_detector_cost_tuples; + + pq.push({initial_cost, min_num_detectors, std::vector()}); size_t num_pq_pushed = 1; - size_t max_num_dets = min_num_dets + det_beam; - Node node; - std::vector next_det_counts; - std::vector next_next_blocked_errs; - std::vector next_dets; - std::vector next_errs; while (!pq.empty()) { - const QNode qnode = pq.top(); - if (qnode.num_dets > max_num_dets) { - pq.pop(); - continue; - } - to_node(qnode, dets, det_order, node); + const Node node = pq.top(); pq.pop(); - if (node.num_dets == 0) { + if (node.num_detectors > max_num_detectors) continue; + + std::vector detectors = initial_detectors; + std::vector detector_cost_tuples(num_errors); + flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples); + + if (node.num_detectors == 0) { if (config.verbose) { std::cout << "activated_errors = "; - for (size_t oei : node.errs) { + for (size_t oei : node.errors) { std::cout << oei << ", "; } std::cout << std::endl; - std::cout << "activated_dets = "; + std::cout << "activated_detectors = "; for (size_t d = 0; d < num_detectors; ++d) { - if (node.dets[d]) { + if (detectors[d]) { std::cout << d << ", "; } } @@ -333,15 +302,11 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::cout << "Decoding complete. Cost: " << node.cost << " num_pq_pushed = " << num_pq_pushed << std::endl; } - // Store the predicted errors into the buffer - predicted_errors_buffer = node.errs; + predicted_errors_buffer = node.errors; return; } - if (node.num_dets > max_num_dets) continue; - - if (config.no_revisit_dets && - !discovered_dets[node.num_dets].insert(node.dets).second) { + if (config.no_revisit_dets && !discovered_detectors[node.num_detectors].insert(detectors).second) { continue; } @@ -349,149 +314,130 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::cout.precision(13); std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl; - std::cout << "num_dets = " << node.num_dets - << " max_num_dets = " << max_num_dets << " cost = " << node.cost + std::cout << "num_detectors = " << node.num_detectors + << " max_num_detectors = " << max_num_detectors << " cost = " << node.cost << std::endl; std::cout << "activated_errors = "; - for (size_t oei : node.errs) { + for (size_t oei : node.errors) { std::cout << oei << ", "; } std::cout << std::endl; - std::cout << "activated_dets = "; + std::cout << "activated_detectors = "; for (size_t d = 0; d < num_detectors; ++d) { - if (node.dets[d]) { + if (detectors[d]) { std::cout << d << ", "; } } std::cout << std::endl; } - if (node.num_dets < min_num_dets) { - min_num_dets = node.num_dets; + if (node.num_detectors < min_num_detectors) { + min_num_detectors = node.num_detectors; if (config.no_revisit_dets) { - for (size_t i = min_num_dets + det_beam + 1; i <= max_num_dets; ++i) { - discovered_dets[i].clear(); + for (size_t i = min_num_detectors + detector_beam + 1; i <= max_num_detectors; ++i) { + discovered_detectors[i].clear(); } } - max_num_dets = std::min(max_num_dets, min_num_dets + det_beam); + max_num_detectors = std::min(max_num_detectors, min_num_detectors + detector_beam); } - // Choose the min det to be the minimum index activated detector - size_t min_det = std::numeric_limits::max(); for (size_t d = 0; d < num_detectors; ++d) { - if (node.dets[config.det_orders[det_order][d]]) { - min_det = config.det_orders[det_order][d]; - break; + if (!detectors[d]) continue; + for (int ei : d2e[d]) { + ++detector_cost_tuples[ei].detectors_count; } } - // Recompute the det counts - std::vector det_counts(num_errors, 0); + next_detector_cost_tuples = detector_cost_tuples; + + size_t min_detector = std::numeric_limits::max(); for (size_t d = 0; d < num_detectors; ++d) { - if (!node.dets[d]) continue; - for (int ei : d2e[d]) { - ++det_counts[ei]; + if (detectors[config.det_orders[detector_order][d]]) { + min_detector = config.det_orders[detector_order][d]; + break; } } - // We cache as we recompute the det costs - std::vector det_costs(num_detectors, -1); - std::vector next_blocked_errs = node.blocked_errs; if (config.at_most_two_errors_per_detector) { - for (int ei : d2e[min_det]) { - // Block all errors of this detector -- note this is an approximation - // where we insist at most 2 errors are incident to any detector - next_blocked_errs[ei] = true; + for (int ei : d2e[min_detector]) { + next_detector_cost_tuples[ei].error_blocked = true; } } - // Consider activating any error of the lowest index activated detector - next_det_counts = det_counts; - size_t last_ei = std::numeric_limits::max(); - for (size_t ei : d2e[min_det]) { - if (node.blocked_errs[ei]) { + size_t prev_ei = std::numeric_limits::max(); + std::vector detector_cost_cache(num_detectors, -1); + + for (size_t ei : d2e[min_detector]) { + if (detector_cost_tuples[ei].error_blocked) { continue; } - // Uncompute the previous edits to the next det counts on the last - // iteration - if (last_ei != std::numeric_limits::max()) { - for (int d : edets[last_ei]) { - int fired = node.dets[d] ? 1 : -1; + if (prev_ei != std::numeric_limits::max()) { + for (int d : edets[prev_ei]) { + int fired = detectors[d] ? 1 : -1; for (int oei : d2e[d]) { - next_det_counts[oei] += fired; + next_detector_cost_tuples[oei].detectors_count += fired; } } } + prev_ei = ei; + + next_errors = node.errors; + next_errors.push_back(ei); + next_detectors = detectors; + next_detector_cost_tuples[ei].error_blocked = true; - last_ei = ei; - next_blocked_errs[ei] = true; - - next_errs = node.errs; - next_errs.push_back(ei); - - next_dets = node.dets; double next_cost = node.cost + errors[ei].likelihood_cost; + size_t next_num_detectors = node.num_detectors; - size_t next_num_dets = node.num_dets; if (config.at_most_two_errors_per_detector) { - next_next_blocked_errs = next_blocked_errs; + next_next_detector_cost_tuples = next_detector_cost_tuples; } for (int d : edets[ei]) { - next_dets[d] = !next_dets[d]; - int fired = next_dets[d] ? 1 : -1; - next_num_dets += fired; + next_detectors[d] = !next_detectors[d]; + int fired = next_detectors[d] ? 1 : -1; + next_num_detectors += fired; for (int oei : d2e[d]) { - next_det_counts[oei] += fired; + next_detector_cost_tuples[oei].detectors_count += fired; } - if (!next_dets[d] && config.at_most_two_errors_per_detector) { + if (!next_detectors[d] && config.at_most_two_errors_per_detector) { for (size_t oei : d2e[d]) { - next_next_blocked_errs[oei] = true; + next_next_detector_cost_tuples[oei].error_blocked = true; } } } - if (next_num_dets > max_num_dets) { - continue; - } + if (next_num_detectors > max_num_detectors) continue; - if (config.no_revisit_dets && - discovered_dets[next_num_dets].find(next_dets) != - discovered_dets[next_num_dets].end()) { + if (config.no_revisit_dets && discovered_detectors[next_num_detectors].find(next_detectors) != discovered_detectors[next_num_detectors].end()) { continue; } for (int d : edets[ei]) { - if (node.dets[d]) { - if (det_costs[d] == -1) { - det_costs[d] = - get_detcost(d, node.blocked_errs, det_counts); + if (detectors[d]) { + if (detector_cost_cache[d] == -1) { + detector_cost_cache[d] = get_detcost(d, detector_cost_tuples); } - next_cost -= det_costs[d]; + next_cost -= detector_cost_cache[d]; } else { - next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); + next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_detector_cost_tuples : next_detector_cost_tuples); } } + for (size_t od : eneighbors[ei]) { - if (!node.dets[od] || !next_dets[od]) continue; - if (det_costs[od] == -1) { - det_costs[od] = - get_detcost(od, node.blocked_errs, det_counts); + if (!detectors[od] || !next_detectors[od]) continue; + if (detector_cost_cache[od] == -1) { + detector_cost_cache[od] = get_detcost(od, detector_cost_tuples); } - next_cost -= det_costs[od]; - next_cost += - get_detcost(od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); + next_cost -= detector_cost_cache[od]; + next_cost += get_detcost(od, config.at_most_two_errors_per_detector ? next_next_detector_cost_tuples : next_detector_cost_tuples); } - if (next_cost == INF) { - continue; - } + if (next_cost == INF) continue; - // pq.push({next_errs, next_dets, next_cost, next_num_dets, - // next_blocked_errs}); - pq.push({next_cost, next_num_dets, next_errs}); + pq.push({next_cost, next_num_detectors, next_errors}); ++num_pq_pushed; if (num_pq_pushed > config.pqlimit) { @@ -506,31 +452,27 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::cout << "Decoding failed to converge within beam limit." << std::endl; } low_confidence_flag = true; - return; } -double TesseractDecoder::cost_from_errors( - const std::vector& predicted_errors) { +double TesseractDecoder::cost_from_errors(const std::vector& predicted_errors) { double total_cost = 0; - // Iterate over all errors and add to the mask - for (size_t ei : predicted_errors_buffer) { + // Iterate over all errors and compute the cost + for (size_t ei : predicted_errors) { total_cost += errors[ei].likelihood_cost; } return total_cost; } -common::ObservablesMask TesseractDecoder::mask_from_errors( - const std::vector& predicted_errors) { +common::ObservablesMask TesseractDecoder::mask_from_errors(const std::vector& predicted_errors) { common::ObservablesMask mask = 0; - // Iterate over all errors and add to the mask - for (size_t ei : predicted_errors_buffer) { + // Iterate over all errors and compute the mask + for (size_t ei : predicted_errors) { mask ^= errors[ei].symptom.observables; } return mask; } -common::ObservablesMask TesseractDecoder::decode( - const std::vector& detections) { +common::ObservablesMask TesseractDecoder::decode(const std::vector& detections) { decode_to_errors(detections); return mask_from_errors(predicted_errors_buffer); } diff --git a/src/tesseract.h b/src/tesseract.h index cd290eb..c380ed3 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -15,14 +15,14 @@ #ifndef TESSERACT_DECODER_H #define TESSERACT_DECODER_H -#include #include -#include #include +#include +#include #include -#include "common.h" #include "stim.h" +#include "common.h" #include "utils.h" constexpr size_t INF_DET_BEAM = std::numeric_limits::max(); @@ -43,24 +43,22 @@ struct TesseractConfig { class Node { public: - std::vector errs; - std::vector dets; double cost; - size_t num_dets; - std::vector blocked_errs; + size_t num_detectors; + std::vector errors; bool operator>(const Node& other) const; std::string str(); }; -class QNode { - public: - double cost; - size_t num_dets; - std::vector errs; +struct DetectorCostTuple { + uint32_t error_blocked; + uint32_t detectors_count; +}; - bool operator>(const QNode& other) const; - std::string str(); +struct ErrorCost { + double likelihood_cost; + double min_cost; }; struct TesseractDecoder { @@ -73,26 +71,21 @@ struct TesseractDecoder { // Clears the predicted_errors_buffer and fills it with the decoded errors for // these detection events, using a specified detector ordering index. - void decode_to_errors(const std::vector& detections, - size_t det_order); + void decode_to_errors(const std::vector& detections, size_t detector_order, size_t detector_beam); // Returns the bitwise XOR of all the observables bitmasks of all errors in // the predicted errors buffer. - common::ObservablesMask mask_from_errors( - const std::vector& predicted_errors); + common::ObservablesMask mask_from_errors(const std::vector& predicted_errors); // Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of // all errors in the predicted errors buffer. double cost_from_errors(const std::vector& predicted_errors); - common::ObservablesMask decode(const std::vector& detections); - void decode_shots(std::vector& shots, - std::vector& obs_predicted); + common::ObservablesMask decode(const std::vector& detections); + void decode_shots(std::vector& shots, std::vector& obs_predicted); bool low_confidence_flag = false; std::vector predicted_errors_buffer; - - int det_beam; std::vector errors; private: @@ -101,12 +94,12 @@ struct TesseractDecoder { std::vector> edets; size_t num_detectors; size_t num_errors; + std::vector error_costs; void initialize_structures(size_t num_detectors); - double get_detcost(size_t d, const std::vector& blocked_errs, - const std::vector& det_counts) const; - void to_node(const QNode& qnode, const std::vector& shot_dets, - size_t det_order, Node& node) const; + double get_detcost(size_t d, const std::vector& detector_cost_tuples) const; + void flip_detectors_and_block_errors(size_t detector_order, const std::vector& errors, + std::vector& detectors, std::vector& detector_cost_tuples) const; }; #endif // TESSERACT_DECODER_H \ No newline at end of file diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 5c57a2e..2c36b99 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -50,30 +50,16 @@ void add_tesseract_module(py::module &root) { .def("__str__", &TesseractConfig::str); py::class_(m, "Node") - .def(py::init, std::vector, double, size_t, - std::vector>(), + .def(py::init>(), py::arg("errs") = std::vector(), - py::arg("dets") = std::vector(), py::arg("cost") = 0.0, - py::arg("num_dets") = 0, - py::arg("blocked_errs") = std::vector()) - .def_readwrite("errs", &Node::errs) - .def_readwrite("dets", &Node::dets) + py::arg("cost") = 0.0, + py::arg("num_dets") = 0) + .def_readwrite("errs", &Node::errors) .def_readwrite("cost", &Node::cost) - .def_readwrite("num_dets", &Node::num_dets) - .def_readwrite("blocked_errs", &Node::blocked_errs) + .def_readwrite("num_dets", &Node::num_detectors) .def(py::self > py::self) .def("__str__", &Node::str); - py::class_(m, "QNode") - .def(py::init>(), - py::arg("cost") = 0.0, py::arg("num_dets") = 0, - py::arg("errs") = std::vector()) - .def_readwrite("cost", &QNode::cost) - .def_readwrite("num_dets", &QNode::num_dets) - .def_readwrite("errs", &QNode::errs) - .def(py::self > py::self) - .def("__str__", &QNode::str); - py::class_(m, "TesseractDecoder") .def(py::init(), py::arg("config")) .def("decode_to_errors", @@ -81,9 +67,9 @@ void add_tesseract_module(py::module &root) { &TesseractDecoder::decode_to_errors), py::arg("detections")) .def("decode_to_errors", - py::overload_cast &, size_t>( + py::overload_cast &, size_t, size_t>( &TesseractDecoder::decode_to_errors), - py::arg("detections"), py::arg("det_order")) + py::arg("detections"), py::arg("det_order"), py::arg("det_beam")) .def("mask_from_errors", &TesseractDecoder::mask_from_errors, py::arg("predicted_errors")) .def("cost_from_errors", &TesseractDecoder::cost_from_errors, @@ -93,7 +79,6 @@ void add_tesseract_module(py::module &root) { &TesseractDecoder::low_confidence_flag) .def_readwrite("predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer) - .def_readwrite("det_beam", &TesseractDecoder::det_beam) .def_readwrite("errors", &TesseractDecoder::errors); }