From 64dd86af318c4c5aae826fee13c0046acbe03051 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 5 Nov 2025 16:29:12 -0500 Subject: [PATCH 01/20] Implement no-cloning validation --- .../analysis/validation/nocloning/__init__.py | 2 + .../analysis/validation/nocloning/analysis.py | 122 ++++++++++++++++++ .../analysis/validation/nocloning/impls.py | 33 +++++ .../analysis/validation/nocloning/lattice.py | 77 +++++++++++ test/analysis/validation/test_no_cloning.py | 80 ++++++++++++ test/analysis/validation/util.py | 17 +++ 6 files changed, 331 insertions(+) create mode 100644 src/bloqade/analysis/validation/nocloning/__init__.py create mode 100644 src/bloqade/analysis/validation/nocloning/analysis.py create mode 100644 src/bloqade/analysis/validation/nocloning/impls.py create mode 100644 src/bloqade/analysis/validation/nocloning/lattice.py create mode 100644 test/analysis/validation/test_no_cloning.py create mode 100644 test/analysis/validation/util.py diff --git a/src/bloqade/analysis/validation/nocloning/__init__.py b/src/bloqade/analysis/validation/nocloning/__init__.py new file mode 100644 index 00000000..a61f8ba0 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/__init__.py @@ -0,0 +1,2 @@ +from . import impls as impls +from .analysis import NoCloningValidation as NoCloningValidation diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py new file mode 100644 index 00000000..a1e9c703 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -0,0 +1,122 @@ +from dataclasses import field + +from kirin import ir +from kirin.analysis import Forward, TypeInference +from kirin.dialects import func +from kirin.analysis.forward import ForwardFrame + +from bloqade.analysis.address import ( + Address, + AddressReg, + AddressQubit, + AddressAnalysis, +) +from bloqade.analysis.address.lattice import QubitLike + +from .lattice import QubitValidation + + +class NoCloningValidation(Forward[QubitValidation]): + """ + Validates the no-cloning theorem by tracking qubit addresses. + + Built on top of AddressAnalysis to get qubit address information. + """ + + keys = ["validate.nocloning"] + lattice = QubitValidation + _address_frame: ForwardFrame[Address] = field(init=False) + _type_frame: ForwardFrame = field(init=False) + method: ir.Method + violations: int = field(default=0, init=False) + + def __init__(self, mtd: ir.Method): + """ + Input: + - an ir.Method / kernel function + infer dialects from it and remember method. + """ + self.method = mtd + super().__init__(mtd.dialects) + + def initialize(self): + super().initialize() + + address_analysis = AddressAnalysis(self.dialects) + address_analysis.initialize() + self._address_frame, _ = address_analysis.run_analysis(self.method) + + type_inference = TypeInference(self.dialects) + type_inference.initialize() + self._type_frame, _ = type_inference.run_analysis(self.method) + + return self + + def method_self(self, method: ir.Method) -> QubitValidation: + return self.lattice.bottom() + + def get_qubit_addresses(self, addr: Address) -> frozenset[int]: + """Extract concrete qubit addresses from an Address lattice element.""" + match addr: + case AddressQubit(data=qubit_addr): + return frozenset([qubit_addr]) + case AddressReg(data=addrs): + return frozenset(addrs) + case _: + return frozenset() + + def get_stmt_info(self, stmt: ir.Statement) -> str: + """String Report about the statement for violation messages.""" + if isinstance(stmt, func.Invoke) and hasattr(stmt, "callee"): + gate_name = stmt.callee.sym_name.upper() + return f"{gate_name} Gate" + + return f"{stmt.__class__.__name__}@{stmt}" + + def eval_stmt_fallback( + self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement + ) -> tuple[QubitValidation, ...]: + """ + Default statement evaluation: check for qubit usage violations. + """ + + if not isinstance(stmt, func.Invoke): + return tuple(QubitValidation.bottom() for _ in stmt.results) + + address_frame = self._address_frame + if address_frame is None: + return tuple(QubitValidation.top() for _ in stmt.results) + + has_qubit_args = any( + isinstance(address_frame.get(arg), QubitLike) for arg in stmt.args + ) + + if not has_qubit_args: + return tuple(QubitValidation.bottom() for _ in stmt.results) + + used_addrs: list[int] = [] + for arg in stmt.args: + addr = address_frame.get(arg) + qubit_addrs = self.get_qubit_addresses(addr) + used_addrs.extend(qubit_addrs) + + seen: set[int] = set() + violations: list[str] = [] + stmt_info = self.get_stmt_info(stmt) + + for qubit_addr in used_addrs: + if qubit_addr in seen: + violations.append(f"Qubit[{qubit_addr}] at {stmt_info}") + seen.add(qubit_addr) + + if not violations: + return tuple(QubitValidation(violations=frozenset()) for _ in stmt.results) + + usage = QubitValidation(violations=frozenset(violations)) + return tuple(usage for _ in stmt.results) if stmt.results else (usage,) + + def run_method( + self, method: ir.Method, args: tuple[QubitValidation, ...] + ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: + self_mt = self.method_self(method) + return self.run_callable(method.code, (self_mt,) + args) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py new file mode 100644 index 00000000..c2c7fc1c --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -0,0 +1,33 @@ +from kirin import interp +from kirin.analysis import ForwardFrame +from kirin.dialects import scf + +from .lattice import QubitValidation +from .analysis import NoCloningValidation + + +@scf.dialect.register(key="validate.nocloning") +class Scf(interp.MethodTable): + @interp.impl(scf.IfElse) + def if_else( + self, + interp_: NoCloningValidation, + frame: ForwardFrame[QubitValidation], + stmt: scf.IfElse, + ): + cond_validation = frame.get(stmt.cond) + + then_results = interp_.run_callable_region( + frame, stmt, stmt.then_body, (cond_validation,) + ) + + if stmt.else_body: + else_results = interp_.run_callable_region( + frame, stmt, stmt.else_body, (cond_validation,) + ) + + merged = tuple(then_results.join(else_results) for _ in stmt.results) + else: + merged = tuple(then_results for _ in stmt.results) + + return merged if merged else (QubitValidation.bottom(),) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py new file mode 100644 index 00000000..fbc54634 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -0,0 +1,77 @@ +from typing import FrozenSet, final +from dataclasses import field, dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class QubitValidation( + SimpleJoinMixin["QubitValidation"], + SimpleMeetMixin["QubitValidation"], + BoundedLattice["QubitValidation"], +): + """Tracks cloning violations detected during analysis.""" + + violations: FrozenSet[str] = field(default_factory=frozenset) + + @classmethod + def bottom(cls) -> "QubitValidation": + """No violations detected""" + return Bottom() + + @classmethod + def top(cls) -> "QubitValidation": + """Unknown state - assume potential violations""" + return Top() + + def is_subseteq(self, other: "QubitValidation") -> bool: + """Check if this state is a subset of another. + + Lattice ordering: + Bottom ⊑ {{'Qubit[1] at CX Gate'}} ⊑ {{'Qubit[0] at CX Gate'},{'Qubit[1] at CX Gate'}} ⊑ Top + """ + if isinstance(other, Top): + return True + if isinstance(self, Bottom): + return True + if isinstance(other, Bottom): + return False + + return self.violations.issubset(other.violations) + + def __repr__(self) -> str: + """Custom repr to show violations clearly.""" + if not self.violations: + return "QubitValidation()" + return f"QubitValidation(violations={self.violations})" + + +@final +class Bottom(QubitValidation, metaclass=SingletonMeta): + """Bottom element representing no violations.""" + + def is_subseteq(self, other: QubitValidation) -> bool: + """Bottom is subset of everything.""" + return True + + def __repr__(self) -> str: + """Cleaner printing.""" + return "⊥ (Bottom)" + + +@final +class Top(QubitValidation, metaclass=SingletonMeta): + """Top element representing unknown state with potential violations.""" + + def is_subseteq(self, other: QubitValidation) -> bool: + """Top is only subset of Top.""" + return isinstance(other, Top) + + def __repr__(self) -> str: + """Cleaner printing.""" + return "⊤ (Top)" diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py new file mode 100644 index 00000000..85e22616 --- /dev/null +++ b/test/analysis/validation/test_no_cloning.py @@ -0,0 +1,80 @@ +from typing import Any + +import pytest +from util import collect_validation_errors +from kirin import ir +from kirin.dialects.ilist.runtime import IList + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import QubitValidation +from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_control_gate_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(1) + control_gate(q[0], q[0]) + + validation = NoCloningValidation(bad_control) + validation.initialize() + frame, _ = validation.run_analysis(bad_control) + print() + bad_control.print(analysis=frame.entries) + validation_errors = collect_validation_errors(frame, QubitValidation) + assert len(validation_errors) == 1 + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_control_gate_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(cond: bool): + q = squin.qalloc(10) + if cond: + control_gate(q[0], q[0]) + else: + control_gate(q[0], q[1]) + squin.cx(q[1], q[1]) + + validation = NoCloningValidation(bad_control) + validation.initialize() + frame, _ = validation.run_analysis(bad_control) + print() + bad_control.print(analysis=frame.entries) + validation_errors = collect_validation_errors(frame, QubitValidation) + # print("Violations:", validation_errors) + assert len(validation_errors) == 2 + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_control_gate_parallel_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(2) + control_gate(q[0], q[1]) + + validation = NoCloningValidation(bad_control) + validation.initialize() + frame, _ = validation.run_analysis(bad_control) + print() + bad_control.print(analysis=frame.entries) + validation_errors = collect_validation_errors(frame, QubitValidation) + assert len(validation_errors) == 0 + + +def test_control_gate_parallel_pass(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + squin.cx(q[0], q[1]) + squin.cy(q[1], q[1]) + + validation = NoCloningValidation(good_kernel) + validation.initialize() + frame, _ = validation.run_analysis(good_kernel) + print() + good_kernel.print(analysis=frame.entries) + validation_errors = collect_validation_errors(frame, QubitValidation) + assert len(validation_errors) == 1 diff --git a/test/analysis/validation/util.py b/test/analysis/validation/util.py new file mode 100644 index 00000000..bdfae6a9 --- /dev/null +++ b/test/analysis/validation/util.py @@ -0,0 +1,17 @@ +from typing import TypeVar + +from kirin.analysis import ForwardFrame + +from bloqade.analysis.validation.nocloning.lattice import QubitValidation + +T = TypeVar("T", bound=QubitValidation) + + +def collect_validation_errors( + frame: ForwardFrame[QubitValidation], typ: type[T] +) -> list[T]: + return [ + validation_errors + for validation_errors in frame.entries.values() + if isinstance(validation_errors, typ) and len(validation_errors.violations) > 0 + ] From 1b773d40c527371da5e2269ee85e7607e87519f3 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Thu, 6 Nov 2025 09:52:23 -0500 Subject: [PATCH 02/20] improve error reporting and update test cases for validation errors --- .../analysis/validation/nocloning/analysis.py | 20 ++++++++----- test/analysis/validation/test_no_cloning.py | 28 ++++++++++++++----- test/analysis/validation/util.py | 20 ++++++------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index a1e9c703..04c59952 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -7,11 +7,9 @@ from bloqade.analysis.address import ( Address, - AddressReg, - AddressQubit, AddressAnalysis, ) -from bloqade.analysis.address.lattice import QubitLike +from bloqade.analysis.address.lattice import AddressReg, AddressQubit from .lattice import QubitValidation @@ -28,7 +26,7 @@ class NoCloningValidation(Forward[QubitValidation]): _address_frame: ForwardFrame[Address] = field(init=False) _type_frame: ForwardFrame = field(init=False) method: ir.Method - violations: int = field(default=0, init=False) + _validation_errors: list[str] = field(default_factory=list, init=False) def __init__(self, mtd: ir.Method): """ @@ -41,7 +39,7 @@ def __init__(self, mtd: ir.Method): def initialize(self): super().initialize() - + self._validation_errors = [] address_analysis = AddressAnalysis(self.dialects) address_analysis.initialize() self._address_frame, _ = address_analysis.run_analysis(self.method) @@ -88,7 +86,8 @@ def eval_stmt_fallback( return tuple(QubitValidation.top() for _ in stmt.results) has_qubit_args = any( - isinstance(address_frame.get(arg), QubitLike) for arg in stmt.args + isinstance(address_frame.get(arg), (AddressQubit, AddressReg)) + for arg in stmt.args ) if not has_qubit_args: @@ -106,7 +105,10 @@ def eval_stmt_fallback( for qubit_addr in used_addrs: if qubit_addr in seen: - violations.append(f"Qubit[{qubit_addr}] at {stmt_info}") + violations.append(f"Qubit[{qubit_addr}] on {stmt_info}") + self._validation_errors.append( + f"Qubit[{qubit_addr}] on {stmt_info} in {stmt.source}" + ) seen.add(qubit_addr) if not violations: @@ -120,3 +122,7 @@ def run_method( ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: self_mt = self.method_self(method) return self.run_callable(method.code, (self_mt,) + args) + + def get_validation_errors(self) -> str: + """Retrieve collected validation error messages.""" + return "\n".join(self._validation_errors) diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py index 85e22616..cf666640 100644 --- a/test/analysis/validation/test_no_cloning.py +++ b/test/analysis/validation/test_no_cloning.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_control_gate_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): +def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): @squin.kernel def bad_control(): q = squin.qalloc(1) @@ -28,7 +28,7 @@ def bad_control(): @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_control_gate_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): +def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): @squin.kernel def bad_control(cond: bool): q = squin.qalloc(10) @@ -44,12 +44,11 @@ def bad_control(cond: bool): print() bad_control.print(analysis=frame.entries) validation_errors = collect_validation_errors(frame, QubitValidation) - # print("Violations:", validation_errors) assert len(validation_errors) == 2 @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_control_gate_parallel_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): +def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): @squin.kernel def bad_control(): q = squin.qalloc(2) @@ -64,7 +63,7 @@ def bad_control(): assert len(validation_errors) == 0 -def test_control_gate_parallel_pass(): +def test_fail_2(): @squin.kernel def good_kernel(): q = squin.qalloc(2) @@ -74,7 +73,22 @@ def good_kernel(): validation = NoCloningValidation(good_kernel) validation.initialize() frame, _ = validation.run_analysis(good_kernel) - print() - good_kernel.print(analysis=frame.entries) validation_errors = collect_validation_errors(frame, QubitValidation) assert len(validation_errors) == 1 + print(validation.get_validation_errors()) + + +def test_parallel_fail(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(5) + squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) + + validation = NoCloningValidation(bad_kernel) + validation.initialize() + frame, _ = validation.run_analysis(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + validation_errors = collect_validation_errors(frame, QubitValidation) + assert len(validation_errors) == 2 + print(validation.get_validation_errors()) diff --git a/test/analysis/validation/util.py b/test/analysis/validation/util.py index bdfae6a9..d82b3322 100644 --- a/test/analysis/validation/util.py +++ b/test/analysis/validation/util.py @@ -1,17 +1,17 @@ -from typing import TypeVar +from typing import List from kirin.analysis import ForwardFrame from bloqade.analysis.validation.nocloning.lattice import QubitValidation -T = TypeVar("T", bound=QubitValidation) - def collect_validation_errors( - frame: ForwardFrame[QubitValidation], typ: type[T] -) -> list[T]: - return [ - validation_errors - for validation_errors in frame.entries.values() - if isinstance(validation_errors, typ) and len(validation_errors.violations) > 0 - ] + frame: ForwardFrame[QubitValidation], typ: type[QubitValidation] +) -> List[str]: + """Collect individual violation strings from all QubitValidation entries of type `typ`.""" + violations: List[str] = [] + for validation_val in frame.entries.values(): + if isinstance(validation_val, typ): + for v in getattr(validation_val, "violations", ()): + violations.append(v) + return violations From dca242296e9bd4ef0df3591ad9fe0f2490802766 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Thu, 6 Nov 2025 12:25:37 -0500 Subject: [PATCH 03/20] Updated ValidationError reporting --- .../analysis/validation/nocloning/analysis.py | 55 ++++++++++++++----- test/analysis/validation/test_no_cloning.py | 25 ++++++++- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index 04c59952..43274950 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -3,6 +3,7 @@ from kirin import ir from kirin.analysis import Forward, TypeInference from kirin.dialects import func +from kirin.ir.exception import ValidationError from kirin.analysis.forward import ForwardFrame from bloqade.analysis.address import ( @@ -14,6 +15,19 @@ from .lattice import QubitValidation +class QubitValidationError(ValidationError): + """ValidationError that records which qubit and gate caused the violation.""" + + qubit_id: int + gate_name: str + + def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): + # message stored in ValidationError so formatting/hint() will include it + super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") + self.qubit_id = qubit_id + self.gate_name = gate_name + + class NoCloningValidation(Forward[QubitValidation]): """ Validates the no-cloning theorem by tracking qubit addresses. @@ -26,7 +40,9 @@ class NoCloningValidation(Forward[QubitValidation]): _address_frame: ForwardFrame[Address] = field(init=False) _type_frame: ForwardFrame = field(init=False) method: ir.Method - _validation_errors: list[str] = field(default_factory=list, init=False) + _validation_errors: list[QubitValidationError] = field( + default_factory=list, init=False + ) def __init__(self, mtd: ir.Method): """ @@ -63,13 +79,9 @@ def get_qubit_addresses(self, addr: Address) -> frozenset[int]: case _: return frozenset() - def get_stmt_info(self, stmt: ir.Statement) -> str: - """String Report about the statement for violation messages.""" - if isinstance(stmt, func.Invoke) and hasattr(stmt, "callee"): - gate_name = stmt.callee.sym_name.upper() - return f"{gate_name} Gate" - - return f"{stmt.__class__.__name__}@{stmt}" + def format_violation(self, qubit_id: int, gate_name: str) -> str: + """Return the violation string for a qubit + gate.""" + return f"Qubit[{qubit_id}] on {gate_name} Gate" def eval_stmt_fallback( self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement @@ -101,13 +113,13 @@ def eval_stmt_fallback( seen: set[int] = set() violations: list[str] = [] - stmt_info = self.get_stmt_info(stmt) for qubit_addr in used_addrs: if qubit_addr in seen: - violations.append(f"Qubit[{qubit_addr}] on {stmt_info}") + gate_name = stmt.callee.sym_name.upper() + violations.append(self.format_violation(qubit_addr, gate_name)) self._validation_errors.append( - f"Qubit[{qubit_addr}] on {stmt_info} in {stmt.source}" + QubitValidationError(stmt, qubit_addr, gate_name) ) seen.add(qubit_addr) @@ -123,6 +135,21 @@ def run_method( self_mt = self.method_self(method) return self.run_callable(method.code, (self_mt,) + args) - def get_validation_errors(self) -> str: - """Retrieve collected validation error messages.""" - return "\n".join(self._validation_errors) + def raise_validation_errors(self): + """Raise validation error for each no-cloning violation found. + Points to source file and line with snippet. + """ + if not self._validation_errors: + return + + # If multiple errors, print all with snippets first + if len(self._validation_errors) > 1: + for err in self._validation_errors: + err.attach(self.method) + # Print error message before snippet + print( + f"\033[31mValidation Error\033[0m: Cloned qubit [{err.qubit_id}] at {err.gate_name} gate." + ) + print(err.hint()) + print(f"Raised {len(self._validation_errors)} error(s).") + raise diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py index cf666640..57536bf0 100644 --- a/test/analysis/validation/test_no_cloning.py +++ b/test/analysis/validation/test_no_cloning.py @@ -25,6 +25,8 @@ def bad_control(): bad_control.print(analysis=frame.entries) validation_errors = collect_validation_errors(frame, QubitValidation) assert len(validation_errors) == 1 + with pytest.raises(Exception): + validation.raise_validation_errors() @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) @@ -45,6 +47,8 @@ def bad_control(cond: bool): bad_control.print(analysis=frame.entries) validation_errors = collect_validation_errors(frame, QubitValidation) assert len(validation_errors) == 2 + with pytest.raises(Exception): + validation.raise_validation_errors() @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) @@ -75,7 +79,8 @@ def good_kernel(): frame, _ = validation.run_analysis(good_kernel) validation_errors = collect_validation_errors(frame, QubitValidation) assert len(validation_errors) == 1 - print(validation.get_validation_errors()) + with pytest.raises(Exception): + validation.raise_validation_errors() def test_parallel_fail(): @@ -91,4 +96,20 @@ def bad_kernel(): bad_kernel.print(analysis=frame.entries) validation_errors = collect_validation_errors(frame, QubitValidation) assert len(validation_errors) == 2 - print(validation.get_validation_errors()) + with pytest.raises(Exception): + validation.raise_validation_errors() + + +# def test_potential_fail(): +# @squin.kernel +# def bad_kernel(a: int, b: int): +# q = squin.qalloc(5) +# squin.cx(q[a], q[b]) + +# validation = NoCloningValidation(bad_kernel) +# validation.initialize() +# frame, _ = validation.run_analysis(bad_kernel) +# print() +# bad_kernel.print(analysis=frame.entries) +# validation_errors = collect_validation_errors(frame, QubitValidation) +# assert len(validation_errors) == 0 From 40f5d6ccc0293fbb3f1a4e1114d31115a0b0f995 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Thu, 6 Nov 2025 14:51:43 -0500 Subject: [PATCH 04/20] Refactor no-cloning validation: enhance error handling and improve test coverage --- .../analysis/validation/nocloning/analysis.py | 145 +++++++++++++----- .../analysis/validation/nocloning/lattice.py | 102 ++++++++---- test/analysis/validation/test_no_cloning.py | 96 ++++++++---- test/analysis/validation/util.py | 20 ++- 4 files changed, 264 insertions(+), 99 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index 43274950..2202cad8 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -10,24 +10,43 @@ Address, AddressAnalysis, ) -from bloqade.analysis.address.lattice import AddressReg, AddressQubit +from bloqade.analysis.address.lattice import ( + Unknown, + AddressReg, + UnknownReg, + AddressQubit, + PartialIList, + PartialTuple, + UnknownQubit, +) -from .lattice import QubitValidation +from .lattice import May, Top, Must, Bottom, QubitValidation class QubitValidationError(ValidationError): - """ValidationError that records which qubit and gate caused the violation.""" + """ValidationError for definite (Must) violations with concrete qubit addresses.""" qubit_id: int gate_name: str def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): - # message stored in ValidationError so formatting/hint() will include it super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") self.qubit_id = qubit_id self.gate_name = gate_name +class PotentialQubitValidationError(ValidationError): + """ValidationError for potential (May) violations with unknown addresses.""" + + gate_name: str + condition: str + + def __init__(self, node: ir.IRNode, gate_name: str, condition: str): + super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.") + self.gate_name = gate_name + self.condition = condition + + class NoCloningValidation(Forward[QubitValidation]): """ Validates the no-cloning theorem by tracking qubit addresses. @@ -40,9 +59,7 @@ class NoCloningValidation(Forward[QubitValidation]): _address_frame: ForwardFrame[Address] = field(init=False) _type_frame: ForwardFrame = field(init=False) method: ir.Method - _validation_errors: list[QubitValidationError] = field( - default_factory=list, init=False - ) + _validation_errors: list[ValidationError] = field(default_factory=list, init=False) def __init__(self, mtd: ir.Method): """ @@ -88,47 +105,95 @@ def eval_stmt_fallback( ) -> tuple[QubitValidation, ...]: """ Default statement evaluation: check for qubit usage violations. + Returns Bottom, May, Must, or Top depending on what we can prove. """ if not isinstance(stmt, func.Invoke): - return tuple(QubitValidation.bottom() for _ in stmt.results) + return tuple(Bottom() for _ in stmt.results) address_frame = self._address_frame if address_frame is None: - return tuple(QubitValidation.top() for _ in stmt.results) + return tuple(Top() for _ in stmt.results) - has_qubit_args = any( - isinstance(address_frame.get(arg), (AddressQubit, AddressReg)) - for arg in stmt.args - ) + concrete_addrs: list[int] = [] + has_unknown = False + has_qubit_args = False + unknown_arg_names: list[str] = [] - if not has_qubit_args: - return tuple(QubitValidation.bottom() for _ in stmt.results) - - used_addrs: list[int] = [] for arg in stmt.args: addr = address_frame.get(arg) - qubit_addrs = self.get_qubit_addresses(addr) - used_addrs.extend(qubit_addrs) + match addr: + case AddressQubit(data=qubit_addr): + has_qubit_args = True + concrete_addrs.append(qubit_addr) + case AddressReg(data=addrs): + has_qubit_args = True + concrete_addrs.extend(addrs) + case UnknownQubit() | UnknownReg() | Unknown(): + has_qubit_args = True + has_unknown = True + arg_name = self._get_source_name(arg) + unknown_arg_names.append(arg_name) + case _: + pass + + if not has_qubit_args: + return tuple(Bottom() for _ in stmt.results) seen: set[int] = set() - violations: list[str] = [] + must_violations: list[str] = [] + gate_name = stmt.callee.sym_name.upper() - for qubit_addr in used_addrs: + for qubit_addr in concrete_addrs: if qubit_addr in seen: - gate_name = stmt.callee.sym_name.upper() - violations.append(self.format_violation(qubit_addr, gate_name)) + violation = self.format_violation(qubit_addr, gate_name) + must_violations.append(violation) self._validation_errors.append( QubitValidationError(stmt, qubit_addr, gate_name) ) seen.add(qubit_addr) - if not violations: - return tuple(QubitValidation(violations=frozenset()) for _ in stmt.results) + if must_violations: + usage = Must(violations=frozenset(must_violations)) + elif has_unknown: + args_str = " == ".join(unknown_arg_names) + if len(unknown_arg_names) > 1: + condition = f", when {args_str}" + else: + condition = f", with unknown index {args_str}" + + self._validation_errors.append( + PotentialQubitValidationError(stmt, gate_name, condition) + ) + + usage = May( + violations=frozenset([f"Unknown qubits at {gate_name} Gate{condition}"]) + ) + else: + usage = Bottom() - usage = QubitValidation(violations=frozenset(violations)) return tuple(usage for _ in stmt.results) if stmt.results else (usage,) + def _get_source_name(self, value: ir.SSAValue) -> str: + """Trace back to get the source variable name for a value. + + For getitem operations like q[a], returns 'a'. + For direct values, returns the value's name. + """ + from kirin.dialects.py.indexing import GetItem + + if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): + index_arg = value.stmt.args[1] + return self._get_source_name(index_arg) + + if isinstance(value, ir.BlockArgument): + return value.name or f"arg{value.index}" + + if hasattr(value, "name") and value.name: + return value.name + + return str(value) + def run_method( self, method: ir.Method, args: tuple[QubitValidation, ...] ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: @@ -136,20 +201,30 @@ def run_method( return self.run_callable(method.code, (self_mt,) + args) def raise_validation_errors(self): - """Raise validation error for each no-cloning violation found. + """Raise validation errors for both definite and potential violations. Points to source file and line with snippet. """ if not self._validation_errors: return - # If multiple errors, print all with snippets first - if len(self._validation_errors) > 1: - for err in self._validation_errors: - err.attach(self.method) - # Print error message before snippet + # Print all errors with snippets + for err in self._validation_errors: + err.attach(self.method) + + # Format error message based on type + if isinstance(err, QubitValidationError): print( - f"\033[31mValidation Error\033[0m: Cloned qubit [{err.qubit_id}] at {err.gate_name} gate." + f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" ) - print(err.hint()) - print(f"Raised {len(self._validation_errors)} error(s).") + elif isinstance(err, PotentialQubitValidationError): + print( + f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" + ) + else: + print( + f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" + ) + + print(err.hint()) + raise diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py index fbc54634..559ef7eb 100644 --- a/src/bloqade/analysis/validation/nocloning/lattice.py +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import FrozenSet, final from dataclasses import field, dataclass @@ -15,9 +16,11 @@ class QubitValidation( SimpleMeetMixin["QubitValidation"], BoundedLattice["QubitValidation"], ): - """Tracks cloning violations detected during analysis.""" + """Base class for qubit cloning validation lattice. - violations: FrozenSet[str] = field(default_factory=frozenset) + Lattice ordering: + Bottom ⊑ May{...} ⊑ Must{...} ⊑ Top + """ @classmethod def bottom(cls) -> "QubitValidation": @@ -26,52 +29,95 @@ def bottom(cls) -> "QubitValidation": @classmethod def top(cls) -> "QubitValidation": - """Unknown state - assume potential violations""" + """Unknown state""" return Top() + @abstractmethod def is_subseteq(self, other: "QubitValidation") -> bool: - """Check if this state is a subset of another. - - Lattice ordering: - Bottom ⊑ {{'Qubit[1] at CX Gate'}} ⊑ {{'Qubit[0] at CX Gate'},{'Qubit[1] at CX Gate'}} ⊑ Top - """ - if isinstance(other, Top): - return True - if isinstance(self, Bottom): - return True - if isinstance(other, Bottom): - return False - - return self.violations.issubset(other.violations) - - def __repr__(self) -> str: - """Custom repr to show violations clearly.""" - if not self.violations: - return "QubitValidation()" - return f"QubitValidation(violations={self.violations})" + """Check if this state is a subset of another.""" + ... @final class Bottom(QubitValidation, metaclass=SingletonMeta): - """Bottom element representing no violations.""" + """Bottom element: no violations detected (safe).""" def is_subseteq(self, other: QubitValidation) -> bool: """Bottom is subset of everything.""" return True def __repr__(self) -> str: - """Cleaner printing.""" - return "⊥ (Bottom)" + return "⊥ (No Errors)" @final class Top(QubitValidation, metaclass=SingletonMeta): - """Top element representing unknown state with potential violations.""" + """Top element: unknown state (worst case - assume violations possible).""" def is_subseteq(self, other: QubitValidation) -> bool: """Top is only subset of Top.""" return isinstance(other, Top) def __repr__(self) -> str: - """Cleaner printing.""" - return "⊤ (Top)" + return "⊤ (Unknown)" + + +@final +@dataclass +class May(QubitValidation): + """Potential violations that may occur depending on runtime values. + + Used when we have unknown addresses (UnknownQubit, UnknownReg, Unknown). + """ + + violations: FrozenSet[str] = field(default_factory=frozenset) + + def is_subseteq(self, other: QubitValidation) -> bool: + """May ⊑ May' if violations ⊆ violations' + May ⊑ Must (any may is less precise than must) + May ⊑ Top + """ + match other: + case Bottom(): + return False + case May(violations=other_violations): + return self.violations.issubset(other_violations) + case Must(): + return True # May is less precise than Must + case Top(): + return True + return False + + def __repr__(self) -> str: + if not self.violations: + return "MayError(∅)" + return f"MayError({self.violations})" + + +@final +@dataclass +class Must(QubitValidation): + """Definite violations with concrete qubit addresses. + + These are violations we can prove will definitely occur. + """ + + violations: FrozenSet[str] = field(default_factory=frozenset) + + def is_subseteq(self, other: QubitValidation) -> bool: + """Must ⊑ Must' if violations ⊆ violations' + Must ⊑ Top + """ + match other: + case Bottom() | May(): + return False + case Must(violations=other_violations): + return self.violations.issubset(other_violations) + case Top(): + return True + return False + + def __repr__(self) -> str: + if not self.violations: + return "MustError(∅)" + return f"MustError({self.violations})" diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py index 57536bf0..7a2e9510 100644 --- a/test/analysis/validation/test_no_cloning.py +++ b/test/analysis/validation/test_no_cloning.py @@ -1,13 +1,12 @@ from typing import Any import pytest -from util import collect_validation_errors +from util import collect_may_errors, collect_must_errors from kirin import ir from kirin.dialects.ilist.runtime import IList from bloqade import squin from bloqade.types import Qubit -from bloqade.analysis.validation.nocloning.lattice import QubitValidation from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation @@ -23,8 +22,10 @@ def bad_control(): frame, _ = validation.run_analysis(bad_control) print() bad_control.print(analysis=frame.entries) - validation_errors = collect_validation_errors(frame, QubitValidation) - assert len(validation_errors) == 1 + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 1 + assert len(may_errors) == 0 with pytest.raises(Exception): validation.raise_validation_errors() @@ -45,8 +46,10 @@ def bad_control(cond: bool): frame, _ = validation.run_analysis(bad_control) print() bad_control.print(analysis=frame.entries) - validation_errors = collect_validation_errors(frame, QubitValidation) - assert len(validation_errors) == 2 + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 2 + assert len(may_errors) == 0 with pytest.raises(Exception): validation.raise_validation_errors() @@ -54,31 +57,39 @@ def bad_control(cond: bool): @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): @squin.kernel - def bad_control(): - q = squin.qalloc(2) + def test(): + q = squin.qalloc(3) control_gate(q[0], q[1]) + squin.rx(1.57, q[0]) + squin.measure(q[0]) + control_gate(q[0], q[2]) - validation = NoCloningValidation(bad_control) + validation = NoCloningValidation(test) validation.initialize() - frame, _ = validation.run_analysis(bad_control) + frame, _ = validation.run_analysis(test) print() - bad_control.print(analysis=frame.entries) - validation_errors = collect_validation_errors(frame, QubitValidation) - assert len(validation_errors) == 0 + test.print(analysis=frame.entries) + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 0 + assert len(may_errors) == 0 def test_fail_2(): @squin.kernel def good_kernel(): q = squin.qalloc(2) + a = 1 squin.cx(q[0], q[1]) - squin.cy(q[1], q[1]) + squin.cy(q[1], q[a]) validation = NoCloningValidation(good_kernel) validation.initialize() frame, _ = validation.run_analysis(good_kernel) - validation_errors = collect_validation_errors(frame, QubitValidation) - assert len(validation_errors) == 1 + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 1 + assert len(may_errors) == 0 with pytest.raises(Exception): validation.raise_validation_errors() @@ -94,22 +105,47 @@ def bad_kernel(): frame, _ = validation.run_analysis(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - validation_errors = collect_validation_errors(frame, QubitValidation) - assert len(validation_errors) == 2 + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 2 + assert len(may_errors) == 0 + with pytest.raises(Exception): + validation.raise_validation_errors() + + +def test_potential_fail(): + @squin.kernel + def bad_kernel(a: int, b: int): + q = squin.qalloc(5) + squin.cx(q[a], q[2]) + + validation = NoCloningValidation(bad_kernel) + validation.initialize() + frame, _ = validation.run_analysis(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 0 + assert len(may_errors) == 1 with pytest.raises(Exception): validation.raise_validation_errors() -# def test_potential_fail(): -# @squin.kernel -# def bad_kernel(a: int, b: int): -# q = squin.qalloc(5) -# squin.cx(q[a], q[b]) +def test_potential_parallel_fail(): + @squin.kernel + def bad_kernel(a: IList): + q = squin.qalloc(5) + squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) -# validation = NoCloningValidation(bad_kernel) -# validation.initialize() -# frame, _ = validation.run_analysis(bad_kernel) -# print() -# bad_kernel.print(analysis=frame.entries) -# validation_errors = collect_validation_errors(frame, QubitValidation) -# assert len(validation_errors) == 0 + validation = NoCloningValidation(bad_kernel) + validation.initialize() + frame, _ = validation.run_analysis(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + must_errors = collect_must_errors(frame) + may_errors = collect_may_errors(frame) + assert len(must_errors) == 0 + assert len(may_errors) == 1 + with pytest.raises(Exception): + validation.raise_validation_errors() diff --git a/test/analysis/validation/util.py b/test/analysis/validation/util.py index d82b3322..c8736331 100644 --- a/test/analysis/validation/util.py +++ b/test/analysis/validation/util.py @@ -1,17 +1,25 @@ -from typing import List +from typing import List, TypeVar from kirin.analysis import ForwardFrame -from bloqade.analysis.validation.nocloning.lattice import QubitValidation +from bloqade.analysis.validation.nocloning.lattice import May, Must +T = TypeVar("T", bound=Must | May) -def collect_validation_errors( - frame: ForwardFrame[QubitValidation], typ: type[QubitValidation] -) -> List[str]: + +def collect_errors(frame: ForwardFrame[T], typ: type[T]) -> List[str]: """Collect individual violation strings from all QubitValidation entries of type `typ`.""" violations: List[str] = [] for validation_val in frame.entries.values(): if isinstance(validation_val, typ): - for v in getattr(validation_val, "violations", ()): + for v in validation_val.violations: violations.append(v) return violations + + +def collect_must_errors(frame): + return collect_errors(frame, Must) + + +def collect_may_errors(frame): + return collect_errors(frame, May) From 2a98e552051cf0ceee0c6fc9ae4fa503fb0274f4 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Thu, 6 Nov 2025 14:57:58 -0500 Subject: [PATCH 05/20] Shorter error messages --- src/bloqade/analysis/validation/nocloning/analysis.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index 2202cad8..65fcf257 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -160,15 +160,13 @@ def eval_stmt_fallback( if len(unknown_arg_names) > 1: condition = f", when {args_str}" else: - condition = f", with unknown index {args_str}" + condition = f", with unknown argument {args_str}" self._validation_errors.append( PotentialQubitValidationError(stmt, gate_name, condition) ) - usage = May( - violations=frozenset([f"Unknown qubits at {gate_name} Gate{condition}"]) - ) + usage = May(violations=frozenset([f"{gate_name} Gate{condition}"])) else: usage = Bottom() From 5f4500d6756719b0fff30859d166cbf1aa797620 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Thu, 6 Nov 2025 15:19:40 -0500 Subject: [PATCH 06/20] clarify join/meet and lattice structure --- .../analysis/validation/nocloning/lattice.py | 154 +++++++++++------- 1 file changed, 98 insertions(+), 56 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py index 559ef7eb..f8086ec8 100644 --- a/src/bloqade/analysis/validation/nocloning/lattice.py +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -2,122 +2,164 @@ from typing import FrozenSet, final from dataclasses import field, dataclass -from kirin.lattice import ( - SingletonMeta, - BoundedLattice, - SimpleJoinMixin, - SimpleMeetMixin, -) +from kirin.lattice import SingletonMeta, BoundedLattice @dataclass -class QubitValidation( - SimpleJoinMixin["QubitValidation"], - SimpleMeetMixin["QubitValidation"], - BoundedLattice["QubitValidation"], -): - """Base class for qubit cloning validation lattice. - - Lattice ordering: - Bottom ⊑ May{...} ⊑ Must{...} ⊑ Top +class QubitValidation(BoundedLattice["QubitValidation"]): + r"""Base class for qubit-cloning validation lattice. + + Linear ordering (more precise --> less precise): + Bottom ⊑ Must ⊑ May ⊑ Top + + Semantics: + - Bottom: proven safe / never occurs + - Must: definitely occurs (strong) + - May: possibly occurs (weak) + - Top: unknown / no information """ @classmethod def bottom(cls) -> "QubitValidation": - """No violations detected""" return Bottom() @classmethod def top(cls) -> "QubitValidation": - """Unknown state""" return Top() @abstractmethod - def is_subseteq(self, other: "QubitValidation") -> bool: - """Check if this state is a subset of another.""" - ... + def is_subseteq(self, other: "QubitValidation") -> bool: ... + + @abstractmethod + def join(self, other: "QubitValidation") -> "QubitValidation": ... + + @abstractmethod + def meet(self, other: "QubitValidation") -> "QubitValidation": ... @final class Bottom(QubitValidation, metaclass=SingletonMeta): - """Bottom element: no violations detected (safe).""" - def is_subseteq(self, other: QubitValidation) -> bool: - """Bottom is subset of everything.""" return True + def join(self, other: QubitValidation) -> QubitValidation: + return other + + def meet(self, other: QubitValidation) -> QubitValidation: + return self + def __repr__(self) -> str: return "⊥ (No Errors)" @final class Top(QubitValidation, metaclass=SingletonMeta): - """Top element: unknown state (worst case - assume violations possible).""" - def is_subseteq(self, other: QubitValidation) -> bool: - """Top is only subset of Top.""" return isinstance(other, Top) + def join(self, other: QubitValidation) -> QubitValidation: + return self + + def meet(self, other: QubitValidation) -> QubitValidation: + return other + def __repr__(self) -> str: return "⊤ (Unknown)" @final @dataclass -class May(QubitValidation): - """Potential violations that may occur depending on runtime values. - - Used when we have unknown addresses (UnknownQubit, UnknownReg, Unknown). - """ +class Must(QubitValidation): + """Definite violations.""" violations: FrozenSet[str] = field(default_factory=frozenset) def is_subseteq(self, other: QubitValidation) -> bool: - """May ⊑ May' if violations ⊆ violations' - May ⊑ Must (any may is less precise than must) - May ⊑ Top - """ match other: case Bottom(): return False - case May(violations=other_violations): - return self.violations.issubset(other_violations) - case Must(): - return True # May is less precise than Must + case Must(violations=ov): + return self.violations.issubset(ov) + case May(violations=_): + return True case Top(): return True return False + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=ov): + return Must(violations=self.violations | ov) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return Top() + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return Bottom() + case Must(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case May(violations=ov): + inter = self.violations & ov if ov else self.violations + return Must(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + def __repr__(self) -> str: - if not self.violations: - return "MayError(∅)" - return f"MayError({self.violations})" + return f"Must({self.violations or '∅'})" @final @dataclass -class Must(QubitValidation): - """Definite violations with concrete qubit addresses. - - These are violations we can prove will definitely occur. - """ +class May(QubitValidation): + """Potential violations.""" violations: FrozenSet[str] = field(default_factory=frozenset) def is_subseteq(self, other: QubitValidation) -> bool: - """Must ⊑ Must' if violations ⊆ violations' - Must ⊑ Top - """ match other: - case Bottom() | May(): + case Bottom(): + return False + case Must(): return False - case Must(violations=other_violations): - return self.violations.issubset(other_violations) + case May(violations=ov): + return self.violations.issubset(ov) case Top(): return True return False + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=ov): + return May(violations=self.violations | ov) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return Top() + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return Bottom() + case Must(violations=ov): + inter = self.violations & ov if ov else ov or self.violations + return Must(violations=inter) if inter else Bottom() + case May(violations=ov): + inter = self.violations & ov + return May(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + def __repr__(self) -> str: - if not self.violations: - return "MustError(∅)" - return f"MustError({self.violations})" + return f"May({self.violations or '∅'})" From 80c934945ce256ed2f7aefefed2dedbef92f8fd1 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 7 Nov 2025 09:49:17 -0500 Subject: [PATCH 07/20] fix linting --- src/bloqade/analysis/validation/nocloning/analysis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index 65fcf257..bb7f13d1 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -15,8 +15,6 @@ AddressReg, UnknownReg, AddressQubit, - PartialIList, - PartialTuple, UnknownQubit, ) From 642914864f27434d2c947549fe3a9176fab22d82 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 7 Nov 2025 09:56:15 -0500 Subject: [PATCH 08/20] Fix import warning --- test/analysis/validation/test_no_cloning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py index 7a2e9510..eda22ea9 100644 --- a/test/analysis/validation/test_no_cloning.py +++ b/test/analysis/validation/test_no_cloning.py @@ -1,7 +1,7 @@ from typing import Any import pytest -from util import collect_may_errors, collect_must_errors +from .util import collect_may_errors, collect_must_errors from kirin import ir from kirin.dialects.ilist.runtime import IList From 2e155d9f4413168fdc8d9302ed1f20ec08efb2d6 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 7 Nov 2025 10:01:46 -0500 Subject: [PATCH 09/20] fix import errors --- test/analysis/validation/test_no_cloning.py | 45 +++++++++++++-------- test/analysis/validation/util.py | 25 ------------ 2 files changed, 29 insertions(+), 41 deletions(-) delete mode 100644 test/analysis/validation/util.py diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py index eda22ea9..34b26122 100644 --- a/test/analysis/validation/test_no_cloning.py +++ b/test/analysis/validation/test_no_cloning.py @@ -1,14 +1,27 @@ -from typing import Any +from typing import Any, List, TypeVar import pytest -from .util import collect_may_errors, collect_must_errors from kirin import ir +from kirin.analysis import ForwardFrame from kirin.dialects.ilist.runtime import IList from bloqade import squin from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import May, Must, QubitValidation from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation +T = TypeVar("T", bound=Must | May) + + +def collect_errors(frame: ForwardFrame[QubitValidation], typ: type[T]) -> List[str]: + """Collect individual violation strings from all QubitValidation entries of type `typ`.""" + violations: List[str] = [] + for validation_val in frame.entries.values(): + if isinstance(validation_val, typ): + for v in validation_val.violations: + violations.append(v) + return violations + @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): @@ -22,8 +35,8 @@ def bad_control(): frame, _ = validation.run_analysis(bad_control) print() bad_control.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 1 assert len(may_errors) == 0 with pytest.raises(Exception): @@ -46,8 +59,8 @@ def bad_control(cond: bool): frame, _ = validation.run_analysis(bad_control) print() bad_control.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 2 assert len(may_errors) == 0 with pytest.raises(Exception): @@ -69,8 +82,8 @@ def test(): frame, _ = validation.run_analysis(test) print() test.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 0 assert len(may_errors) == 0 @@ -86,8 +99,8 @@ def good_kernel(): validation = NoCloningValidation(good_kernel) validation.initialize() frame, _ = validation.run_analysis(good_kernel) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 1 assert len(may_errors) == 0 with pytest.raises(Exception): @@ -105,8 +118,8 @@ def bad_kernel(): frame, _ = validation.run_analysis(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 2 assert len(may_errors) == 0 with pytest.raises(Exception): @@ -124,8 +137,8 @@ def bad_kernel(a: int, b: int): frame, _ = validation.run_analysis(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 0 assert len(may_errors) == 1 with pytest.raises(Exception): @@ -143,8 +156,8 @@ def bad_kernel(a: IList): frame, _ = validation.run_analysis(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_errors = collect_must_errors(frame) - may_errors = collect_may_errors(frame) + must_errors = collect_errors(frame, Must) + may_errors = collect_errors(frame, May) assert len(must_errors) == 0 assert len(may_errors) == 1 with pytest.raises(Exception): diff --git a/test/analysis/validation/util.py b/test/analysis/validation/util.py deleted file mode 100644 index c8736331..00000000 --- a/test/analysis/validation/util.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, TypeVar - -from kirin.analysis import ForwardFrame - -from bloqade.analysis.validation.nocloning.lattice import May, Must - -T = TypeVar("T", bound=Must | May) - - -def collect_errors(frame: ForwardFrame[T], typ: type[T]) -> List[str]: - """Collect individual violation strings from all QubitValidation entries of type `typ`.""" - violations: List[str] = [] - for validation_val in frame.entries.values(): - if isinstance(validation_val, typ): - for v in validation_val.violations: - violations.append(v) - return violations - - -def collect_must_errors(frame): - return collect_errors(frame, Must) - - -def collect_may_errors(frame): - return collect_errors(frame, May) From a35636abcda7c892f1291efa7cc926bfe8d9fb76 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 7 Nov 2025 13:38:15 -0500 Subject: [PATCH 10/20] Improve Validation framework to compose multiple validation analyses. --- .../analysis/validation/nocloning/analysis.py | 143 +++++++-------- .../analysis/validation/nocloning/impls.py | 77 ++++++-- .../analysis/validation/nocloning/lattice.py | 18 +- .../analysis/validation/validationpass.py | 168 ++++++++++++++++++ .../validation/nocloning/test_no_cloning.py | 168 ++++++++++++++++++ .../validation/test_compose_validation.py | 52 ++++++ test/analysis/validation/test_no_cloning.py | 164 ----------------- 7 files changed, 544 insertions(+), 246 deletions(-) create mode 100644 src/bloqade/analysis/validation/validationpass.py create mode 100644 test/analysis/validation/nocloning/test_no_cloning.py create mode 100644 test/analysis/validation/test_compose_validation.py delete mode 100644 test/analysis/validation/test_no_cloning.py diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index bb7f13d1..b73bb9b2 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -1,7 +1,7 @@ -from dataclasses import field +from typing import Any from kirin import ir -from kirin.analysis import Forward, TypeInference +from kirin.analysis import Forward from kirin.dialects import func from kirin.ir.exception import ValidationError from kirin.analysis.forward import ForwardFrame @@ -15,8 +15,11 @@ AddressReg, UnknownReg, AddressQubit, + PartialIList, + PartialTuple, UnknownQubit, ) +from bloqade.analysis.validation.validationpass import ValidationPass from .lattice import May, Top, Must, Bottom, QubitValidation @@ -45,66 +48,39 @@ def __init__(self, node: ir.IRNode, gate_name: str, condition: str): self.condition = condition -class NoCloningValidation(Forward[QubitValidation]): - """ - Validates the no-cloning theorem by tracking qubit addresses. - - Built on top of AddressAnalysis to get qubit address information. - """ +class _NoCloningAnalysis(Forward[QubitValidation]): + """Internal forward analysis for tracking qubit cloning violations.""" keys = ["validate.nocloning"] lattice = QubitValidation - _address_frame: ForwardFrame[Address] = field(init=False) - _type_frame: ForwardFrame = field(init=False) - method: ir.Method - _validation_errors: list[ValidationError] = field(default_factory=list, init=False) - def __init__(self, mtd: ir.Method): - """ - Input: - - an ir.Method / kernel function - infer dialects from it and remember method. - """ - self.method = mtd - super().__init__(mtd.dialects) + def __init__(self, dialects): + super().__init__(dialects) + self._address_frame: ForwardFrame[Address] | None = None + self._validation_errors: list[ValidationError] = [] def initialize(self): super().initialize() self._validation_errors = [] - address_analysis = AddressAnalysis(self.dialects) - address_analysis.initialize() - self._address_frame, _ = address_analysis.run_analysis(self.method) - - type_inference = TypeInference(self.dialects) - type_inference.initialize() - self._type_frame, _ = type_inference.run_analysis(self.method) - return self - def method_self(self, method: ir.Method) -> QubitValidation: - return self.lattice.bottom() + def run_method( + self, method: ir.Method, args: tuple[QubitValidation, ...] + ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: + if self._address_frame is None: + if getattr(self, "_address_analysis", None) is None: + addr_analysis = AddressAnalysis(self.dialects) + addr_analysis.initialize() + self._address_analysis = addr_analysis - def get_qubit_addresses(self, addr: Address) -> frozenset[int]: - """Extract concrete qubit addresses from an Address lattice element.""" - match addr: - case AddressQubit(data=qubit_addr): - return frozenset([qubit_addr]) - case AddressReg(data=addrs): - return frozenset(addrs) - case _: - return frozenset() + self._address_frame, _ = self._address_analysis.run_analysis(method) - def format_violation(self, qubit_id: int, gate_name: str) -> str: - """Return the violation string for a qubit + gate.""" - return f"Qubit[{qubit_id}] on {gate_name} Gate" + return self.run_callable(method.code, args) def eval_stmt_fallback( self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement ) -> tuple[QubitValidation, ...]: - """ - Default statement evaluation: check for qubit usage violations. - Returns Bottom, May, Must, or Top depending on what we can prove. - """ + """Check for qubit usage violations.""" if not isinstance(stmt, func.Invoke): return tuple(Bottom() for _ in stmt.results) @@ -127,7 +103,13 @@ def eval_stmt_fallback( case AddressReg(data=addrs): has_qubit_args = True concrete_addrs.extend(addrs) - case UnknownQubit() | UnknownReg() | Unknown(): + case ( + UnknownQubit() + | UnknownReg() + | PartialIList() + | PartialTuple() + | Unknown() + ): has_qubit_args = True has_unknown = True arg_name = self._get_source_name(arg) @@ -144,7 +126,7 @@ def eval_stmt_fallback( for qubit_addr in concrete_addrs: if qubit_addr in seen: - violation = self.format_violation(qubit_addr, gate_name) + violation = f"Qubit[{qubit_addr}] on {gate_name} Gate" must_violations.append(violation) self._validation_errors.append( QubitValidationError(stmt, qubit_addr, gate_name) @@ -171,11 +153,7 @@ def eval_stmt_fallback( return tuple(usage for _ in stmt.results) if stmt.results else (usage,) def _get_source_name(self, value: ir.SSAValue) -> str: - """Trace back to get the source variable name for a value. - - For getitem operations like q[a], returns 'a'. - For direct values, returns the value's name. - """ + """Trace back to get the source variable name.""" from kirin.dialects.py.indexing import GetItem if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): @@ -190,24 +168,52 @@ def _get_source_name(self, value: ir.SSAValue) -> str: return str(value) - def run_method( - self, method: ir.Method, args: tuple[QubitValidation, ...] - ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: - self_mt = self.method_self(method) - return self.run_callable(method.code, (self_mt,) + args) - def raise_validation_errors(self): - """Raise validation errors for both definite and potential violations. - Points to source file and line with snippet. +class NoCloningValidation(ValidationPass): + """Validates the no-cloning theorem by tracking qubit addresses.""" + + def __init__(self): + self.method: ir.Method | None = None + self._analysis: _NoCloningAnalysis | None = None + self._cached_address_frame = None + + def name(self) -> str: + return "No-Cloning Validation" + + def get_required_analyses(self) -> list[type]: + """Declare dependency on AddressAnalysis.""" + return [AddressAnalysis] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Use cached AddressAnalysis result.""" + self._cached_address_frame = cache.get(AddressAnalysis) + + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run the no-cloning validation analysis. + + Returns: + - frame: ForwardFrame with QubitValidation lattice values + - errors: List of validation errors found """ - if not self._validation_errors: - return + if self._analysis is None: + self._analysis = _NoCloningAnalysis(method.dialects) + + self.method = method + self._analysis.initialize() + if self._cached_address_frame is not None: + self._analysis._address_frame = self._cached_address_frame + frame, _ = self._analysis.run_analysis(method, args=None) - # Print all errors with snippets - for err in self._validation_errors: - err.attach(self.method) + return frame, self._analysis._validation_errors - # Format error message based on type + def print_validation_errors(self): + """Print all collected errors with formatted snippets.""" + if self._analysis is None: + return + errors = self._analysis._validation_errors + if not errors: + return + for err in errors: if isinstance(err, QubitValidationError): print( f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" @@ -220,7 +226,4 @@ def raise_validation_errors(self): print( f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" ) - print(err.hint()) - - raise diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py index c2c7fc1c..343742d2 100644 --- a/src/bloqade/analysis/validation/nocloning/impls.py +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -2,8 +2,12 @@ from kirin.analysis import ForwardFrame from kirin.dialects import scf -from .lattice import QubitValidation -from .analysis import NoCloningValidation +from .lattice import May, Top, Must, Bottom, QubitValidation +from .analysis import ( + QubitValidationError, + PotentialQubitValidationError, + _NoCloningAnalysis, +) @scf.dialect.register(key="validate.nocloning") @@ -11,23 +15,76 @@ class Scf(interp.MethodTable): @interp.impl(scf.IfElse) def if_else( self, - interp_: NoCloningValidation, + interp_: _NoCloningAnalysis, frame: ForwardFrame[QubitValidation], stmt: scf.IfElse, ): - cond_validation = frame.get(stmt.cond) + try: + cond_validation = frame.get(stmt.cond) + except Exception: + cond_validation = Top() - then_results = interp_.run_callable_region( - frame, stmt, stmt.then_body, (cond_validation,) + errors_before_then = len(interp_._validation_errors) + _ = interp_.run_callable_region(frame, stmt, stmt.then_body, (cond_validation,)) + errors_after_then = len(interp_._validation_errors) + + then_had_errors = errors_after_then > errors_before_then + then_errors = interp_._validation_errors[errors_before_then:errors_after_then] + then_state = ( + Must(violations=frozenset(err.args[0] for err in then_errors)) + if then_had_errors + else Bottom() ) if stmt.else_body: - else_results = interp_.run_callable_region( + errors_before_else = len(interp_._validation_errors) + _ = interp_.run_callable_region( frame, stmt, stmt.else_body, (cond_validation,) ) + errors_after_else = len(interp_._validation_errors) + + else_had_errors = errors_after_else > errors_before_else + else_errors = interp_._validation_errors[ + errors_before_else:errors_after_else + ] + else_state = ( + Must(violations=frozenset(err.args[0] for err in else_errors)) + if else_had_errors + else Bottom() + ) + + merged = then_state.join(else_state) - merged = tuple(then_results.join(else_results) for _ in stmt.results) + if isinstance(merged, May): + interp_._validation_errors = interp_._validation_errors[ + :errors_before_then + ] + + for err in then_errors + else_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, + err.gate_name, + ( + ", when condition is true" + if err in then_errors + else ", when condition is false" + ), + ) + interp_._validation_errors.append(potential_err) else: - merged = tuple(then_results for _ in stmt.results) + merged = then_state.join(Bottom()) + + if isinstance(merged, May): + interp_._validation_errors = interp_._validation_errors[ + :errors_before_then + ] + + for err in then_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is true" + ) + interp_._validation_errors.append(potential_err) - return merged if merged else (QubitValidation.bottom(),) + return (merged,) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py index f8086ec8..c31cf5ad 100644 --- a/src/bloqade/analysis/validation/nocloning/lattice.py +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -87,11 +87,25 @@ def is_subseteq(self, other: QubitValidation) -> bool: return False def join(self, other: QubitValidation) -> QubitValidation: + """Join with another validation state. + + Key insight: Must ⊔ Bottom = May (error on one path, not all) + """ match other: case Bottom(): - return self + # Error in one branch, safe in other = May (conditional error) + result = May(violations=self.violations) + return result case Must(violations=ov): - return Must(violations=self.violations | ov) + # Errors in both branches + common = self.violations & ov + all_violations = self.violations | ov + if common == all_violations: + # Same errors on all paths = Must + return Must(violations=all_violations) + else: + # Different errors on different paths = May + return May(violations=all_violations) case May(violations=ov): return May(violations=self.violations | ov) case Top(): diff --git a/src/bloqade/analysis/validation/validationpass.py b/src/bloqade/analysis/validation/validationpass.py new file mode 100644 index 00000000..094d1ba4 --- /dev/null +++ b/src/bloqade/analysis/validation/validationpass.py @@ -0,0 +1,168 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar +from dataclasses import field, dataclass + +from kirin import ir +from kirin.ir.exception import ValidationError + +T = TypeVar("T") + + +class ValidationPass(ABC, Generic[T]): + """Base class for a validation pass. + + Each pass analyzes an IR method and collects validation errors. + """ + + @abstractmethod + def name(self) -> str: + """Return the name of this validation pass.""" + ... + + @abstractmethod + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run validation and return (analysis_frame, errors). + + Returns: + - analysis_frame: The result frame from the analysis + - errors: List of validation errors (empty if valid) + """ + ... + + def get_required_analyses(self) -> list[type]: + """Return list of analysis classes this pass depends on. + + Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]). + The suite will run these analyses once and cache results. + """ + return [] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Receive cached analysis results from the suite. + + Override to store cached analysis frames/results. + Example: + self._address_frame = cache.get(AddressAnalysis) + """ + pass + + +@dataclass +class ValidationSuite: + """Compose multiple validation passes and run them together. + + Caches analysis results to avoid redundant computation when multiple + validation passes depend on the same underlying analysis. + + Example: + suite = ValidationSuite([ + NoCloningValidation, + AnotherValidation, + ]) + result = suite.validate(my_kernel) + print(result.format_errors()) + """ + + passes: list[type[ValidationPass]] = field(default_factory=list) + fail_fast: bool = False + _analysis_cache: dict[type, Any] = field(default_factory=dict, init=False) + + def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite": + """Add a validation pass to the suite.""" + self.passes.append(pass_cls) + return self + + def validate(self, method: ir.Method) -> "ValidationResult": + """Run all validation passes and collect results.""" + all_errors: dict[str, list[ValidationError]] = {} + all_frames: dict[str, Any] = {} + self._analysis_cache.clear() + for pass_cls in self.passes: + validator = pass_cls() + pass_name = validator.name() + + try: + required = validator.get_required_analyses() + for required_analysis in required: + if required_analysis not in self._analysis_cache: + analysis = required_analysis(method.dialects) + analysis.initialize() + frame, _ = analysis.run_analysis(method) + self._analysis_cache[required_analysis] = frame + + validator.set_analysis_cache(self._analysis_cache) + + frame, errors = validator.run(method) + all_frames[pass_name] = frame + + for err in errors: + if isinstance(err, ValidationError): + try: + err.attach(method) + except Exception: + pass + + if errors: + all_errors[pass_name] = errors + if self.fail_fast: + break + except Exception as e: + import traceback + + tb = traceback.format_exc() + all_errors[pass_name] = [ + ValidationError( + method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}" + ) + ] + if self.fail_fast: + break + + return ValidationResult(all_errors, all_frames) + + +@dataclass +class ValidationResult: + """Result of running a validation suite.""" + + errors: dict[str, list[ValidationError]] + frames: dict[str, Any] = field(default_factory=dict) + + def is_valid(self) -> bool: + """Check if validation passed (no errors).""" + return len(self.errors) == 0 + + def error_count(self) -> int: + """Total number of errors across all passes.""" + return sum(len(errs) for errs in self.errors.values()) + + def get_frame(self, pass_name: str) -> Any: + """Get the analysis frame for a specific pass.""" + return self.frames.get(pass_name) + + def format_errors(self) -> str: + """Format all errors with their pass names.""" + if not self.errors: + return "\n\033[32mAll validation passes succeeded\033[0m" + + lines = [ + f"\n\033[31mValidation failed with {self.error_count()} error(s):\033[0m" + ] + + for pass_name, pass_errors in self.errors.items(): + lines.append(f"\n\033[31m{pass_name}:\033[0m") + for err in pass_errors: + err_msg = err.args[0] if err.args else str(err) + lines.append(f" - {err_msg}") + if hasattr(err, "hint"): + hint = err.hint() + if hint: + lines.append(f" {hint}") + + return "\n".join(lines) + + def raise_if_invalid(self): + """Raise an exception if validation failed.""" + if not self.is_valid(): + first_errors = next(iter(self.errors.values())) + raise first_errors[0] diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py new file mode 100644 index 00000000..789b2db1 --- /dev/null +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -0,0 +1,168 @@ +from typing import Any, TypeVar + +import pytest +from kirin import ir +from kirin.dialects.ilist.runtime import IList + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import May, Must +from bloqade.analysis.validation.nocloning.analysis import ( + NoCloningValidation, + QubitValidationError, + PotentialQubitValidationError, +) + +T = TypeVar("T", bound=Must | May) + + +def collect_errors_from_validation( + validation: NoCloningValidation, +) -> tuple[int, int]: + """Count Must (definite) and May (potential) errors from the validation pass. + + Returns: + (must_count, may_count) - number of definite and potential errors + """ + must_count = 0 + may_count = 0 + + if validation._analysis is None: + return (must_count, may_count) + + for err in validation._analysis._validation_errors: + if isinstance(err, QubitValidationError): + must_count += 1 + elif isinstance(err, PotentialQubitValidationError): + may_count += 1 + + return must_count, may_count + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(1) + control_gate(q[0], q[0]) + + validation = NoCloningValidation() + + frame, _ = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 + assert may_count == 0 + validation.print_validation_errors() + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(cond: bool): + q = squin.qalloc(10) + if cond: + control_gate(q[0], q[0]) + else: + control_gate(q[0], q[1]) + squin.cx(q[1], q[1]) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 # squin.cx(q[1], q[1]) outside conditional + assert may_count == 1 # control_gate(q[0], q[0]) inside conditional + validation.print_validation_errors() + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def test(): + q = squin.qalloc(3) + control_gate(q[0], q[1]) + squin.rx(1.57, q[0]) + squin.measure(q[0]) + control_gate(q[0], q[2]) + + validation = NoCloningValidation() + frame, _ = validation.run(test) + print() + test.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 0 + + +def test_fail_2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + a = 1 + squin.cx(q[0], q[1]) + squin.cy(q[1], q[a]) + + validation = NoCloningValidation() + frame, _ = validation.run(good_kernel) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 + assert may_count == 0 + validation.print_validation_errors() + + +def test_parallel_fail(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(5) + squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 2 + assert may_count == 0 + validation.print_validation_errors() + + +def test_potential_fail(): + @squin.kernel + def bad_kernel(a: int, b: int): + q = squin.qalloc(5) + squin.cx(q[a], q[2]) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 1 + validation.print_validation_errors() + + +def test_potential_parallel_fail(): + @squin.kernel + def bad_kernel(a: IList): + q = squin.qalloc(5) + squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 1 + validation.print_validation_errors() diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py new file mode 100644 index 00000000..8547e7cc --- /dev/null +++ b/test/analysis/validation/test_compose_validation.py @@ -0,0 +1,52 @@ +import pytest + +from bloqade import squin +from bloqade.analysis.validation.nocloning import NoCloningValidation +from bloqade.analysis.validation.validationpass import ValidationSuite + + +def test_validation_suite(): + @squin.kernel + def bad_kernel(a: int): + q = squin.qalloc(2) + squin.cx(q[0], q[0]) # cloning error + squin.cx(q[a], q[1]) # cloning error + + # Running no-cloning validation multiple times + suite = ValidationSuite( + [ + NoCloningValidation, + NoCloningValidation, + NoCloningValidation, + ] + ) + result = suite.validate(bad_kernel) + + assert not result.is_valid() + assert ( + result.error_count() == 2 + ) # Report 2 errors, even when validated multiple times + print(result.format_errors()) + with pytest.raises(Exception): + result.raise_if_invalid() + + +def test_validation_suite2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + squin.cx(q[0], q[1]) # cloning error + + # Running no-cloning validation multiple times + suite = ValidationSuite( + [ + NoCloningValidation, + ], + fail_fast=True, + ) + result = suite.validate(good_kernel) + + assert result.is_valid() + assert result.error_count() == 0 + print(result.format_errors()) + result.raise_if_invalid() diff --git a/test/analysis/validation/test_no_cloning.py b/test/analysis/validation/test_no_cloning.py deleted file mode 100644 index 34b26122..00000000 --- a/test/analysis/validation/test_no_cloning.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Any, List, TypeVar - -import pytest -from kirin import ir -from kirin.analysis import ForwardFrame -from kirin.dialects.ilist.runtime import IList - -from bloqade import squin -from bloqade.types import Qubit -from bloqade.analysis.validation.nocloning.lattice import May, Must, QubitValidation -from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation - -T = TypeVar("T", bound=Must | May) - - -def collect_errors(frame: ForwardFrame[QubitValidation], typ: type[T]) -> List[str]: - """Collect individual violation strings from all QubitValidation entries of type `typ`.""" - violations: List[str] = [] - for validation_val in frame.entries.values(): - if isinstance(validation_val, typ): - for v in validation_val.violations: - violations.append(v) - return violations - - -@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): - @squin.kernel - def bad_control(): - q = squin.qalloc(1) - control_gate(q[0], q[0]) - - validation = NoCloningValidation(bad_control) - validation.initialize() - frame, _ = validation.run_analysis(bad_control) - print() - bad_control.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 1 - assert len(may_errors) == 0 - with pytest.raises(Exception): - validation.raise_validation_errors() - - -@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): - @squin.kernel - def bad_control(cond: bool): - q = squin.qalloc(10) - if cond: - control_gate(q[0], q[0]) - else: - control_gate(q[0], q[1]) - squin.cx(q[1], q[1]) - - validation = NoCloningValidation(bad_control) - validation.initialize() - frame, _ = validation.run_analysis(bad_control) - print() - bad_control.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 2 - assert len(may_errors) == 0 - with pytest.raises(Exception): - validation.raise_validation_errors() - - -@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) -def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): - @squin.kernel - def test(): - q = squin.qalloc(3) - control_gate(q[0], q[1]) - squin.rx(1.57, q[0]) - squin.measure(q[0]) - control_gate(q[0], q[2]) - - validation = NoCloningValidation(test) - validation.initialize() - frame, _ = validation.run_analysis(test) - print() - test.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 0 - assert len(may_errors) == 0 - - -def test_fail_2(): - @squin.kernel - def good_kernel(): - q = squin.qalloc(2) - a = 1 - squin.cx(q[0], q[1]) - squin.cy(q[1], q[a]) - - validation = NoCloningValidation(good_kernel) - validation.initialize() - frame, _ = validation.run_analysis(good_kernel) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 1 - assert len(may_errors) == 0 - with pytest.raises(Exception): - validation.raise_validation_errors() - - -def test_parallel_fail(): - @squin.kernel - def bad_kernel(): - q = squin.qalloc(5) - squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) - - validation = NoCloningValidation(bad_kernel) - validation.initialize() - frame, _ = validation.run_analysis(bad_kernel) - print() - bad_kernel.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 2 - assert len(may_errors) == 0 - with pytest.raises(Exception): - validation.raise_validation_errors() - - -def test_potential_fail(): - @squin.kernel - def bad_kernel(a: int, b: int): - q = squin.qalloc(5) - squin.cx(q[a], q[2]) - - validation = NoCloningValidation(bad_kernel) - validation.initialize() - frame, _ = validation.run_analysis(bad_kernel) - print() - bad_kernel.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 0 - assert len(may_errors) == 1 - with pytest.raises(Exception): - validation.raise_validation_errors() - - -def test_potential_parallel_fail(): - @squin.kernel - def bad_kernel(a: IList): - q = squin.qalloc(5) - squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) - - validation = NoCloningValidation(bad_kernel) - validation.initialize() - frame, _ = validation.run_analysis(bad_kernel) - print() - bad_kernel.print(analysis=frame.entries) - must_errors = collect_errors(frame, Must) - may_errors = collect_errors(frame, May) - assert len(must_errors) == 0 - assert len(may_errors) == 1 - with pytest.raises(Exception): - validation.raise_validation_errors() From 2bf369ec29c20d4373308f2c565ed21f2ad1b3fc Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 7 Nov 2025 16:35:13 -0500 Subject: [PATCH 11/20] removed redundant `method` variable --- .../analysis/validation/nocloning/analysis.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index b73bb9b2..eff693c1 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -59,21 +59,16 @@ def __init__(self, dialects): self._address_frame: ForwardFrame[Address] | None = None self._validation_errors: list[ValidationError] = [] - def initialize(self): - super().initialize() - self._validation_errors = [] - return self + def method_self(self, method: ir.Method) -> QubitValidation: + return self.lattice.bottom() def run_method( self, method: ir.Method, args: tuple[QubitValidation, ...] ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: if self._address_frame is None: - if getattr(self, "_address_analysis", None) is None: - addr_analysis = AddressAnalysis(self.dialects) - addr_analysis.initialize() - self._address_analysis = addr_analysis - - self._address_frame, _ = self._address_analysis.run_analysis(method) + addr_analysis = AddressAnalysis(self.dialects) + addr_analysis.initialize() + self._address_frame, _ = addr_analysis.run_analysis(method) return self.run_callable(method.code, args) @@ -81,7 +76,6 @@ def eval_stmt_fallback( self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement ) -> tuple[QubitValidation, ...]: """Check for qubit usage violations.""" - if not isinstance(stmt, func.Invoke): return tuple(Bottom() for _ in stmt.results) @@ -173,7 +167,6 @@ class NoCloningValidation(ValidationPass): """Validates the no-cloning theorem by tracking qubit addresses.""" def __init__(self): - self.method: ir.Method | None = None self._analysis: _NoCloningAnalysis | None = None self._cached_address_frame = None @@ -198,11 +191,10 @@ def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: if self._analysis is None: self._analysis = _NoCloningAnalysis(method.dialects) - self.method = method self._analysis.initialize() if self._cached_address_frame is not None: self._analysis._address_frame = self._cached_address_frame - frame, _ = self._analysis.run_analysis(method, args=None) + frame, _ = self._analysis.run_analysis(method) return frame, self._analysis._validation_errors @@ -210,10 +202,7 @@ def print_validation_errors(self): """Print all collected errors with formatted snippets.""" if self._analysis is None: return - errors = self._analysis._validation_errors - if not errors: - return - for err in errors: + for err in self._analysis._validation_errors: if isinstance(err, QubitValidationError): print( f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" From a2af2f310d41779eba03c7e62e4de80f1522c358 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 10 Nov 2025 10:49:25 -0500 Subject: [PATCH 12/20] Fix commutativity of `join` operation --- .../analysis/validation/nocloning/lattice.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py index c31cf5ad..24ce01ae 100644 --- a/src/bloqade/analysis/validation/nocloning/lattice.py +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -43,6 +43,13 @@ def is_subseteq(self, other: QubitValidation) -> bool: return True def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=v): + return May(violations=v) + case May() | Top(): + return other return other def meet(self, other: QubitValidation) -> QubitValidation: @@ -87,25 +94,14 @@ def is_subseteq(self, other: QubitValidation) -> bool: return False def join(self, other: QubitValidation) -> QubitValidation: - """Join with another validation state. - - Key insight: Must ⊔ Bottom = May (error on one path, not all) - """ match other: case Bottom(): - # Error in one branch, safe in other = May (conditional error) - result = May(violations=self.violations) - return result + return May(violations=self.violations) case Must(violations=ov): - # Errors in both branches - common = self.violations & ov - all_violations = self.violations | ov - if common == all_violations: - # Same errors on all paths = Must - return Must(violations=all_violations) + if self.violations == ov: + return Must(violations=self.violations) else: - # Different errors on different paths = May - return May(violations=all_violations) + return May(violations=self.violations | ov) case May(violations=ov): return May(violations=self.violations | ov) case Top(): @@ -120,7 +116,7 @@ def meet(self, other: QubitValidation) -> QubitValidation: inter = self.violations & ov return Must(violations=inter) if inter else Bottom() case May(violations=ov): - inter = self.violations & ov if ov else self.violations + inter = self.violations & ov return Must(violations=inter) if inter else Bottom() case Top(): return self @@ -166,7 +162,7 @@ def meet(self, other: QubitValidation) -> QubitValidation: case Bottom(): return Bottom() case Must(violations=ov): - inter = self.violations & ov if ov else ov or self.violations + inter = self.violations & ov return Must(violations=inter) if inter else Bottom() case May(violations=ov): inter = self.violations & ov From b5724719742809ece2fcf2ed0d3420c51fe91801 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 10 Nov 2025 15:54:32 -0500 Subject: [PATCH 13/20] updated to work with new Kirin version --- .../analysis/validation/nocloning/analysis.py | 38 +++++++++---------- .../analysis/validation/nocloning/impls.py | 12 ++++-- .../analysis/validation/validationpass.py | 2 +- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index eff693c1..8cc1c579 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -51,7 +51,7 @@ def __init__(self, node: ir.IRNode, gate_name: str, condition: str): class _NoCloningAnalysis(Forward[QubitValidation]): """Internal forward analysis for tracking qubit cloning violations.""" - keys = ["validate.nocloning"] + keys = ("validate.nocloning",) lattice = QubitValidation def __init__(self, dialects): @@ -62,33 +62,32 @@ def __init__(self, dialects): def method_self(self, method: ir.Method) -> QubitValidation: return self.lattice.bottom() - def run_method( - self, method: ir.Method, args: tuple[QubitValidation, ...] - ) -> tuple[ForwardFrame[QubitValidation], QubitValidation]: + def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): + # Set up address frame before analysis if not already cached if self._address_frame is None: addr_analysis = AddressAnalysis(self.dialects) addr_analysis.initialize() - self._address_frame, _ = addr_analysis.run_analysis(method) + self._address_frame, _ = addr_analysis.run(method) - return self.run_callable(method.code, args) + # Now run the forward analysis with address frame populated + return super().run(method, *args, **kwargs) - def eval_stmt_fallback( - self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement + def eval_fallback( + self, frame: ForwardFrame[QubitValidation], node: ir.Statement ) -> tuple[QubitValidation, ...]: """Check for qubit usage violations.""" - if not isinstance(stmt, func.Invoke): - return tuple(Bottom() for _ in stmt.results) + if not isinstance(node, func.Invoke): + return tuple(Bottom() for _ in node.results) address_frame = self._address_frame if address_frame is None: - return tuple(Top() for _ in stmt.results) + return tuple(Top() for _ in node.results) concrete_addrs: list[int] = [] has_unknown = False has_qubit_args = False unknown_arg_names: list[str] = [] - - for arg in stmt.args: + for arg in node.args: addr = address_frame.get(arg) match addr: case AddressQubit(data=qubit_addr): @@ -112,18 +111,19 @@ def eval_stmt_fallback( pass if not has_qubit_args: - return tuple(Bottom() for _ in stmt.results) + return tuple(Bottom() for _ in node.results) seen: set[int] = set() must_violations: list[str] = [] - gate_name = stmt.callee.sym_name.upper() + s_name = getattr(node.callee, "sym_name", " str: """Trace back to get the source variable name.""" @@ -194,7 +194,7 @@ def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: self._analysis.initialize() if self._cached_address_frame is not None: self._analysis._address_frame = self._cached_address_frame - frame, _ = self._analysis.run_analysis(method) + frame, _ = self._analysis.run(method) return frame, self._analysis._validation_errors diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py index 343742d2..1341414f 100644 --- a/src/bloqade/analysis/validation/nocloning/impls.py +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -25,7 +25,9 @@ def if_else( cond_validation = Top() errors_before_then = len(interp_._validation_errors) - _ = interp_.run_callable_region(frame, stmt, stmt.then_body, (cond_validation,)) + with interp_.new_frame(stmt, has_parent_access=True) as then_frame: + interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) + frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) errors_after_then = len(interp_._validation_errors) then_had_errors = errors_after_then > errors_before_then @@ -38,9 +40,11 @@ def if_else( if stmt.else_body: errors_before_else = len(interp_._validation_errors) - _ = interp_.run_callable_region( - frame, stmt, stmt.else_body, (cond_validation,) - ) + with interp_.new_frame(stmt, has_parent_access=True) as else_frame: + interp_.frame_call_region( + else_frame, stmt, stmt.else_body, cond_validation + ) + frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) errors_after_else = len(interp_._validation_errors) else_had_errors = errors_after_else > errors_before_else diff --git a/src/bloqade/analysis/validation/validationpass.py b/src/bloqade/analysis/validation/validationpass.py index 094d1ba4..0755762c 100644 --- a/src/bloqade/analysis/validation/validationpass.py +++ b/src/bloqade/analysis/validation/validationpass.py @@ -87,7 +87,7 @@ def validate(self, method: ir.Method) -> "ValidationResult": if required_analysis not in self._analysis_cache: analysis = required_analysis(method.dialects) analysis.initialize() - frame, _ = analysis.run_analysis(method) + frame, _ = analysis.run(method) self._analysis_cache[required_analysis] = frame validator.set_analysis_cache(self._analysis_cache) From 9ccc42c352499b5439768404454a13559fff5b35 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 12:15:04 -0500 Subject: [PATCH 14/20] moved collecting errors to Kirin's InterpreterABC --- .../analysis/validation/nocloning/analysis.py | 21 +++---- .../analysis/validation/nocloning/impls.py | 55 +++++++++---------- .../validation/nocloning/test_no_cloning.py | 4 +- .../validation/test_compose_validation.py | 7 +-- 4 files changed, 40 insertions(+), 47 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index 8cc1c579..cd0ccf6b 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Sequence from kirin import ir from kirin.analysis import Forward @@ -57,19 +57,15 @@ class _NoCloningAnalysis(Forward[QubitValidation]): def __init__(self, dialects): super().__init__(dialects) self._address_frame: ForwardFrame[Address] | None = None - self._validation_errors: list[ValidationError] = [] def method_self(self, method: ir.Method) -> QubitValidation: return self.lattice.bottom() def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): - # Set up address frame before analysis if not already cached if self._address_frame is None: addr_analysis = AddressAnalysis(self.dialects) addr_analysis.initialize() self._address_frame, _ = addr_analysis.run(method) - - # Now run the forward analysis with address frame populated return super().run(method, *args, **kwargs) def eval_fallback( @@ -122,9 +118,10 @@ def eval_fallback( if qubit_addr in seen: violation = f"Qubit[{qubit_addr}] on {gate_name} Gate" must_violations.append(violation) - self._validation_errors.append( - QubitValidationError(node, qubit_addr, gate_name) + self.add_validation_error( + node, QubitValidationError(node, qubit_addr, gate_name) ) + seen.add(qubit_addr) if must_violations: @@ -136,8 +133,8 @@ def eval_fallback( else: condition = f", with unknown argument {args_str}" - self._validation_errors.append( - PotentialQubitValidationError(node, gate_name, condition) + self.add_validation_error( + node, PotentialQubitValidationError(node, gate_name, condition) ) usage = May(violations=frozenset([f"{gate_name} Gate{condition}"])) @@ -195,14 +192,14 @@ def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: if self._cached_address_frame is not None: self._analysis._address_frame = self._cached_address_frame frame, _ = self._analysis.run(method) - - return frame, self._analysis._validation_errors + return frame, self._analysis.get_validation_errors() def print_validation_errors(self): """Print all collected errors with formatted snippets.""" if self._analysis is None: return - for err in self._analysis._validation_errors: + validation_errors = self._analysis.get_validation_errors() + for err in validation_errors: if isinstance(err, QubitValidationError): print( f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py index 1341414f..c4062ce6 100644 --- a/src/bloqade/analysis/validation/nocloning/impls.py +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -24,71 +24,68 @@ def if_else( except Exception: cond_validation = Top() - errors_before_then = len(interp_._validation_errors) + errors_before_then_keys = set(interp_._validation_errors.keys()) + with interp_.new_frame(stmt, has_parent_access=True) as then_frame: interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) - errors_after_then = len(interp_._validation_errors) + then_keys = set(interp_._validation_errors.keys()) - errors_before_then_keys + then_errors = interp_.get_validation_errors(keys=then_keys) - then_had_errors = errors_after_then > errors_before_then - then_errors = interp_._validation_errors[errors_before_then:errors_after_then] then_state = ( Must(violations=frozenset(err.args[0] for err in then_errors)) - if then_had_errors + if bool(then_keys) else Bottom() ) if stmt.else_body: - errors_before_else = len(interp_._validation_errors) + errors_before_else_keys = set(interp_._validation_errors.keys()) + with interp_.new_frame(stmt, has_parent_access=True) as else_frame: interp_.frame_call_region( else_frame, stmt, stmt.else_body, cond_validation ) frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) - errors_after_else = len(interp_._validation_errors) + else_keys = set(interp_._validation_errors.keys()) - errors_before_else_keys + else_errors = interp_.get_validation_errors(keys=else_keys) - else_had_errors = errors_after_else > errors_before_else - else_errors = interp_._validation_errors[ - errors_before_else:errors_after_else - ] else_state = ( Must(violations=frozenset(err.args[0] for err in else_errors)) - if else_had_errors + if bool(else_keys) else Bottom() ) merged = then_state.join(else_state) if isinstance(merged, May): - interp_._validation_errors = interp_._validation_errors[ - :errors_before_then - ] + branch_keys = then_keys | else_keys + for k in branch_keys: + interp_._validation_errors.pop(k, None) - for err in then_errors + else_errors: + for err in then_errors: if isinstance(err, QubitValidationError): potential_err = PotentialQubitValidationError( - err.node, - err.gate_name, - ( - ", when condition is true" - if err in then_errors - else ", when condition is false" - ), + err.node, err.gate_name, ", when condition is true" ) - interp_._validation_errors.append(potential_err) + interp_.add_validation_error(err.node, potential_err) + + for err in else_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is false" + ) + interp_.add_validation_error(err.node, potential_err) else: merged = then_state.join(Bottom()) if isinstance(merged, May): - interp_._validation_errors = interp_._validation_errors[ - :errors_before_then - ] - + for k in then_keys: + interp_._validation_errors.pop(k, None) for err in then_errors: if isinstance(err, QubitValidationError): potential_err = PotentialQubitValidationError( err.node, err.gate_name, ", when condition is true" ) - interp_._validation_errors.append(potential_err) + interp_.add_validation_error(err.node, potential_err) return (merged,) diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py index 789b2db1..2faf9725 100644 --- a/test/analysis/validation/nocloning/test_no_cloning.py +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -29,8 +29,8 @@ def collect_errors_from_validation( if validation._analysis is None: return (must_count, may_count) - - for err in validation._analysis._validation_errors: + print(validation._analysis.get_validation_errors()) + for err in validation._analysis.get_validation_errors(): if isinstance(err, QubitValidationError): must_count += 1 elif isinstance(err, PotentialQubitValidationError): diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py index 8547e7cc..934de00c 100644 --- a/test/analysis/validation/test_compose_validation.py +++ b/test/analysis/validation/test_compose_validation.py @@ -9,8 +9,8 @@ def test_validation_suite(): @squin.kernel def bad_kernel(a: int): q = squin.qalloc(2) - squin.cx(q[0], q[0]) # cloning error - squin.cx(q[a], q[1]) # cloning error + squin.cx(q[0], q[0]) # definite cloning error + squin.cx(q[a], q[1]) # potential cloning error # Running no-cloning validation multiple times suite = ValidationSuite( @@ -35,9 +35,8 @@ def test_validation_suite2(): @squin.kernel def good_kernel(): q = squin.qalloc(2) - squin.cx(q[0], q[1]) # cloning error + squin.cx(q[0], q[1]) - # Running no-cloning validation multiple times suite = ValidationSuite( [ NoCloningValidation, From ca8f5cd8e054f19b2bf9b523139a5a3bcdd121ce Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 12:16:22 -0500 Subject: [PATCH 15/20] fix unused import --- src/bloqade/analysis/validation/nocloning/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index cd0ccf6b..e37b9eee 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any from kirin import ir from kirin.analysis import Forward From aeaea1c36a987b6ee2e6893655c2a8679ac62562 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 12 Nov 2025 13:01:55 -0500 Subject: [PATCH 16/20] Moved ValidationPass to Kirin --- .../analysis/validation/nocloning/analysis.py | 2 +- .../analysis/validation/validationpass.py | 168 ------------------ .../validation/test_compose_validation.py | 2 +- 3 files changed, 2 insertions(+), 170 deletions(-) delete mode 100644 src/bloqade/analysis/validation/validationpass.py diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index e37b9eee..fc523e41 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -5,6 +5,7 @@ from kirin.dialects import func from kirin.ir.exception import ValidationError from kirin.analysis.forward import ForwardFrame +from kirin.validation.validationpass import ValidationPass from bloqade.analysis.address import ( Address, @@ -19,7 +20,6 @@ PartialTuple, UnknownQubit, ) -from bloqade.analysis.validation.validationpass import ValidationPass from .lattice import May, Top, Must, Bottom, QubitValidation diff --git a/src/bloqade/analysis/validation/validationpass.py b/src/bloqade/analysis/validation/validationpass.py deleted file mode 100644 index 0755762c..00000000 --- a/src/bloqade/analysis/validation/validationpass.py +++ /dev/null @@ -1,168 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Generic, TypeVar -from dataclasses import field, dataclass - -from kirin import ir -from kirin.ir.exception import ValidationError - -T = TypeVar("T") - - -class ValidationPass(ABC, Generic[T]): - """Base class for a validation pass. - - Each pass analyzes an IR method and collects validation errors. - """ - - @abstractmethod - def name(self) -> str: - """Return the name of this validation pass.""" - ... - - @abstractmethod - def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: - """Run validation and return (analysis_frame, errors). - - Returns: - - analysis_frame: The result frame from the analysis - - errors: List of validation errors (empty if valid) - """ - ... - - def get_required_analyses(self) -> list[type]: - """Return list of analysis classes this pass depends on. - - Override to declare dependencies (e.g., [AddressAnalysis, AnotherAnalysis]). - The suite will run these analyses once and cache results. - """ - return [] - - def set_analysis_cache(self, cache: dict[type, Any]) -> None: - """Receive cached analysis results from the suite. - - Override to store cached analysis frames/results. - Example: - self._address_frame = cache.get(AddressAnalysis) - """ - pass - - -@dataclass -class ValidationSuite: - """Compose multiple validation passes and run them together. - - Caches analysis results to avoid redundant computation when multiple - validation passes depend on the same underlying analysis. - - Example: - suite = ValidationSuite([ - NoCloningValidation, - AnotherValidation, - ]) - result = suite.validate(my_kernel) - print(result.format_errors()) - """ - - passes: list[type[ValidationPass]] = field(default_factory=list) - fail_fast: bool = False - _analysis_cache: dict[type, Any] = field(default_factory=dict, init=False) - - def add_pass(self, pass_cls: type[ValidationPass]) -> "ValidationSuite": - """Add a validation pass to the suite.""" - self.passes.append(pass_cls) - return self - - def validate(self, method: ir.Method) -> "ValidationResult": - """Run all validation passes and collect results.""" - all_errors: dict[str, list[ValidationError]] = {} - all_frames: dict[str, Any] = {} - self._analysis_cache.clear() - for pass_cls in self.passes: - validator = pass_cls() - pass_name = validator.name() - - try: - required = validator.get_required_analyses() - for required_analysis in required: - if required_analysis not in self._analysis_cache: - analysis = required_analysis(method.dialects) - analysis.initialize() - frame, _ = analysis.run(method) - self._analysis_cache[required_analysis] = frame - - validator.set_analysis_cache(self._analysis_cache) - - frame, errors = validator.run(method) - all_frames[pass_name] = frame - - for err in errors: - if isinstance(err, ValidationError): - try: - err.attach(method) - except Exception: - pass - - if errors: - all_errors[pass_name] = errors - if self.fail_fast: - break - except Exception as e: - import traceback - - tb = traceback.format_exc() - all_errors[pass_name] = [ - ValidationError( - method.code, f"Validation pass '{pass_name}' failed: {e}\n{tb}" - ) - ] - if self.fail_fast: - break - - return ValidationResult(all_errors, all_frames) - - -@dataclass -class ValidationResult: - """Result of running a validation suite.""" - - errors: dict[str, list[ValidationError]] - frames: dict[str, Any] = field(default_factory=dict) - - def is_valid(self) -> bool: - """Check if validation passed (no errors).""" - return len(self.errors) == 0 - - def error_count(self) -> int: - """Total number of errors across all passes.""" - return sum(len(errs) for errs in self.errors.values()) - - def get_frame(self, pass_name: str) -> Any: - """Get the analysis frame for a specific pass.""" - return self.frames.get(pass_name) - - def format_errors(self) -> str: - """Format all errors with their pass names.""" - if not self.errors: - return "\n\033[32mAll validation passes succeeded\033[0m" - - lines = [ - f"\n\033[31mValidation failed with {self.error_count()} error(s):\033[0m" - ] - - for pass_name, pass_errors in self.errors.items(): - lines.append(f"\n\033[31m{pass_name}:\033[0m") - for err in pass_errors: - err_msg = err.args[0] if err.args else str(err) - lines.append(f" - {err_msg}") - if hasattr(err, "hint"): - hint = err.hint() - if hint: - lines.append(f" {hint}") - - return "\n".join(lines) - - def raise_if_invalid(self): - """Raise an exception if validation failed.""" - if not self.is_valid(): - first_errors = next(iter(self.errors.values())) - raise first_errors[0] diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py index 934de00c..86030a74 100644 --- a/test/analysis/validation/test_compose_validation.py +++ b/test/analysis/validation/test_compose_validation.py @@ -1,8 +1,8 @@ import pytest +from kirin.validation.validationpass import ValidationSuite from bloqade import squin from bloqade.analysis.validation.nocloning import NoCloningValidation -from bloqade.analysis.validation.validationpass import ValidationSuite def test_validation_suite(): From 06c7bbb7378e8ffee8d771ee9f64261ab6796904 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Fri, 14 Nov 2025 15:59:04 -0500 Subject: [PATCH 17/20] Remove redundant code in ifelse handling. --- .../analysis/validation/nocloning/impls.py | 69 +++++++++---------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py index c4062ce6..7dae4d7b 100644 --- a/src/bloqade/analysis/validation/nocloning/impls.py +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -24,68 +24,63 @@ def if_else( except Exception: cond_validation = Top() - errors_before_then_keys = set(interp_._validation_errors.keys()) + errors_before = set(interp_._validation_errors.keys()) with interp_.new_frame(stmt, has_parent_access=True) as then_frame: interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) - then_keys = set(interp_._validation_errors.keys()) - errors_before_then_keys + + then_keys = set(interp_._validation_errors.keys()) - errors_before then_errors = interp_.get_validation_errors(keys=then_keys) then_state = ( Must(violations=frozenset(err.args[0] for err in then_errors)) - if bool(then_keys) + if then_keys else Bottom() ) if stmt.else_body: - errors_before_else_keys = set(interp_._validation_errors.keys()) + errors_before_else = set(interp_._validation_errors.keys()) with interp_.new_frame(stmt, has_parent_access=True) as else_frame: interp_.frame_call_region( else_frame, stmt, stmt.else_body, cond_validation ) frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) - else_keys = set(interp_._validation_errors.keys()) - errors_before_else_keys + + else_keys = set(interp_._validation_errors.keys()) - errors_before_else else_errors = interp_.get_validation_errors(keys=else_keys) else_state = ( Must(violations=frozenset(err.args[0] for err in else_errors)) - if bool(else_keys) + if else_keys else Bottom() ) - - merged = then_state.join(else_state) - - if isinstance(merged, May): - branch_keys = then_keys | else_keys - for k in branch_keys: - interp_._validation_errors.pop(k, None) - - for err in then_errors: - if isinstance(err, QubitValidationError): - potential_err = PotentialQubitValidationError( - err.node, err.gate_name, ", when condition is true" - ) - interp_.add_validation_error(err.node, potential_err) - - for err in else_errors: - if isinstance(err, QubitValidationError): - potential_err = PotentialQubitValidationError( - err.node, err.gate_name, ", when condition is false" - ) - interp_.add_validation_error(err.node, potential_err) else: - merged = then_state.join(Bottom()) + else_state = Bottom() + else_keys = set() + else_errors = [] + merged = then_state.join(else_state) + all_branch_keys = then_keys | else_keys + for k in all_branch_keys: + interp_._validation_errors.pop(k, None) - if isinstance(merged, May): - for k in then_keys: - interp_._validation_errors.pop(k, None) - for err in then_errors: - if isinstance(err, QubitValidationError): - potential_err = PotentialQubitValidationError( - err.node, err.gate_name, ", when condition is true" - ) - interp_.add_validation_error(err.node, potential_err) + if isinstance(merged, Must): + for err in then_errors + else_errors: + if isinstance(err, QubitValidationError): + interp_.add_validation_error(err.node, err) + elif isinstance(merged, May): + for err in then_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is true" + ) + interp_.add_validation_error(err.node, potential_err) + for err in else_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is false" + ) + interp_.add_validation_error(err.node, potential_err) return (merged,) From 924d60aa61840600267cc526180e8b946dbce6a3 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 17 Nov 2025 10:57:48 -0500 Subject: [PATCH 18/20] Refactor validation analysis and error handling in NoCloningValidation. Simplified error handling of Scf.Ifelse --- .../analysis/validation/nocloning/analysis.py | 112 ++++++++++++------ .../analysis/validation/nocloning/impls.py | 73 ++++-------- .../analysis/validation/nocloning/lattice.py | 87 +++++++------- .../validation/nocloning/test_no_cloning.py | 57 ++++----- .../validation/test_compose_validation.py | 7 +- 5 files changed, 166 insertions(+), 170 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index fc523e41..d826a44a 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -71,7 +71,7 @@ def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidati def eval_fallback( self, frame: ForwardFrame[QubitValidation], node: ir.Statement ) -> tuple[QubitValidation, ...]: - """Check for qubit usage violations.""" + """Check for qubit usage violations and return lattice values.""" if not isinstance(node, func.Invoke): return tuple(Bottom() for _ in node.results) @@ -83,6 +83,7 @@ def eval_fallback( has_unknown = False has_qubit_args = False unknown_arg_names: list[str] = [] + for arg in node.args: addr = address_frame.get(arg) match addr: @@ -110,22 +111,17 @@ def eval_fallback( return tuple(Bottom() for _ in node.results) seen: set[int] = set() - must_violations: list[str] = [] - s_name = getattr(node.callee, "sym_name", "") gate_name = s_name.upper() for qubit_addr in concrete_addrs: if qubit_addr in seen: - violation = f"Qubit[{qubit_addr}] on {gate_name} Gate" - must_violations.append(violation) - self.add_validation_error( - node, QubitValidationError(node, qubit_addr, gate_name) - ) - + violations.add((qubit_addr, gate_name)) seen.add(qubit_addr) - if must_violations: - usage = Must(violations=frozenset(must_violations)) + if violations: + usage = Must(violations=frozenset(violations)) elif has_unknown: args_str = " == ".join(unknown_arg_names) if len(unknown_arg_names) > 1: @@ -133,11 +129,7 @@ def eval_fallback( else: condition = f", with unknown argument {args_str}" - self.add_validation_error( - node, PotentialQubitValidationError(node, gate_name, condition) - ) - - usage = May(violations=frozenset([f"{gate_name} Gate{condition}"])) + usage = May(violations=frozenset([(gate_name, condition)])) else: usage = Bottom() @@ -159,6 +151,48 @@ def _get_source_name(self, value: ir.SSAValue) -> str: return str(value) + def extract_errors_from_frame( + self, frame: ForwardFrame[QubitValidation] + ) -> list[ValidationError]: + """Extract validation errors from final lattice values. + + Only extracts errors from top-level statements (not nested in regions). + """ + errors = [] + seen_statements = set() + + for node, value in frame.entries.items(): + if isinstance(node, ir.ResultValue): + stmt = node.stmt + elif isinstance(node, ir.Statement): + stmt = node + else: + continue + if stmt in seen_statements: + continue + seen_statements.add(stmt) + if isinstance(value, Must): + for qubit_id, gate_name in value.violations: + errors.append(QubitValidationError(stmt, qubit_id, gate_name)) + elif isinstance(value, May): + for gate_name, condition in value.violations: + errors.append( + PotentialQubitValidationError(stmt, gate_name, condition) + ) + return errors + + def count_violations(self, frame: Any) -> int: + """Count individual violations from the frame, same as test helper.""" + from .lattice import May, Must + + total = 0 + for node, value in frame.entries.items(): + if isinstance(value, Must): + total += len(value.violations) + elif isinstance(value, May): + total += len(value.violations) + return total + class NoCloningValidation(ValidationPass): """Validates the no-cloning theorem by tracking qubit addresses.""" @@ -179,37 +213,39 @@ def set_analysis_cache(self, cache: dict[type, Any]) -> None: self._cached_address_frame = cache.get(AddressAnalysis) def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: - """Run the no-cloning validation analysis. - - Returns: - - frame: ForwardFrame with QubitValidation lattice values - - errors: List of validation errors found - """ + """Run the no-cloning validation analysis.""" if self._analysis is None: self._analysis = _NoCloningAnalysis(method.dialects) self._analysis.initialize() if self._cached_address_frame is not None: self._analysis._address_frame = self._cached_address_frame + frame, _ = self._analysis.run(method) - return frame, self._analysis.get_validation_errors() + errors = self._analysis.extract_errors_from_frame(frame) + + return frame, errors def print_validation_errors(self): """Print all collected errors with formatted snippets.""" if self._analysis is None: return - validation_errors = self._analysis.get_validation_errors() - for err in validation_errors: - if isinstance(err, QubitValidationError): - print( - f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" - ) - elif isinstance(err, PotentialQubitValidationError): - print( - f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" - ) - else: - print( - f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" - ) - print(err.hint()) + + if self._analysis.state._current_frame: + frame = self._analysis.state._current_frame + errors = self._analysis.extract_errors_from_frame(frame) + + for err in errors: + if isinstance(err, QubitValidationError): + print( + f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" + ) + elif isinstance(err, PotentialQubitValidationError): + print( + f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" + ) + else: + print( + f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" + ) + print(err.hint()) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py index 7dae4d7b..333a7079 100644 --- a/src/bloqade/analysis/validation/nocloning/impls.py +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -3,11 +3,7 @@ from kirin.dialects import scf from .lattice import May, Top, Must, Bottom, QubitValidation -from .analysis import ( - QubitValidationError, - PotentialQubitValidationError, - _NoCloningAnalysis, -) +from .analysis import _NoCloningAnalysis @scf.dialect.register(key="validate.nocloning") @@ -24,63 +20,40 @@ def if_else( except Exception: cond_validation = Top() - errors_before = set(interp_._validation_errors.keys()) - with interp_.new_frame(stmt, has_parent_access=True) as then_frame: interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) - frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) - - then_keys = set(interp_._validation_errors.keys()) - errors_before - then_errors = interp_.get_validation_errors(keys=then_keys) - then_state = ( - Must(violations=frozenset(err.args[0] for err in then_errors)) - if then_keys - else Bottom() - ) + then_state = Bottom() + for node, val in then_frame.entries.items(): + if isinstance(val, (Must, May)): + then_state = then_state.join(val) + else_state = Bottom() if stmt.else_body: - errors_before_else = set(interp_._validation_errors.keys()) - with interp_.new_frame(stmt, has_parent_access=True) as else_frame: interp_.frame_call_region( else_frame, stmt, stmt.else_body, cond_validation ) - frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) - else_keys = set(interp_._validation_errors.keys()) - errors_before_else - else_errors = interp_.get_validation_errors(keys=else_keys) + for node, val in else_frame.entries.items(): + if isinstance(val, (Must, May)): + else_state = else_state.join(val) - else_state = ( - Must(violations=frozenset(err.args[0] for err in else_errors)) - if else_keys - else Bottom() - ) - else: - else_state = Bottom() - else_keys = set() - else_errors = [] merged = then_state.join(else_state) - all_branch_keys = then_keys | else_keys - for k in all_branch_keys: - interp_._validation_errors.pop(k, None) - if isinstance(merged, Must): - for err in then_errors + else_errors: - if isinstance(err, QubitValidationError): - interp_.add_validation_error(err.node, err) - elif isinstance(merged, May): - for err in then_errors: - if isinstance(err, QubitValidationError): - potential_err = PotentialQubitValidationError( - err.node, err.gate_name, ", when condition is true" - ) - interp_.add_validation_error(err.node, potential_err) + if isinstance(merged, May): + then_has = not isinstance(then_state, Bottom) + else_has = not isinstance(else_state, Bottom) + + if then_has and not else_has: + new_violations = frozenset( + (gate, ", when condition is true") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) + elif else_has and not then_has: + new_violations = frozenset( + (gate, ", when condition is false") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) - for err in else_errors: - if isinstance(err, QubitValidationError): - potential_err = PotentialQubitValidationError( - err.node, err.gate_name, ", when condition is false" - ) - interp_.add_validation_error(err.node, potential_err) return (merged,) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py index 24ce01ae..a8ca5049 100644 --- a/src/bloqade/analysis/validation/nocloning/lattice.py +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -9,22 +9,25 @@ class QubitValidation(BoundedLattice["QubitValidation"]): r"""Base class for qubit-cloning validation lattice. - Linear ordering (more precise --> less precise): - Bottom ⊑ Must ⊑ May ⊑ Top - - Semantics: + Semantics for control flow: - Bottom: proven safe / never occurs - - Must: definitely occurs (strong) - - May: possibly occurs (weak) + - Must: definitely occurs on ALL paths + - May: possibly occurs on SOME paths - Top: unknown / no information + + Lattice ordering (more precise --> less precise): + Bottom ⊑ Must ⊑ May ⊑ Top + Bottom ⊑ May ⊑ Top + + Key insight: Must ⊔ Bottom = May (happens on some paths, not all) """ @classmethod - def bottom(cls) -> "QubitValidation": + def bottom(cls) -> "Bottom": return Bottom() @classmethod - def top(cls) -> "QubitValidation": + def top(cls) -> "Top": return Top() @abstractmethod @@ -43,13 +46,6 @@ def is_subseteq(self, other: QubitValidation) -> bool: return True def join(self, other: QubitValidation) -> QubitValidation: - match other: - case Bottom(): - return self - case Must(violations=v): - return May(violations=v) - case May() | Top(): - return other return other def meet(self, other: QubitValidation) -> QubitValidation: @@ -77,9 +73,10 @@ def __repr__(self) -> str: @final @dataclass class Must(QubitValidation): - """Definite violations.""" + """Definite violations with concrete qubit IDs and gate names.""" - violations: FrozenSet[str] = field(default_factory=frozenset) + violations: FrozenSet[tuple[int, str]] = field(default_factory=frozenset) + """Set of (qubit_id, gate_name) tuples""" def is_subseteq(self, other: QubitValidation) -> bool: match other: @@ -87,57 +84,56 @@ def is_subseteq(self, other: QubitValidation) -> bool: return False case Must(violations=ov): return self.violations.issubset(ov) - case May(violations=_): - return True - case Top(): + case May() | Top(): return True return False def join(self, other: QubitValidation) -> QubitValidation: match other: case Bottom(): - return May(violations=self.violations) + may_violations = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_violations) case Must(violations=ov): - if self.violations == ov: - return Must(violations=self.violations) - else: - return May(violations=self.violations | ov) + merged = self.violations | ov + return Must(violations=merged) case May(violations=ov): - return May(violations=self.violations | ov) + may_viols = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_viols | ov) case Top(): - return Top() + return other return Top() def meet(self, other: QubitValidation) -> QubitValidation: match other: case Bottom(): - return Bottom() + return other case Must(violations=ov): inter = self.violations & ov return Must(violations=inter) if inter else Bottom() - case May(violations=ov): - inter = self.violations & ov - return Must(violations=inter) if inter else Bottom() + case May(): + return self case Top(): return self return Bottom() def __repr__(self) -> str: - return f"Must({self.violations or '∅'})" + if not self.violations: + return "Must(∅)" + viols = ", ".join(f"Qubit[{qid}] at {gate}" for qid, gate in self.violations) + return f"Must({{{viols}}})" @final @dataclass class May(QubitValidation): - """Potential violations.""" + """Potential violations with gate names and conditions.""" - violations: FrozenSet[str] = field(default_factory=frozenset) + violations: FrozenSet[tuple[str, str]] = field(default_factory=frozenset) + """Set of (gate_name, condition) tuples""" def is_subseteq(self, other: QubitValidation) -> bool: match other: - case Bottom(): - return False - case Must(): + case Bottom() | Must(): return False case May(violations=ov): return self.violations.issubset(ov) @@ -150,20 +146,20 @@ def join(self, other: QubitValidation) -> QubitValidation: case Bottom(): return self case Must(violations=ov): - return May(violations=self.violations | ov) + may_viols = frozenset((gate, "") for _, gate in ov) + return May(violations=self.violations | may_viols) case May(violations=ov): return May(violations=self.violations | ov) case Top(): - return Top() + return other return Top() def meet(self, other: QubitValidation) -> QubitValidation: match other: case Bottom(): - return Bottom() - case Must(violations=ov): - inter = self.violations & ov - return Must(violations=inter) if inter else Bottom() + return other + case Must(): + return other case May(violations=ov): inter = self.violations & ov return May(violations=inter) if inter else Bottom() @@ -172,4 +168,7 @@ def meet(self, other: QubitValidation) -> QubitValidation: return Bottom() def __repr__(self) -> str: - return f"May({self.violations or '∅'})" + if not self.violations: + return "May(∅)" + viols = ", ".join(f"{gate}{cond}" for gate, cond in self.violations) + return f"May({{{viols}}})" diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py index 2faf9725..b35636f1 100644 --- a/test/analysis/validation/nocloning/test_no_cloning.py +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -7,34 +7,31 @@ from bloqade import squin from bloqade.types import Qubit from bloqade.analysis.validation.nocloning.lattice import May, Must -from bloqade.analysis.validation.nocloning.analysis import ( - NoCloningValidation, - QubitValidationError, - PotentialQubitValidationError, -) +from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation T = TypeVar("T", bound=Must | May) def collect_errors_from_validation( validation: NoCloningValidation, + frame, ) -> tuple[int, int]: """Count Must (definite) and May (potential) errors from the validation pass. Returns: - (must_count, may_count) - number of definite and potential errors + (must_count, may_count) - number of definite and potential violations """ must_count = 0 may_count = 0 if validation._analysis is None: return (must_count, may_count) - print(validation._analysis.get_validation_errors()) - for err in validation._analysis.get_validation_errors(): - if isinstance(err, QubitValidationError): - must_count += 1 - elif isinstance(err, PotentialQubitValidationError): - may_count += 1 + + for node, value in frame.entries.items(): + if isinstance(value, Must): + must_count += len(value.violations) + elif isinstance(value, May): + may_count += len(value.violations) return must_count, may_count @@ -48,17 +45,16 @@ def bad_control(): validation = NoCloningValidation() - frame, _ = validation.run(bad_control) + frame, errors = validation.run(bad_control) print() bad_control.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 1 assert may_count == 0 - validation.print_validation_errors() -@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +@pytest.mark.parametrize("control_gate", [squin.cx]) def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): @squin.kernel def bad_control(cond: bool): @@ -70,14 +66,13 @@ def bad_control(cond: bool): squin.cx(q[1], q[1]) validation = NoCloningValidation() - frame, _ = validation.run(bad_control) + frame, errors = validation.run(bad_control) print() bad_control.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 1 # squin.cx(q[1], q[1]) outside conditional assert may_count == 1 # control_gate(q[0], q[0]) inside conditional - validation.print_validation_errors() @pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) @@ -91,11 +86,11 @@ def test(): control_gate(q[0], q[2]) validation = NoCloningValidation() - frame, _ = validation.run(test) + frame, errors = validation.run(test) print() test.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 0 assert may_count == 0 @@ -109,12 +104,11 @@ def good_kernel(): squin.cy(q[1], q[a]) validation = NoCloningValidation() - frame, _ = validation.run(good_kernel) + frame, errors = validation.run(good_kernel) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 1 assert may_count == 0 - validation.print_validation_errors() def test_parallel_fail(): @@ -124,14 +118,13 @@ def bad_kernel(): squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) validation = NoCloningValidation() - frame, _ = validation.run(bad_kernel) + frame, errors = validation.run(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 2 assert may_count == 0 - validation.print_validation_errors() def test_potential_fail(): @@ -141,14 +134,13 @@ def bad_kernel(a: int, b: int): squin.cx(q[a], q[2]) validation = NoCloningValidation() - frame, _ = validation.run(bad_kernel) + frame, errors = validation.run(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 0 assert may_count == 1 - validation.print_validation_errors() def test_potential_parallel_fail(): @@ -158,11 +150,10 @@ def bad_kernel(a: IList): squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) validation = NoCloningValidation() - frame, _ = validation.run(bad_kernel) + frame, errors = validation.run(bad_kernel) print() bad_kernel.print(analysis=frame.entries) - must_count, may_count = collect_errors_from_validation(validation) + must_count, may_count = collect_errors_from_validation(validation, frame) assert must_count == 0 assert may_count == 1 - validation.print_validation_errors() diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py index 86030a74..b6d9e26a 100644 --- a/test/analysis/validation/test_compose_validation.py +++ b/test/analysis/validation/test_compose_validation.py @@ -1,4 +1,3 @@ -import pytest from kirin.validation.validationpass import ValidationSuite from bloqade import squin @@ -22,13 +21,11 @@ def bad_kernel(a: int): ) result = suite.validate(bad_kernel) - assert not result.is_valid() + assert not result.is_valid assert ( result.error_count() == 2 ) # Report 2 errors, even when validated multiple times print(result.format_errors()) - with pytest.raises(Exception): - result.raise_if_invalid() def test_validation_suite2(): @@ -45,7 +42,7 @@ def good_kernel(): ) result = suite.validate(good_kernel) - assert result.is_valid() + assert result.is_valid assert result.error_count() == 0 print(result.format_errors()) result.raise_if_invalid() From c2c825e5bf386bdf7ae16ca867f3170bc4c16c75 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Mon, 17 Nov 2025 11:21:57 -0500 Subject: [PATCH 19/20] Fix import --- src/bloqade/analysis/validation/nocloning/analysis.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py index d826a44a..de9211bc 100644 --- a/src/bloqade/analysis/validation/nocloning/analysis.py +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -3,7 +3,11 @@ from kirin import ir from kirin.analysis import Forward from kirin.dialects import func -from kirin.ir.exception import ValidationError +from kirin.ir.exception import ( + ValidationError, + DefiniteValidationError, + PotentialValidationError, +) from kirin.analysis.forward import ForwardFrame from kirin.validation.validationpass import ValidationPass @@ -24,7 +28,7 @@ from .lattice import May, Top, Must, Bottom, QubitValidation -class QubitValidationError(ValidationError): +class QubitValidationError(DefiniteValidationError): """ValidationError for definite (Must) violations with concrete qubit addresses.""" qubit_id: int @@ -36,7 +40,7 @@ def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): self.gate_name = gate_name -class PotentialQubitValidationError(ValidationError): +class PotentialQubitValidationError(PotentialValidationError): """ValidationError for potential (May) violations with unknown addresses.""" gate_name: str From 365758192a07f7d0000aa56b5fc014e621a40e34 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 19 Nov 2025 11:04:14 -0500 Subject: [PATCH 20/20] use `raise_if_invalid` instead of `format_errors`. --- test/analysis/validation/test_compose_validation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py index b6d9e26a..8396ccb6 100644 --- a/test/analysis/validation/test_compose_validation.py +++ b/test/analysis/validation/test_compose_validation.py @@ -1,3 +1,4 @@ +import pytest from kirin.validation.validationpass import ValidationSuite from bloqade import squin @@ -25,7 +26,9 @@ def bad_kernel(a: int): assert ( result.error_count() == 2 ) # Report 2 errors, even when validated multiple times - print(result.format_errors()) + with pytest.raises(Exception) as exc_info: + result.raise_if_invalid() + print(f"{exc_info.value}") def test_validation_suite2(): @@ -44,5 +47,4 @@ def good_kernel(): assert result.is_valid assert result.error_count() == 0 - print(result.format_errors()) result.raise_if_invalid()