Skip to content

Commit a7a3f28

Browse files
authored
Making ilist.rewrite.HintLen less aggressive (#589)
because of issue: #512 I am adding a slight modification to `HintLen` so it will not add hints inside For/IfElse bodies as that is causing some other passes to fail for example in `AddressiveUnroll` in bloqade-circuit the following example doesn't unroll properly: ```python from bloqade import qubit, squin from kirin.dialects import ilist from bloqade.rewrite.passes.aggressive_unroll import AggressiveUnroll @squin.kernel(typeinfer=True, fold=True) def log_depth_ghz(): size = 8 q0 = qubit.new() squin.h(q0) reg = ilist.IList([q0]) for i in range(size): current = len(reg) missing = size - current if missing > current: num_alloc = current else: num_alloc = missing if num_alloc > 0: new_qubits = qubit.qalloc(num_alloc) squin.broadcast.cx(reg[:num_alloc], new_qubits) reg = reg + new_qubits unroll = AggressiveUnroll(log_depth_ghz.dialects, no_raise=True) result = unroll.fixpoint(log_depth_ghz) log_depth_ghz.print() ``` Before this PR: ```mlir func.func @log_depth_ghz() -> !py.NoneType { ^0(%log_depth_ghz_self): │ %q0 = qubit.new() : !py.Qubit │ %reg = py.ilist.new(values=(%q0)){elem_type=!py.Qubit} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.h(qubits=%reg) │ %0 = qubit.new() : !py.Qubit │ %new_qubits = py.ilist.new(values=(%0)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits) │ %1 = qubit.new() : !py.Qubit │ %new_qubits_1 = py.ilist.new(values=(%1)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_1) │ %2 = qubit.new() : !py.Qubit │ %new_qubits_2 = py.ilist.new(values=(%2)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_2) │ %3 = qubit.new() : !py.Qubit │ %new_qubits_3 = py.ilist.new(values=(%3)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_3) │ %4 = qubit.new() : !py.Qubit │ %new_qubits_4 = py.ilist.new(values=(%4)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_4) │ %5 = qubit.new() : !py.Qubit │ %new_qubits_5 = py.ilist.new(values=(%5)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_5) │ %6 = qubit.new() : !py.Qubit │ %new_qubits_6 = py.ilist.new(values=(%6)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_6) │ %7 = qubit.new() : !py.Qubit │ %new_qubits_7 = py.ilist.new(values=(%7)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits_7) │ %8 = func.const.none() : !py.NoneType │ func.return %8 ``` After this PR: ```mlir func.func @main() -> !py.NoneType { ^0(%main_self): │ %q0 = qubit.new() : !py.Qubit │ %reg = py.ilist.new(values=(%q0)){elem_type=!py.Qubit} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.h(qubits=%reg) │ %0 = qubit.new() : !py.Qubit │ %new_qubits = py.ilist.new(values=(%0)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(1,int)] │ squin.gate.cx(controls=%reg, targets=%new_qubits) │ %1 = qubit.new() : !py.Qubit │ %2 = qubit.new() : !py.Qubit │ %new_qubits_1 = py.ilist.new(values=(%1, %2)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(2,int)] │ %controls = py.ilist.new(values=(%q0, %0)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(2,int)] │ squin.gate.cx(controls=%controls, targets=%new_qubits_1) │ %3 = qubit.new() : !py.Qubit │ %4 = qubit.new() : !py.Qubit │ %5 = qubit.new() : !py.Qubit │ %6 = qubit.new() : !py.Qubit │ %new_qubits_2 = py.ilist.new(values=(%3, %4, %5, %6)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(4,int)] │ %controls_1 = py.ilist.new(values=(%q0, %0, %1, %2)){elem_type=!Any} : !py.IList[!py.Qubit, Literal(4,int)] │ squin.gate.cx(controls=%controls_1, targets=%new_qubits_2) │ %7 = func.const.none() : !py.NoneType │ func.return %7 } // func.func main ``` which is clearly correct The issue is because the type inference result for `reg` is wrong inside the for loop that means that `current = len(reg)` is being replaced with `current = 1` which after unrolling propagates an incorrect value for every iteration after unrolling the loop. If you do not add hints inside the for loop body this will not happen so the program unrolls correctly. In principle we do not need to be as aggressive with this rewrite rule anyways as we do not unroll inner loops, therefore, adding hints inside the body of a loop is not really necessary.
1 parent df50bd1 commit a7a3f28

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/kirin/dialects/ilist/rewrite/hint_len.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from kirin import ir, types
22
from kirin.analysis import const
3-
from kirin.dialects import py
3+
from kirin.dialects import py, scf
44
from kirin.rewrite.abc import RewriteRule, RewriteResult
55
from kirin.dialects.ilist.stmts import IListType
66

@@ -26,7 +26,11 @@ def _get_collection_len(self, collection: ir.SSAValue):
2626
return None
2727

2828
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
29-
if not isinstance(node, py.Len):
29+
30+
if not (
31+
isinstance(node, py.Len)
32+
and not isinstance(node.parent_stmt, (scf.For, scf.IfElse))
33+
):
3034
return RewriteResult()
3135

3236
if (coll_len := self._get_collection_len(node.value)) is None:

0 commit comments

Comments
 (0)