Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
64dd86a
Implement no-cloning validation
zhenrongliew Nov 5, 2025
1b773d4
improve error reporting and update test cases for validation errors
zhenrongliew Nov 6, 2025
dca2422
Updated ValidationError reporting
zhenrongliew Nov 6, 2025
40f5d6c
Refactor no-cloning validation: enhance error handling and improve te…
zhenrongliew Nov 6, 2025
2a98e55
Shorter error messages
zhenrongliew Nov 6, 2025
5f4500d
clarify join/meet and lattice structure
zhenrongliew Nov 6, 2025
ef5a4e1
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 7, 2025
80c9349
fix linting
zhenrongliew Nov 7, 2025
6429148
Fix import warning
zhenrongliew Nov 7, 2025
2e155d9
fix import errors
zhenrongliew Nov 7, 2025
2a011fc
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 7, 2025
a35636a
Improve Validation framework to compose multiple validation analyses.
zhenrongliew Nov 7, 2025
2bf369e
removed redundant `method` variable
zhenrongliew Nov 7, 2025
a2af2f3
Fix commutativity of `join` operation
zhenrongliew Nov 10, 2025
c3680c4
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 10, 2025
b572471
updated to work with new Kirin version
zhenrongliew Nov 10, 2025
9ccc42c
moved collecting errors to Kirin's InterpreterABC
zhenrongliew Nov 12, 2025
ca8f5cd
fix unused import
zhenrongliew Nov 12, 2025
aeaea1c
Moved ValidationPass to Kirin
zhenrongliew Nov 12, 2025
06c7bbb
Remove redundant code in ifelse handling.
zhenrongliew Nov 14, 2025
924d60a
Refactor validation analysis and error handling in NoCloningValidatio…
zhenrongliew Nov 17, 2025
c2c825e
Fix import
zhenrongliew Nov 17, 2025
3657581
use `raise_if_invalid` instead of `format_errors`.
zhenrongliew Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/bloqade/analysis/validation/nocloning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import impls as impls
from .analysis import NoCloningValidation as NoCloningValidation
255 changes: 255 additions & 0 deletions src/bloqade/analysis/validation/nocloning/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from typing import Any

from kirin import ir
from kirin.analysis import Forward
from kirin.dialects import func
from kirin.ir.exception import (
ValidationError,
DefiniteValidationError,
PotentialValidationError,
)
from kirin.analysis.forward import ForwardFrame
from kirin.validation.validationpass import ValidationPass

from bloqade.analysis.address import (
Address,
AddressAnalysis,
)
from bloqade.analysis.address.lattice import (
Unknown,
AddressReg,
UnknownReg,
AddressQubit,
PartialIList,
PartialTuple,
UnknownQubit,
)

from .lattice import May, Top, Must, Bottom, QubitValidation


class QubitValidationError(DefiniteValidationError):
"""ValidationError for definite (Must) violations with concrete qubit addresses."""

qubit_id: int
gate_name: str

def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str):
super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.")
self.qubit_id = qubit_id
self.gate_name = gate_name


class PotentialQubitValidationError(PotentialValidationError):
"""ValidationError for potential (May) violations with unknown addresses."""

gate_name: str
condition: str

def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.")
self.gate_name = gate_name
self.condition = condition


class _NoCloningAnalysis(Forward[QubitValidation]):
"""Internal forward analysis for tracking qubit cloning violations."""

keys = ("validate.nocloning",)
lattice = QubitValidation

def __init__(self, dialects):
super().__init__(dialects)
self._address_frame: ForwardFrame[Address] | None = None

def method_self(self, method: ir.Method) -> QubitValidation:
return self.lattice.bottom()

def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation):
if self._address_frame is None:
addr_analysis = AddressAnalysis(self.dialects)
addr_analysis.initialize()
self._address_frame, _ = addr_analysis.run(method)
return super().run(method, *args, **kwargs)

def eval_fallback(
self, frame: ForwardFrame[QubitValidation], node: ir.Statement
) -> tuple[QubitValidation, ...]:
"""Check for qubit usage violations and return lattice values."""
if not isinstance(node, func.Invoke):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some questions for my understanding:

Why is this not a method impl for func.Invoke?

Also, this seems to be assuming that all func.Invoke are stdlib functions. What if a user calls their own subroutine? If I understand correctly, this will result in an error for that if the user-defined function has any qubit arguments whose addresses overlap. Is that correct and expected?

