Skip to content

Commit 9188954

Browse files
authored
Lowering for syntax sugar for binop (#587)
This PR expand the capability for lowering of kernel with type hint of something with: ```python @basic def main(x: str | float | int): return x main.print() tps = main.arg_types assert len(tps) == 1 assert tps[0] == types.Union([types.String, types.Float, types.Int]) ``` Note that before this PR, only Binding can allow this semantic, but will error if one trying to do that for kernel program
1 parent 5d4b1a3 commit 9188954

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

src/kirin/lowering/python/dialect.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,45 @@ def lower(self, state: State[ast.AST], node: ast.AST) -> Result:
7575
def unreachable(self, state: State[ast.AST], node: ast.AST) -> Result:
7676
raise BuildError(f"unreachable reached for {node.__class__.__name__}")
7777

78+
@staticmethod
79+
def _flatten_hint_binop(node: ast.expr) -> list[ast.expr]:
80+
"""Flatten a binary operation tree into a list of expressions.
81+
82+
This is useful for handling union types represented as binary operations.
83+
"""
84+
hints = []
85+
86+
def _recurse(n: ast.expr):
87+
if isinstance(n, ast.BinOp):
88+
_recurse(n.left)
89+
_recurse(n.right)
90+
else:
91+
hints.append(n)
92+
93+
_recurse(node)
94+
return hints
95+
7896
@staticmethod
7997
def get_hint(state: State[ast.AST], node: ast.expr | None) -> types.TypeAttribute:
8098
if node is None:
8199
return types.AnyType()
82100

101+
# deal with union syntax
102+
if isinstance(node, ast.BinOp):
103+
hint_nodes = FromPythonAST._flatten_hint_binop(node)
104+
hint_ts = []
105+
for i in range(len(hint_nodes)):
106+
hint_ts.append(
107+
FromPythonAST.get_hint(
108+
state,
109+
hint_nodes[i],
110+
)
111+
)
112+
return types.Union(hint_ts)
113+
83114
try:
84115
t = state.get_global(node).data
116+
85117
return types.hint2type(t)
86118
except Exception as e: # noqa: E722
87119
raise BuildError(f"expect a type hint, got {ast.unparse(node)}") from e
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any
2+
3+
from kirin import types
4+
from kirin.prelude import basic
5+
from kirin.dialects import ilist
6+
7+
8+
def test_method_union_binop_hint():
9+
10+
@basic
11+
def main(x: ilist.IList[float, Any] | list[float]) -> float:
12+
return x[0]
13+
14+
main.print()
15+
16+
tps = main.arg_types
17+
18+
assert len(tps) == 1
19+
assert tps[0] == types.Union(
20+
[ilist.IListType[types.Float, types.Any], types.List[types.Float]]
21+
)
22+
23+
24+
def test_method_union_multi_hint():
25+
26+
@basic
27+
def main(x: str | float | int):
28+
return x
29+
30+
main.print()
31+
32+
tps = main.arg_types
33+
34+
assert len(tps) == 1
35+
assert tps[0] == types.Union([types.String, types.Float, types.Int])

0 commit comments

Comments
 (0)