diff --git a/src/bloqade/analysis/validation/nocloning/__init__.py b/src/bloqade/analysis/validation/nocloning/__init__.py new file mode 100644 index 00000000..a61f8ba0 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/__init__.py @@ -0,0 +1,2 @@ +from . import impls as impls +from .analysis import NoCloningValidation as NoCloningValidation diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py new file mode 100644 index 00000000..de9211bc --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -0,0 +1,255 @@ +from typing import Any + +from kirin import ir +from kirin.analysis import Forward +from kirin.dialects import func +from kirin.ir.exception import ( + ValidationError, + DefiniteValidationError, + PotentialValidationError, +) +from kirin.analysis.forward import ForwardFrame +from kirin.validation.validationpass import ValidationPass + +from bloqade.analysis.address import ( + Address, + AddressAnalysis, +) +from bloqade.analysis.address.lattice import ( + Unknown, + AddressReg, + UnknownReg, + AddressQubit, + PartialIList, + PartialTuple, + UnknownQubit, +) + +from .lattice import May, Top, Must, Bottom, QubitValidation + + +class QubitValidationError(DefiniteValidationError): + """ValidationError for definite (Must) violations with concrete qubit addresses.""" + + qubit_id: int + gate_name: str + + def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): + super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") + self.qubit_id = qubit_id + self.gate_name = gate_name + + +class PotentialQubitValidationError(PotentialValidationError): + """ValidationError for potential (May) violations with unknown addresses.""" + + gate_name: str + condition: str + + def __init__(self, node: ir.IRNode, gate_name: str, condition: str): + super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.") + self.gate_name = gate_name + self.condition = condition + + +class _NoCloningAnalysis(Forward[QubitValidation]): + """Internal forward analysis for tracking qubit cloning violations.""" + + keys = ("validate.nocloning",) + lattice = QubitValidation + + def __init__(self, dialects): + super().__init__(dialects) + self._address_frame: ForwardFrame[Address] | None = None + + def method_self(self, method: ir.Method) -> QubitValidation: + return self.lattice.bottom() + + def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): + if self._address_frame is None: + addr_analysis = AddressAnalysis(self.dialects) + addr_analysis.initialize() + self._address_frame, _ = addr_analysis.run(method) + return super().run(method, *args, **kwargs) + + def eval_fallback( + self, frame: ForwardFrame[QubitValidation], node: ir.Statement + ) -> tuple[QubitValidation, ...]: + """Check for qubit usage violations and return lattice values.""" + if not isinstance(node, func.Invoke): + return tuple(Bottom() for _ in node.results) + + address_frame = self._address_frame + if address_frame is None: + return tuple(Top() for _ in node.results) + + concrete_addrs: list[int] = [] + has_unknown = False + has_qubit_args = False + unknown_arg_names: list[str] = [] + + for arg in node.args: + addr = address_frame.get(arg) + match addr: + case AddressQubit(data=qubit_addr): + has_qubit_args = True + concrete_addrs.append(qubit_addr) + case AddressReg(data=addrs): + has_qubit_args = True + concrete_addrs.extend(addrs) + case ( + UnknownQubit() + | UnknownReg() + | PartialIList() + | PartialTuple() + | Unknown() + ): + has_qubit_args = True + has_unknown = True + arg_name = self._get_source_name(arg) + unknown_arg_names.append(arg_name) + case _: + pass + + if not has_qubit_args: + return tuple(Bottom() for _ in node.results) + + seen: set[int] = set() + violations: set[tuple[int, str]] = set() + s_name = getattr(node.callee, "sym_name", "") + gate_name = s_name.upper() + + for qubit_addr in concrete_addrs: + if qubit_addr in seen: + violations.add((qubit_addr, gate_name)) + seen.add(qubit_addr) + + if violations: + usage = Must(violations=frozenset(violations)) + elif has_unknown: + args_str = " == ".join(unknown_arg_names) + if len(unknown_arg_names) > 1: + condition = f", when {args_str}" + else: + condition = f", with unknown argument {args_str}" + + usage = May(violations=frozenset([(gate_name, condition)])) + else: + usage = Bottom() + + return tuple(usage for _ in node.results) if node.results else (usage,) + + def _get_source_name(self, value: ir.SSAValue) -> str: + """Trace back to get the source variable name.""" + from kirin.dialects.py.indexing import GetItem + + if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): + index_arg = value.stmt.args[1] + return self._get_source_name(index_arg) + + if isinstance(value, ir.BlockArgument): + return value.name or f"arg{value.index}" + + if hasattr(value, "name") and value.name: + return value.name + + return str(value) + + def extract_errors_from_frame( + self, frame: ForwardFrame[QubitValidation] + ) -> list[ValidationError]: + """Extract validation errors from final lattice values. + + Only extracts errors from top-level statements (not nested in regions). + """ + errors = [] + seen_statements = set() + + for node, value in frame.entries.items(): + if isinstance(node, ir.ResultValue): + stmt = node.stmt + elif isinstance(node, ir.Statement): + stmt = node + else: + continue + if stmt in seen_statements: + continue + seen_statements.add(stmt) + if isinstance(value, Must): + for qubit_id, gate_name in value.violations: + errors.append(QubitValidationError(stmt, qubit_id, gate_name)) + elif isinstance(value, May): + for gate_name, condition in value.violations: + errors.append( + PotentialQubitValidationError(stmt, gate_name, condition) + ) + return errors + + def count_violations(self, frame: Any) -> int: + """Count individual violations from the frame, same as test helper.""" + from .lattice import May, Must + + total = 0 + for node, value in frame.entries.items(): + if isinstance(value, Must): + total += len(value.violations) + elif isinstance(value, May): + total += len(value.violations) + return total + + +class NoCloningValidation(ValidationPass): + """Validates the no-cloning theorem by tracking qubit addresses.""" + + def __init__(self): + self._analysis: _NoCloningAnalysis | None = None + self._cached_address_frame = None + + def name(self) -> str: + return "No-Cloning Validation" + + def get_required_analyses(self) -> list[type]: + """Declare dependency on AddressAnalysis.""" + return [AddressAnalysis] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Use cached AddressAnalysis result.""" + self._cached_address_frame = cache.get(AddressAnalysis) + + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run the no-cloning validation analysis.""" + if self._analysis is None: + self._analysis = _NoCloningAnalysis(method.dialects) + + self._analysis.initialize() + if self._cached_address_frame is not None: + self._analysis._address_frame = self._cached_address_frame + + frame, _ = self._analysis.run(method) + errors = self._analysis.extract_errors_from_frame(frame) + + return frame, errors + + def print_validation_errors(self): + """Print all collected errors with formatted snippets.""" + if self._analysis is None: + return + + if self._analysis.state._current_frame: + frame = self._analysis.state._current_frame + errors = self._analysis.extract_errors_from_frame(frame) + + for err in errors: + if isinstance(err, QubitValidationError): + print( + f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" + ) + elif isinstance(err, PotentialQubitValidationError): + print( + f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" + ) + else: + print( + f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" + ) + print(err.hint()) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py new file mode 100644 index 00000000..333a7079 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -0,0 +1,59 @@ +from kirin import interp +from kirin.analysis import ForwardFrame +from kirin.dialects import scf + +from .lattice import May, Top, Must, Bottom, QubitValidation +from .analysis import _NoCloningAnalysis + + +@scf.dialect.register(key="validate.nocloning") +class Scf(interp.MethodTable): + @interp.impl(scf.IfElse) + def if_else( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: scf.IfElse, + ): + try: + cond_validation = frame.get(stmt.cond) + except Exception: + cond_validation = Top() + + with interp_.new_frame(stmt, has_parent_access=True) as then_frame: + interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) + + then_state = Bottom() + for node, val in then_frame.entries.items(): + if isinstance(val, (Must, May)): + then_state = then_state.join(val) + + else_state = Bottom() + if stmt.else_body: + with interp_.new_frame(stmt, has_parent_access=True) as else_frame: + interp_.frame_call_region( + else_frame, stmt, stmt.else_body, cond_validation + ) + + for node, val in else_frame.entries.items(): + if isinstance(val, (Must, May)): + else_state = else_state.join(val) + + merged = then_state.join(else_state) + + if isinstance(merged, May): + then_has = not isinstance(then_state, Bottom) + else_has = not isinstance(else_state, Bottom) + + if then_has and not else_has: + new_violations = frozenset( + (gate, ", when condition is true") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) + elif else_has and not then_has: + new_violations = frozenset( + (gate, ", when condition is false") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) + + return (merged,) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py new file mode 100644 index 00000000..a8ca5049 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -0,0 +1,174 @@ +from abc import abstractmethod +from typing import FrozenSet, final +from dataclasses import field, dataclass + +from kirin.lattice import SingletonMeta, BoundedLattice + + +@dataclass +class QubitValidation(BoundedLattice["QubitValidation"]): + r"""Base class for qubit-cloning validation lattice. + + Semantics for control flow: + - Bottom: proven safe / never occurs + - Must: definitely occurs on ALL paths + - May: possibly occurs on SOME paths + - Top: unknown / no information + + Lattice ordering (more precise --> less precise): + Bottom ⊑ Must ⊑ May ⊑ Top + Bottom ⊑ May ⊑ Top + + Key insight: Must ⊔ Bottom = May (happens on some paths, not all) + """ + + @classmethod + def bottom(cls) -> "Bottom": + return Bottom() + + @classmethod + def top(cls) -> "Top": + return Top() + + @abstractmethod + def is_subseteq(self, other: "QubitValidation") -> bool: ... + + @abstractmethod + def join(self, other: "QubitValidation") -> "QubitValidation": ... + + @abstractmethod + def meet(self, other: "QubitValidation") -> "QubitValidation": ... + + +@final +class Bottom(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return True + + def join(self, other: QubitValidation) -> QubitValidation: + return other + + def meet(self, other: QubitValidation) -> QubitValidation: + return self + + def __repr__(self) -> str: + return "⊥ (No Errors)" + + +@final +class Top(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return isinstance(other, Top) + + def join(self, other: QubitValidation) -> QubitValidation: + return self + + def meet(self, other: QubitValidation) -> QubitValidation: + return other + + def __repr__(self) -> str: + return "⊤ (Unknown)" + + +@final +@dataclass +class Must(QubitValidation): + """Definite violations with concrete qubit IDs and gate names.""" + + violations: FrozenSet[tuple[int, str]] = field(default_factory=frozenset) + """Set of (qubit_id, gate_name) tuples""" + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom(): + return False + case Must(violations=ov): + return self.violations.issubset(ov) + case May() | Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + may_violations = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_violations) + case Must(violations=ov): + merged = self.violations | ov + return Must(violations=merged) + case May(violations=ov): + may_viols = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_viols | ov) + case Top(): + return other + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return other + case Must(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case May(): + return self + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + if not self.violations: + return "Must(∅)" + viols = ", ".join(f"Qubit[{qid}] at {gate}" for qid, gate in self.violations) + return f"Must({{{viols}}})" + + +@final +@dataclass +class May(QubitValidation): + """Potential violations with gate names and conditions.""" + + violations: FrozenSet[tuple[str, str]] = field(default_factory=frozenset) + """Set of (gate_name, condition) tuples""" + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom() | Must(): + return False + case May(violations=ov): + return self.violations.issubset(ov) + case Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=ov): + may_viols = frozenset((gate, "") for _, gate in ov) + return May(violations=self.violations | may_viols) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return other + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return other + case Must(): + return other + case May(violations=ov): + inter = self.violations & ov + return May(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + if not self.violations: + return "May(∅)" + viols = ", ".join(f"{gate}{cond}" for gate, cond in self.violations) + return f"May({{{viols}}})" diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py new file mode 100644 index 00000000..b35636f1 --- /dev/null +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -0,0 +1,159 @@ +from typing import Any, TypeVar + +import pytest +from kirin import ir +from kirin.dialects.ilist.runtime import IList + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import May, Must +from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation + +T = TypeVar("T", bound=Must | May) + + +def collect_errors_from_validation( + validation: NoCloningValidation, + frame, +) -> tuple[int, int]: + """Count Must (definite) and May (potential) errors from the validation pass. + + Returns: + (must_count, may_count) - number of definite and potential violations + """ + must_count = 0 + may_count = 0 + + if validation._analysis is None: + return (must_count, may_count) + + for node, value in frame.entries.items(): + if isinstance(value, Must): + must_count += len(value.violations) + elif isinstance(value, May): + may_count += len(value.violations) + + return must_count, may_count + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(1) + control_gate(q[0], q[0]) + + validation = NoCloningValidation() + + frame, errors = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 0 + + +@pytest.mark.parametrize("control_gate", [squin.cx]) +def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(cond: bool): + q = squin.qalloc(10) + if cond: + control_gate(q[0], q[0]) + else: + control_gate(q[0], q[1]) + squin.cx(q[1], q[1]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 # squin.cx(q[1], q[1]) outside conditional + assert may_count == 1 # control_gate(q[0], q[0]) inside conditional + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def test(): + q = squin.qalloc(3) + control_gate(q[0], q[1]) + squin.rx(1.57, q[0]) + squin.measure(q[0]) + control_gate(q[0], q[2]) + + validation = NoCloningValidation() + frame, errors = validation.run(test) + print() + test.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 0 + + +def test_fail_2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + a = 1 + squin.cx(q[0], q[1]) + squin.cy(q[1], q[a]) + + validation = NoCloningValidation() + frame, errors = validation.run(good_kernel) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 0 + + +def test_parallel_fail(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(5) + squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 2 + assert may_count == 0 + + +def test_potential_fail(): + @squin.kernel + def bad_kernel(a: int, b: int): + q = squin.qalloc(5) + squin.cx(q[a], q[2]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 1 + + +def test_potential_parallel_fail(): + @squin.kernel + def bad_kernel(a: IList): + q = squin.qalloc(5) + squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 1 diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py new file mode 100644 index 00000000..8396ccb6 --- /dev/null +++ b/test/analysis/validation/test_compose_validation.py @@ -0,0 +1,50 @@ +import pytest +from kirin.validation.validationpass import ValidationSuite + +from bloqade import squin +from bloqade.analysis.validation.nocloning import NoCloningValidation + + +def test_validation_suite(): + @squin.kernel + def bad_kernel(a: int): + q = squin.qalloc(2) + squin.cx(q[0], q[0]) # definite cloning error + squin.cx(q[a], q[1]) # potential cloning error + + # Running no-cloning validation multiple times + suite = ValidationSuite( + [ + NoCloningValidation, + NoCloningValidation, + NoCloningValidation, + ] + ) + result = suite.validate(bad_kernel) + + assert not result.is_valid + assert ( + result.error_count() == 2 + ) # Report 2 errors, even when validated multiple times + with pytest.raises(Exception) as exc_info: + result.raise_if_invalid() + print(f"{exc_info.value}") + + +def test_validation_suite2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + squin.cx(q[0], q[1]) + + suite = ValidationSuite( + [ + NoCloningValidation, + ], + fail_fast=True, + ) + result = suite.validate(good_kernel) + + assert result.is_valid + assert result.error_count() == 0 + result.raise_if_invalid()