Just to be clear: I don't really see a (valid) use-case when a user would do that, but then again it's very simple to write a kernel function that is perfectly valid even when called with overlapping qubit arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the analysis is able to distinguish between safe user-defined functions and unsafe ones.
e.g. foo is called with overlapping qubit 1 but it's implementation is safe:

@squin.kernel
def kernel(a: int):
    q = squin.qalloc(4)
    squin.cx(q[0], q[0])  # definite cloning error
    def foo(q1, q2):
        squin.cx(q1,q[3])
    foo(q[1], q[1]) 

This is the reported error:

Validation failed with 1 violation(s):

No-Cloning Validation:
  - Qubit[0] cloned at CX gate.
      File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 12, col 4
    │  def kernel(a: int):
    │      q = squin.qalloc(4)
  12│      squin.cx(q[0], q[0])  # definite cloning error
    │      ^^^^^^^^^^^^^^^^^^^^
    │      def foo(q1, q2):
    │          squin.cx(q1,q[3])
    │      foo(q[1], q[1]) 
 (1 sub-exception)

And here is an unsafe user defined function:

@squin.kernel
def kernel(a: int):
   q = squin.qalloc(4)
   squin.cx(q[0], q[0])  # definite cloning error
   def foo(q1, q2):
       squin.cx(q1,q2)
   foo(q[1], q[1]) 

The report:

Validation failed with 2 violation(s):

No-Cloning Validation:
 - Qubit[0] cloned at CX gate.
     File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 12, col 4
   │  def bad_kernel(a: int):
   │      q = squin.qalloc(4)
 12│      squin.cx(q[0], q[0])  # definite cloning error
   │      ^^^^^^^^^^^^^^^^^^^^
   │      def foo(q1, q2):
   │          squin.cx(q1,q2)
   │      foo(q[1], q[1]) 

 - Qubit[1] cloned at FOO gate.
     File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 15, col 4
   │  def foo(q1, q2):
   │      squin.cx(q1,q2)
 15│  foo(q[1], q[1]) 
   │  ^^^
(2 sub-exceptions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Granted this error message might be a little misleading: "Qubit[1] cloned at FOO gate."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, but what I mean is this: if I change foo in the above to e.g.

def foo(q1, q2):
    squin.x(q1)
    squin.x(q2)

then calling foo(q[0], q[0]) is fine since there's no overlapping operation. Still, this will be reported as an error.
Again, I don't see a valid use case for this (why would you want to write something like the above?), but I just wanted to make sure this behavior was expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean now. No, that behavior is not expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could introduce some heuristic to define what function is a stdlib function and only then apply the analysis. @Roger-luo is there a way to do this via modules in kirin as of yet?

return tuple(Bottom() for _ in node.results)

address_frame = self._address_frame
if address_frame is None:
return tuple(Top() for _ in node.results)

concrete_addrs: list[int] = []
has_unknown = False
has_qubit_args = False
unknown_arg_names: list[str] = []

for arg in node.args:
addr = address_frame.get(arg)
match addr:
case AddressQubit(data=qubit_addr):
has_qubit_args = True
concrete_addrs.append(qubit_addr)
case AddressReg(data=addrs):
has_qubit_args = True
concrete_addrs.extend(addrs)
case (
UnknownQubit()
| UnknownReg()
| PartialIList()
| PartialTuple()
| Unknown()
):
has_qubit_args = True
has_unknown = True
arg_name = self._get_source_name(arg)
unknown_arg_names.append(arg_name)
case _:
pass

if not has_qubit_args:
return tuple(Bottom() for _ in node.results)

seen: set[int] = set()
violations: set[tuple[int, str]] = set()
s_name = getattr(node.callee, "sym_name", "<unknown>")
gate_name = s_name.upper()

for qubit_addr in concrete_addrs:
if qubit_addr in seen:
violations.add((qubit_addr, gate_name))
seen.add(qubit_addr)

if violations:
usage = Must(violations=frozenset(violations))
elif has_unknown:
args_str = " == ".join(unknown_arg_names)
if len(unknown_arg_names) > 1:
condition = f", when {args_str}"
else:
condition = f", with unknown argument {args_str}"

usage = May(violations=frozenset([(gate_name, condition)]))
else:
usage = Bottom()

return tuple(usage for _ in node.results) if node.results else (usage,)

def _get_source_name(self, value: ir.SSAValue) -> str:
"""Trace back to get the source variable name."""
from kirin.dialects.py.indexing import GetItem

if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem):
index_arg = value.stmt.args[1]
return self._get_source_name(index_arg)

