Skip to content

Commit 5d4b0f8

Browse files
authored
Make MethodType explicit (#582)
This PR address #579 and contain partly of #580
1 parent d133226 commit 5d4b0f8

File tree

23 files changed

+292
-56
lines changed

23 files changed

+292
-56
lines changed

docs/cookbook/foodlang/cf_rewrite.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def food(self):
132132

133133
if fold:
134134
fold_pass(mt)
135-
135+
136136
if hungry:
137137
Walk(NewFoodAndNap()).rewrite(mt.code)
138138

example/food/script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# type: ignore
2-
from emit import EmitReceptMain
32
from group import food
43
from stmts import Eat, Nap, Cook, NewFood
54
from recept import FeeAnalysis
65

6+
from emit import EmitReceptMain
77
from interp import FoodMethods as FoodMethods
88
from lattice import AtLeastXItem
99
from rewrite import NewFoodAndNap

example/quantum/script.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from enum import Enum
55
from typing import ClassVar
66
from dataclasses import dataclass
7+
78
from qulacs import QuantumState
89

910

@@ -27,7 +28,7 @@ class Basis(Enum):
2728

2829
# [section]
2930
from kirin import ir, types, lowering
30-
from kirin.decl import statement, info
31+
from kirin.decl import info, statement
3132
from kirin.prelude import basic
3233

3334
# our language definitions and compiler begins
@@ -161,8 +162,9 @@ def main(state: QuantumState):
161162
# we need to implement the runtime for the quantum circuit
162163
# let's just import qulacs a quantum circuit simulator
163164

165+
from qulacs import QuantumState, gate
166+
164167
from kirin import interp
165-
from qulacs import gate, QuantumState
166168

167169

168170
@dialect.register

src/kirin/analysis/typeinfer/solve.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute:
6969
)
7070
elif isinstance(typ, types.Union):
7171
return types.Union(self.substitute(t) for t in typ.types)
72+
elif isinstance(typ, types.FunctionType):
73+
return types.FunctionType(
74+
params_type=tuple(self.substitute(t) for t in typ.params_type),
75+
return_type=(
76+
self.substitute(typ.return_type) if typ.return_type else None
77+
),
78+
)
7279
return typ
7380

