Skip to content

Commit 5b853b7

Browse files
david-plkaihsin
andauthored
Refactor gemini logical validation to use kirin's infrastructure (#629)
Co-authored-by: kaihsin <kaihsinwu@gmail.com>
1 parent 88b9525 commit 5b853b7

File tree

12 files changed

+498
-655
lines changed

12 files changed

+498
-655
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ requires-python = ">=3.10"
1313
dependencies = [
1414
"numpy>=1.22.0",
1515
"scipy>=1.13.1",
16-
"kirin-toolchain>=0.21.0,<0.23.0",
16+
"kirin-toolchain~=0.22.2",
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from . import impls as impls # NOTE: register methods
2-
from .analysis import GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
2+
from .analysis import GeminiLogicalValidation as GeminiLogicalValidation
Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,40 @@
1+
from typing import Any
2+
from dataclasses import dataclass
3+
14
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward, ForwardFrame
7+
from kirin.validation import ValidationPass
28

39
from bloqade import squin
4-
from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis
510

611

7-
class GeminiLogicalValidationAnalysis(ValidationAnalysis):
12+
class _GeminiLogicalValidationAnalysis(Forward[EmptyLattice]):
813
keys = ["gemini.validate.logical"]
914

1015
first_gate = True
16+
lattice = EmptyLattice
1117

12-
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
18+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
1319
if isinstance(node, squin.gate.stmts.Gate):
1420
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
1521
self.first_gate = False
1622

17-
return super().eval_fallback(frame, node)
23+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
24+
25+
def method_self(self, method: ir.Method) -> EmptyLattice:
26+
return self.lattice.bottom()
27+
28+
29+
@dataclass
30+
class GeminiLogicalValidation(ValidationPass):
31+
"""Validates a logical gemini program"""
32+
33+
def name(self) -> str:
34+
return "Gemini Logical Validation"
35+
36+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
37+
analysis = _GeminiLogicalValidationAnalysis(method.dialects)
38+
frame, _ = analysis.run(method)
39+
40+
return frame, analysis.get_validation_errors()
Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from kirin import ir, interp as _interp
2-
from kirin.analysis import const
2+
from kirin.analysis import ForwardFrame, const
33
from kirin.dialects import scf, func
44

55
from bloqade.squin import gate
6-
from bloqade.validation.analysis import ValidationFrame
7-
from bloqade.validation.analysis.lattice import Error
86

9-
from .analysis import GeminiLogicalValidationAnalysis
7+
from .analysis import _GeminiLogicalValidationAnalysis
108

119

1210
@scf.dialect.register(key="gemini.validate.logical")
@@ -15,87 +13,82 @@ class __ScfGeminiLogicalValidation(_interp.MethodTable):
1513
@_interp.impl(scf.IfElse)
1614
def if_else(
1715
self,
18-
interp: GeminiLogicalValidationAnalysis,
19-
frame: ValidationFrame,
16+
interp: _GeminiLogicalValidationAnalysis,
17+
frame: ForwardFrame,
2018
stmt: scf.IfElse,
2119
):
22-
frame.errors.append(
20+
interp.add_validation_error(
21+
stmt,
2322
ir.ValidationError(
2423
stmt, "If statements are not supported in logical Gemini programs!"
25-
)
26-
)
27-
return (
28-
Error(
29-
message="If statements are not supported in logical Gemini programs!"
3024
),
3125
)
26+
return (interp.lattice.bottom(),)
3227

3328
@_interp.impl(scf.For)
3429
def for_loop(
3530
self,
36-
interp: GeminiLogicalValidationAnalysis,
37-
frame: ValidationFrame,
31+
interp: _GeminiLogicalValidationAnalysis,
32+
frame: ForwardFrame,
3833
stmt: scf.For,
3934
):
40-
if isinstance(stmt.iterable.hints.get("const"), const.Value):
41-
return (interp.lattice.top(),)
35+
if not isinstance(stmt.iterable.hints.get("const"), const.Value):
4236

43-
frame.errors.append(
44-
ir.ValidationError(
37+
interp.add_validation_error(
4538
stmt,
46-
"Non-constant iterable in for loop is not supported in Gemini logical programs!",
39+
ir.ValidationError(
40+
stmt,
41+
"Non-constant iterable in for loop is not supported in Gemini logical programs!",
42+
),
4743
)
48-
)
4944

50-
return (
51-
Error(
52-
message="Non-constant iterable in for loop is not supported in Gemini logical programs!"
53-
),
54-
)
45+
return (interp.lattice.bottom(),)
5546

5647

5748
@func.dialect.register(key="gemini.validate.logical")
5849
class __FuncGeminiLogicalValidation(_interp.MethodTable):
5950
@_interp.impl(func.Invoke)
6051
def invoke(
6152
self,
62-
interp: GeminiLogicalValidationAnalysis,
63-
frame: ValidationFrame,
53+
interp: _GeminiLogicalValidationAnalysis,
54+
frame: ForwardFrame,
6455
stmt: func.Invoke,
6556
):
66-
frame.errors.append(
57+
interp.add_validation_error(
58+
stmt,
6759
ir.ValidationError(
6860
stmt,
6961
"Function invocations not supported in logical Gemini program!",
7062
help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls",
71-
)
63+
),
7264
)
7365

74-
return tuple(
75-
Error(
76-
message="Function invocations not supported in logical Gemini program!"
77-
)
78-
for _ in stmt.results
79-
)
66+
return tuple(interp.lattice.bottom() for _ in stmt.results)
8067

8168

8269
@gate.dialect.register(key="gemini.validate.logical")
8370
class __GateGeminiLogicalValidation(_interp.MethodTable):
71+
8472
@_interp.impl(gate.stmts.U3)
85-
def u3(
73+
@_interp.impl(gate.stmts.T)
74+
@_interp.impl(gate.stmts.Rx)
75+
@_interp.impl(gate.stmts.Ry)
76+
@_interp.impl(gate.stmts.Rz)
77+
def non_clifford(
8678
self,
87-
interp: GeminiLogicalValidationAnalysis,
88-
frame: ValidationFrame,
89-
stmt: gate.stmts.U3,
79+
interp: _GeminiLogicalValidationAnalysis,
80+
frame: ForwardFrame,
81+
stmt: gate.stmts.SingleQubitGate | gate.stmts.RotationGate,
9082
):
9183
if interp.first_gate:
9284
interp.first_gate = False
9385
return ()
9486

95-
frame.errors.append(
87+
interp.add_validation_error(
88+
stmt,
9689
ir.ValidationError(
9790
stmt,
98-
"U3 gate can only be used for initial state preparation, i.e. as the first gate!",
99-
)
91+
f"Non-clifford gate {stmt.name} can only be used for initial state preparation, i.e. as the first gate!",
92+
),
10093
)
10194
return ()

src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from kirin.passes import Default
55
from kirin.prelude import structural_no_opt
66
from kirin.dialects import py, func, ilist
7+
from kirin.validation import ValidationSuite
78
from typing_extensions import Doc
89
from kirin.passes.inline import InlinePass
910

1011
from bloqade.squin import gate, qubit
11-
from bloqade.validation import KernelValidation
1212
from bloqade.rewrite.passes import AggressiveUnroll
13-
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis
13+
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation
1414

1515
from ._dialect import dialect
1616

@@ -63,8 +63,9 @@ def run_pass(
6363
default_pass.fixpoint(mt)
6464

6565
if verify:
66-
validator = KernelValidation(GeminiLogicalValidationAnalysis)
67-
validator.run(mt, no_raise=no_raise)
66+
validator = ValidationSuite([GeminiLogicalValidation])
67+
validation_result = validator.validate(mt)
68+
validation_result.raise_if_invalid()
6869
mt.verify()
6970

7071
return run_pass

src/bloqade/validation/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/bloqade/validation/analysis/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/bloqade/validation/analysis/analysis.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/bloqade/validation/analysis/lattice.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)