if isinstance(value, ir.BlockArgument):
return value.name or f"arg{value.index}"

if hasattr(value, "name") and value.name:
return value.name

return str(value)

def extract_errors_from_frame(
self, frame: ForwardFrame[QubitValidation]
) -> list[ValidationError]:
"""Extract validation errors from final lattice values.

Only extracts errors from top-level statements (not nested in regions).
"""
errors = []
seen_statements = set()

for node, value in frame.entries.items():
if isinstance(node, ir.ResultValue):
stmt = node.stmt
elif isinstance(node, ir.Statement):
stmt = node
else:
continue
if stmt in seen_statements:
continue
seen_statements.add(stmt)
if isinstance(value, Must):
for qubit_id, gate_name in value.violations:
errors.append(QubitValidationError(stmt, qubit_id, gate_name))
elif isinstance(value, May):
for gate_name, condition in value.violations:
errors.append(
PotentialQubitValidationError(stmt, gate_name, condition)
)
return errors

def count_violations(self, frame: Any) -> int:
"""Count individual violations from the frame, same as test helper."""
from .lattice import May, Must

total = 0
for node, value in frame.entries.items():
if isinstance(value, Must):
total += len(value.violations)
elif isinstance(value, May):
total += len(value.violations)
return total


class NoCloningValidation(ValidationPass):
"""Validates the no-cloning theorem by tracking qubit addresses."""

def __init__(self):
self._analysis: _NoCloningAnalysis | None = None
self._cached_address_frame = None

def name(self) -> str:
return "No-Cloning Validation"

def get_required_analyses(self) -> list[type]:
"""Declare dependency on AddressAnalysis."""
return [AddressAnalysis]

def set_analysis_cache(self, cache: dict[type, Any]) -> None:
"""Use cached AddressAnalysis result."""
self._cached_address_frame = cache.get(AddressAnalysis)

def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
"""Run the no-cloning validation analysis."""
if self._analysis is None:
self._analysis = _NoCloningAnalysis(method.dialects)

self._analysis.initialize()
if self._cached_address_frame is not None:
self._analysis._address_frame = self._cached_address_frame

frame, _ = self._analysis.run(method)
errors = self._analysis.extract_errors_from_frame(frame)

return frame, errors

def print_validation_errors(self):
"""Print all collected errors with formatted snippets."""
if self._analysis is None:
return

if self._analysis.state._current_frame:
frame = self._analysis.state._current_frame
errors = self._analysis.extract_errors_from_frame(frame)

for err in errors:
if isinstance(err, QubitValidationError):
print(
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
)
elif isinstance(err, PotentialQubitValidationError):
print(
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
)
else:
print(
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
)
print(err.hint())
59 changes: 59 additions & 0 deletions src/bloqade/analysis/validation/nocloning/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from kirin import interp
from kirin.analysis import ForwardFrame
from kirin.dialects import scf

from .lattice import May, Top, Must, Bottom, QubitValidation
from .analysis import _NoCloningAnalysis


@scf.dialect.register(key="validate.nocloning")
class Scf(interp.MethodTable):
@interp.impl(scf.IfElse)
def if_else(
self,
interp_: _NoCloningAnalysis,
frame: ForwardFrame[QubitValidation],
stmt: scf.IfElse,
):
try:
cond_validation = frame.get(stmt.cond)
except Exception:
cond_validation = Top()

with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation)

then_state = Bottom()
for node, val in then_frame.entries.items():
if isinstance(val, (Must, May)):
then_state = then_state.join(val)

else_state = Bottom()
if stmt.else_body:
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
interp_.frame_call_region(
else_frame, stmt, stmt.else_body, cond_validation
)

for node, val in else_frame.entries.items():
if isinstance(val, (Must, May)):
else_state = else_state.join(val)

merged = then_state.join(else_state)

if isinstance(merged, May):
then_has = not isinstance(then_state, Bottom)
else_has = not isinstance(else_state, Bottom)

if then_has and not else_has:
new_violations = frozenset(
(gate, ", when condition is true") for gate, _ in merged.violations
)
merged = May(violations=new_violations)
elif else_has and not then_has:
new_violations = frozenset(
(gate, ", when condition is false") for gate, _ in merged.violations
)
merged = May(violations=new_violations)

return (merged,)
Loading
Loading