7481
def solve(
@@ -94,6 +101,8 @@ def solve(
94101
return self.solve_Generic(annot, value)
95102
elif isinstance(annot, types.Union):
96103
return self.solve_Union(annot, value)
104+
elif isinstance(annot, types.FunctionType):
105+
return self.solve_FunctionType(annot, value)
97106

98107
if annot.is_subseteq(value):
99108
return Ok
@@ -133,6 +142,24 @@ def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute):
133142
return result
134143
return Ok
135144

145+
def solve_FunctionType(self, annot: types.FunctionType, value: types.TypeAttribute):
146+
if not isinstance(value, types.FunctionType):
147+
return ResolutionError(annot, value)
148+
149+
for var, val in zip(annot.params_type, value.params_type):
150+
result = self.solve(var, val)
151+
if not result:
152+
return result
153+
154+
if not annot.return_type or not value.return_type:
155+
return Ok
156+
157+
result = self.solve(annot.return_type, value.return_type)
158+
if not result:
159+
return result
160+
161+
return Ok
162+
136163
def solve_Union(self, annot: types.Union, value: types.TypeAttribute):
137164
for typ in annot.types:
138165
result = self.solve(typ, value)

src/kirin/dialects/func/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
constprop as constprop,
66
typeinfer as typeinfer,
77
)
8-
from kirin.dialects.func.attrs import Signature as Signature, MethodType as MethodType
8+
from kirin.dialects.func.attrs import Signature as Signature
99
from kirin.dialects.func.stmts import (
1010
Call as Call,
1111
Invoke as Invoke,

src/kirin/dialects/func/attrs.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33

44
from kirin import types
5-
from kirin.ir import Method, Attribute
5+
from kirin.ir import Attribute
66
from kirin.print.printer import Printer
77
from kirin.serialization.core.serializationunit import SerializationUnit
88

@@ -12,10 +12,6 @@
1212

1313
from ._dialect import dialect
1414

15-
TypeofMethodType = types.PyClass[Method]
16-
MethodType = types.Generic(
17-
Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret")
18-
)
1915
TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute")
2016

2117

src/kirin/dialects/func/stmts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
from types import MethodType as ClassMethodType, FunctionType
3+
from types import MethodType as PyClassMethodType, FunctionType as PyFunctionType
44
from typing import TypeVar
55

66
from kirin import ir, types
77
from kirin.decl import info, statement
88
from kirin.print.printer import Printer
99

10-
from .attrs import Signature, MethodType
10+
from .attrs import Signature
1111
from ._dialect import dialect
1212

1313

@@ -58,7 +58,7 @@ class Function(ir.Statement):
5858
"""The signature of the function at declaration."""
5959
body: ir.Region = info.region(multi=True)
6060
"""The body of the function."""
61-
result: ir.ResultValue = info.result(MethodType)
61+
result: ir.ResultValue = info.result(types.MethodType)
6262
"""The result of the function."""
6363

6464
def print_impl(self, printer: Printer) -> None:
@@ -115,7 +115,7 @@ class Lambda(ir.Statement):
115115
"""The signature of the function at declaration."""
116116
captured: tuple[ir.SSAValue, ...] = info.argument()
117117
body: ir.Region = info.region(multi=True)
118-
result: ir.ResultValue = info.result(MethodType)
118+
result: ir.ResultValue = info.result(types.MethodType)
119119

120120
def check(self) -> None:
121121
assert self.body.blocks, "lambda body must not be empty"
@@ -145,7 +145,7 @@ def print_impl(self, printer: Printer) -> None:
145145
class GetField(ir.Statement):
146146
name = "getfield"
147147
traits = frozenset({ir.Pure()})
148-
obj: ir.SSAValue = info.argument(MethodType)
148+
obj: ir.SSAValue = info.argument(types.MethodType)
149149
field: int = info.attribute()
150150
# NOTE: mypy somehow doesn't understand default init=False
151151
result: ir.ResultValue = info.result(init=False)
@@ -249,15 +249,15 @@ def print_impl(self, printer: Printer) -> None:
249249

250250
def check_type(self) -> None:
251251
if not self.callee.type.is_subseteq(types.MethodType):
252-
if self.callee.type.is_subseteq(types.PyClass(FunctionType)):
252+
if self.callee.type.is_subseteq(types.PyClass(PyFunctionType)):
253253
raise ir.TypeCheckError(
254254
self,
255255
f"callee must be a method type, got {self.callee.type}",
256256
help="did you call a Python function directly? "
257257
"consider decorating it with kernel decorator",
258258
)
259259

260-
if self.callee.type.is_subseteq(types.PyClass(ClassMethodType)):
260+
if self.callee.type.is_subseteq(types.PyClass(PyClassMethodType)):
261261
raise ir.TypeCheckError(
262262
self,
263263
"callee must be a method type, got class method",

src/kirin/dialects/func/typeinfer.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import ir, types
44
from kirin.interp import Frame, MethodTable, ReturnValue, impl
55
from kirin.analysis import const
6-
from kirin.analysis.typeinfer import TypeInference, TypeResolution
6+
from kirin.analysis.typeinfer import TypeInference
77
from kirin.dialects.func.stmts import (
88
Call,
99
Invoke,
@@ -54,20 +54,11 @@ def call(self, interp_: TypeInference, frame: Frame, stmt: Call):
5454

5555
def _solve_method_type(self, interp: TypeInference, frame: Frame, stmt: Call):
5656
mt_inferred = frame.get(stmt.callee)
57-
if not isinstance(mt_inferred, types.Generic):
58-
return (types.Bottom,)
5957

60-
if len(mt_inferred.vars) != 2:
61-
return (types.Bottom,)
62-
args = mt_inferred.vars[0]
63-
result = mt_inferred.vars[1]
64-
if not args.is_subseteq(types.Tuple):
58+
if not isinstance(mt_inferred, types.FunctionType):
6559
return (types.Bottom,)
6660

67-
resolve = TypeResolution()
68-
# NOTE: we are not using [...] below to be compatible with 3.10
69-
resolve.solve(args, types.Tuple.where(frame.get_values(stmt.inputs)))
70-
return (resolve.substitute(result),)
61+
return (mt_inferred.return_type,)
7162

7263
@impl(Invoke)
7364
def invoke(self, interp_: TypeInference, frame: Frame, stmt: Invoke):

src/kirin/dialects/ilist/stmts.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ class Map(ir.Statement):
7575
class Foldr(ir.Statement):
7676
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
7777
purity: bool = info.attribute(default=False)
78-
fn: ir.SSAValue = info.argument(
79-
types.Generic(ir.Method, [ElemT, OutElemT], OutElemT)
80-
)
78+
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT, OutElemT], OutElemT])
8179
collection: ir.SSAValue = info.argument(IListType[ElemT])
8280
init: ir.SSAValue = info.argument(OutElemT)
8381
result: ir.ResultValue = info.result(OutElemT)
@@ -87,9 +85,8 @@ class Foldr(ir.Statement):
8785
class Foldl(ir.Statement):
8886
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
8987
purity: bool = info.attribute(default=False)
90-
fn: ir.SSAValue = info.argument(
91-
types.Generic(ir.Method, [OutElemT, ElemT], OutElemT)
92-
)
88+
fn: ir.SSAValue = info.argument(types.MethodType[[OutElemT, ElemT], OutElemT])
89+
9390
collection: ir.SSAValue = info.argument(IListType[ElemT])
9491
init: ir.SSAValue = info.argument(OutElemT)
9592
result: ir.ResultValue = info.result(OutElemT)
@@ -104,7 +101,7 @@ class Scan(ir.Statement):
104101
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
105102
purity: bool = info.attribute(default=False)
106103
fn: ir.SSAValue = info.argument(
107-
types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT])
104+
types.MethodType[[OutElemT, ElemT], types.Tuple[OutElemT, ResultT]]
108105
)
109106
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
110107
init: ir.SSAValue = info.argument(OutElemT)
@@ -117,7 +114,7 @@ class Scan(ir.Statement):
117114
class ForEach(ir.Statement):
118115
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
119116
purity: bool = info.attribute(default=False)
120-
fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType))
117+
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], types.NoneType])
121118
collection: ir.SSAValue = info.argument(IListType[ElemT])
122119

123120

@@ -141,7 +138,7 @@ class Sorted(ir.Statement):
141138
purity: bool = info.attribute(default=False)
142139
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
143140
key: ir.SSAValue = info.argument(
144-
types.Union((types.Generic(ir.Method, [ElemT], ElemT), types.NoneType))
141+
types.Union((types.MethodType[[ElemT], ElemT], types.NoneType))
145142
)
146143
reverse: ir.SSAValue = info.argument(types.Bool)
147144
result: ir.ResultValue = info.result(IListType[ElemT, ListLen])

src/kirin/dialects/lowering/func.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def lower_FunctionDef(
3535
entries: dict[str, ir.SSAValue] = {}
3636
entr_block = ir.Block()
3737
fn_self = entr_block.args.append_from(
38-
types.Generic(
39-
ir.Method, types.Tuple.where(signature.inputs), signature.output
40-
),
38+
types.MethodType[list(signature.inputs), signature.output],
4139
node.name + "_self",
4240
)
4341
entries[node.name] = fn_self

0 commit comments

Comments
 (0)