-
Notifications
You must be signed in to change notification settings - Fork 1
No-cloning validation analysis #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhenrongliew
wants to merge
23
commits into
main
Choose a base branch
from
dl/validate-no-cloning
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
64dd86a
Implement no-cloning validation
zhenrongliew 1b773d4
improve error reporting and update test cases for validation errors
zhenrongliew dca2422
Updated ValidationError reporting
zhenrongliew 40f5d6c
Refactor no-cloning validation: enhance error handling and improve te…
zhenrongliew 2a98e55
Shorter error messages
zhenrongliew 5f4500d
clarify join/meet and lattice structure
zhenrongliew ef5a4e1
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew 80c9349
fix linting
zhenrongliew 6429148
Fix import warning
zhenrongliew 2e155d9
fix import errors
zhenrongliew 2a011fc
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew a35636a
Improve Validation framework to compose multiple validation analyses.
zhenrongliew 2bf369e
removed redundant `method` variable
zhenrongliew a2af2f3
Fix commutativity of `join` operation
zhenrongliew c3680c4
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew b572471
updated to work with new Kirin version
zhenrongliew 9ccc42c
moved collecting errors to Kirin's InterpreterABC
zhenrongliew ca8f5cd
fix unused import
zhenrongliew aeaea1c
Moved ValidationPass to Kirin
zhenrongliew 06c7bbb
Remove redundant code in ifelse handling.
zhenrongliew 924d60a
Refactor validation analysis and error handling in NoCloningValidatio…
zhenrongliew c2c825e
Fix import
zhenrongliew 3657581
use `raise_if_invalid` instead of `format_errors`.
zhenrongliew File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from . import impls as impls | ||
| from .analysis import NoCloningValidation as NoCloningValidation |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| from typing import Any | ||
|
|
||
| from kirin import ir | ||
| from kirin.analysis import Forward | ||
| from kirin.dialects import func | ||
| from kirin.ir.exception import ( | ||
| ValidationError, | ||
| DefiniteValidationError, | ||
| PotentialValidationError, | ||
| ) | ||
| from kirin.analysis.forward import ForwardFrame | ||
| from kirin.validation.validationpass import ValidationPass | ||
|
|
||
| from bloqade.analysis.address import ( | ||
| Address, | ||
| AddressAnalysis, | ||
| ) | ||
| from bloqade.analysis.address.lattice import ( | ||
| Unknown, | ||
| AddressReg, | ||
| UnknownReg, | ||
| AddressQubit, | ||
| PartialIList, | ||
| PartialTuple, | ||
| UnknownQubit, | ||
| ) | ||
|
|
||
| from .lattice import May, Top, Must, Bottom, QubitValidation | ||
|
|
||
|
|
||
| class QubitValidationError(DefiniteValidationError): | ||
| """ValidationError for definite (Must) violations with concrete qubit addresses.""" | ||
|
|
||
| qubit_id: int | ||
| gate_name: str | ||
|
|
||
| def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): | ||
| super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") | ||
| self.qubit_id = qubit_id | ||
| self.gate_name = gate_name | ||
|
|
||
|
|
||
| class PotentialQubitValidationError(PotentialValidationError): | ||
| """ValidationError for potential (May) violations with unknown addresses.""" | ||
|
|
||
| gate_name: str | ||
| condition: str | ||
|
|
||
| def __init__(self, node: ir.IRNode, gate_name: str, condition: str): | ||
| super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.") | ||
| self.gate_name = gate_name | ||
| self.condition = condition | ||
|
|
||
|
|
||
| class _NoCloningAnalysis(Forward[QubitValidation]): | ||
| """Internal forward analysis for tracking qubit cloning violations.""" | ||
|
|
||
| keys = ("validate.nocloning",) | ||
| lattice = QubitValidation | ||
|
|
||
| def __init__(self, dialects): | ||
| super().__init__(dialects) | ||
| self._address_frame: ForwardFrame[Address] | None = None | ||
|
|
||
| def method_self(self, method: ir.Method) -> QubitValidation: | ||
| return self.lattice.bottom() | ||
|
|
||
| def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): | ||
| if self._address_frame is None: | ||
| addr_analysis = AddressAnalysis(self.dialects) | ||
| addr_analysis.initialize() | ||
| self._address_frame, _ = addr_analysis.run(method) | ||
| return super().run(method, *args, **kwargs) | ||
|
|
||
| def eval_fallback( | ||
| self, frame: ForwardFrame[QubitValidation], node: ir.Statement | ||
| ) -> tuple[QubitValidation, ...]: | ||
| """Check for qubit usage violations and return lattice values.""" | ||
| if not isinstance(node, func.Invoke): | ||
| return tuple(Bottom() for _ in node.results) | ||
|
|
||
| address_frame = self._address_frame | ||
| if address_frame is None: | ||
| return tuple(Top() for _ in node.results) | ||
|
|
||
| concrete_addrs: list[int] = [] | ||
| has_unknown = False | ||
| has_qubit_args = False | ||
| unknown_arg_names: list[str] = [] | ||
|
|
||
| for arg in node.args: | ||
| addr = address_frame.get(arg) | ||
| match addr: | ||
| case AddressQubit(data=qubit_addr): | ||
| has_qubit_args = True | ||
| concrete_addrs.append(qubit_addr) | ||
| case AddressReg(data=addrs): | ||
| has_qubit_args = True | ||
| concrete_addrs.extend(addrs) | ||
| case ( | ||
| UnknownQubit() | ||
| | UnknownReg() | ||
| | PartialIList() | ||
| | PartialTuple() | ||
| | Unknown() | ||
| ): | ||
| has_qubit_args = True | ||
| has_unknown = True | ||
| arg_name = self._get_source_name(arg) | ||
| unknown_arg_names.append(arg_name) | ||
| case _: | ||
| pass | ||
|
|
||
| if not has_qubit_args: | ||
| return tuple(Bottom() for _ in node.results) | ||
|
|
||
| seen: set[int] = set() | ||
| violations: set[tuple[int, str]] = set() | ||
| s_name = getattr(node.callee, "sym_name", "<unknown>") | ||
| gate_name = s_name.upper() | ||
|
|
||
| for qubit_addr in concrete_addrs: | ||
| if qubit_addr in seen: | ||
| violations.add((qubit_addr, gate_name)) | ||
| seen.add(qubit_addr) | ||
|
|
||
| if violations: | ||
| usage = Must(violations=frozenset(violations)) | ||
| elif has_unknown: | ||
| args_str = " == ".join(unknown_arg_names) | ||
| if len(unknown_arg_names) > 1: | ||
| condition = f", when {args_str}" | ||
| else: | ||
| condition = f", with unknown argument {args_str}" | ||
|
|
||
| usage = May(violations=frozenset([(gate_name, condition)])) | ||
| else: | ||
| usage = Bottom() | ||
|
|
||
| return tuple(usage for _ in node.results) if node.results else (usage,) | ||
|
|
||
| def _get_source_name(self, value: ir.SSAValue) -> str: | ||
| """Trace back to get the source variable name.""" | ||
| from kirin.dialects.py.indexing import GetItem | ||
|
|
||
| if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): | ||
| index_arg = value.stmt.args[1] | ||
| return self._get_source_name(index_arg) | ||
|
|
||
| if isinstance(value, ir.BlockArgument): | ||
| return value.name or f"arg{value.index}" | ||
|
|
||
| if hasattr(value, "name") and value.name: | ||
| return value.name | ||
|
|
||
| return str(value) | ||
|
|
||
| def extract_errors_from_frame( | ||
| self, frame: ForwardFrame[QubitValidation] | ||
| ) -> list[ValidationError]: | ||
| """Extract validation errors from final lattice values. | ||
|
|
||
| Only extracts errors from top-level statements (not nested in regions). | ||
| """ | ||
| errors = [] | ||
| seen_statements = set() | ||
|
|
||
| for node, value in frame.entries.items(): | ||
| if isinstance(node, ir.ResultValue): | ||
| stmt = node.stmt | ||
| elif isinstance(node, ir.Statement): | ||
| stmt = node | ||
| else: | ||
| continue | ||
| if stmt in seen_statements: | ||
| continue | ||
| seen_statements.add(stmt) | ||
| if isinstance(value, Must): | ||
| for qubit_id, gate_name in value.violations: | ||
| errors.append(QubitValidationError(stmt, qubit_id, gate_name)) | ||
| elif isinstance(value, May): | ||
| for gate_name, condition in value.violations: | ||
| errors.append( | ||
| PotentialQubitValidationError(stmt, gate_name, condition) | ||
| ) | ||
| return errors | ||
|
|
||
| def count_violations(self, frame: Any) -> int: | ||
| """Count individual violations from the frame, same as test helper.""" | ||
| from .lattice import May, Must | ||
|
|
||
| total = 0 | ||
| for node, value in frame.entries.items(): | ||
| if isinstance(value, Must): | ||
| total += len(value.violations) | ||
| elif isinstance(value, May): | ||
| total += len(value.violations) | ||
| return total | ||
|
|
||
|
|
||
| class NoCloningValidation(ValidationPass): | ||
| """Validates the no-cloning theorem by tracking qubit addresses.""" | ||
|
|
||
| def __init__(self): | ||
| self._analysis: _NoCloningAnalysis | None = None | ||
| self._cached_address_frame = None | ||
|
|
||
| def name(self) -> str: | ||
| return "No-Cloning Validation" | ||
|
|
||
| def get_required_analyses(self) -> list[type]: | ||
| """Declare dependency on AddressAnalysis.""" | ||
| return [AddressAnalysis] | ||
|
|
||
| def set_analysis_cache(self, cache: dict[type, Any]) -> None: | ||
| """Use cached AddressAnalysis result.""" | ||
| self._cached_address_frame = cache.get(AddressAnalysis) | ||
|
|
||
| def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: | ||
| """Run the no-cloning validation analysis.""" | ||
| if self._analysis is None: | ||
| self._analysis = _NoCloningAnalysis(method.dialects) | ||
|
|
||
| self._analysis.initialize() | ||
| if self._cached_address_frame is not None: | ||
| self._analysis._address_frame = self._cached_address_frame | ||
|
|
||
| frame, _ = self._analysis.run(method) | ||
| errors = self._analysis.extract_errors_from_frame(frame) | ||
|
|
||
| return frame, errors | ||
|
|
||
| def print_validation_errors(self): | ||
| """Print all collected errors with formatted snippets.""" | ||
| if self._analysis is None: | ||
| return | ||
|
|
||
| if self._analysis.state._current_frame: | ||
| frame = self._analysis.state._current_frame | ||
| errors = self._analysis.extract_errors_from_frame(frame) | ||
|
|
||
| for err in errors: | ||
| if isinstance(err, QubitValidationError): | ||
| print( | ||
| f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" | ||
| ) | ||
| elif isinstance(err, PotentialQubitValidationError): | ||
| print( | ||
| f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" | ||
| ) | ||
| else: | ||
| print( | ||
| f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" | ||
| ) | ||
| print(err.hint()) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from kirin import interp | ||
| from kirin.analysis import ForwardFrame | ||
| from kirin.dialects import scf | ||
|
|
||
| from .lattice import May, Top, Must, Bottom, QubitValidation | ||
| from .analysis import _NoCloningAnalysis | ||
|
|
||
|
|
||
| @scf.dialect.register(key="validate.nocloning") | ||
| class Scf(interp.MethodTable): | ||
| @interp.impl(scf.IfElse) | ||
| def if_else( | ||
| self, | ||
| interp_: _NoCloningAnalysis, | ||
| frame: ForwardFrame[QubitValidation], | ||
| stmt: scf.IfElse, | ||
| ): | ||
| try: | ||
| cond_validation = frame.get(stmt.cond) | ||
| except Exception: | ||
| cond_validation = Top() | ||
|
|
||
| with interp_.new_frame(stmt, has_parent_access=True) as then_frame: | ||
| interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) | ||
|
|
||
| then_state = Bottom() | ||
| for node, val in then_frame.entries.items(): | ||
| if isinstance(val, (Must, May)): | ||
| then_state = then_state.join(val) | ||
|
|
||
| else_state = Bottom() | ||
| if stmt.else_body: | ||
| with interp_.new_frame(stmt, has_parent_access=True) as else_frame: | ||
| interp_.frame_call_region( | ||
| else_frame, stmt, stmt.else_body, cond_validation | ||
| ) | ||
|
|
||
| for node, val in else_frame.entries.items(): | ||
| if isinstance(val, (Must, May)): | ||
| else_state = else_state.join(val) | ||
|
|
||
| merged = then_state.join(else_state) | ||
|
|
||
| if isinstance(merged, May): | ||
| then_has = not isinstance(then_state, Bottom) | ||
| else_has = not isinstance(else_state, Bottom) | ||
|
|
||
| if then_has and not else_has: | ||
| new_violations = frozenset( | ||
| (gate, ", when condition is true") for gate, _ in merged.violations | ||
| ) | ||
| merged = May(violations=new_violations) | ||
| elif else_has and not then_has: | ||
| new_violations = frozenset( | ||
| (gate, ", when condition is false") for gate, _ in merged.violations | ||
| ) | ||
| merged = May(violations=new_violations) | ||
|
|
||
| return (merged,) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some questions for my understanding:
Why is this not a method impl for
func.Invoke?Also, this seems to be assuming that all
func.Invokeare stdlib functions. What if a user calls their own subroutine? If I understand correctly, this will result in an error for that if the user-defined function has any qubit arguments whose addresses overlap. Is that correct and expected?Just to be clear: I don't really see a (valid) use-case when a user would do that, but then again it's very simple to write a kernel function that is perfectly valid even when called with overlapping qubit arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the analysis is able to distinguish between safe user-defined functions and unsafe ones.
e.g.
foois called with overlapping qubit1but it's implementation is safe:This is the reported error:
Validation failed with 1 violation(s): No-Cloning Validation: - Qubit[0] cloned at CX gate. File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 12, col 4 │ def kernel(a: int): │ q = squin.qalloc(4) 12│ squin.cx(q[0], q[0]) # definite cloning error │ ^^^^^^^^^^^^^^^^^^^^ │ def foo(q1, q2): │ squin.cx(q1,q[3]) │ foo(q[1], q[1]) (1 sub-exception)And here is an unsafe user defined function:
The report:
Validation failed with 2 violation(s): No-Cloning Validation: - Qubit[0] cloned at CX gate. File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 12, col 4 │ def bad_kernel(a: int): │ q = squin.qalloc(4) 12│ squin.cx(q[0], q[0]) # definite cloning error │ ^^^^^^^^^^^^^^^^^^^^ │ def foo(q1, q2): │ squin.cx(q1,q2) │ foo(q[1], q[1]) - Qubit[1] cloned at FOO gate. File "/Users/Documents/kirin-workspace/bloqade-circuit/test/analysis/validation/test_compose_validation.py", line 15, col 4 │ def foo(q1, q2): │ squin.cx(q1,q2) 15│ foo(q[1], q[1]) │ ^^^ (2 sub-exceptions)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Granted this error message might be a little misleading: "Qubit[1] cloned at FOO gate."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand, but what I mean is this: if I change
fooin the above to e.g.then calling
foo(q[0], q[0])is fine since there's no overlapping operation. Still, this will be reported as an error.Again, I don't see a valid use case for this (why would you want to write something like the above?), but I just wanted to make sure this behavior was expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see what you mean now. No, that behavior is not expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could introduce some heuristic to define what function is a stdlib function and only then apply the analysis. @Roger-luo is there a way to do this via modules in kirin as of yet?