Skip to content

Commit baf63b4

Browse files
committed
Make sure we can compile expressions like 'x, y = OneOf(None, Tuple(int, int))((1, 2))'.
1 parent 271a2d8 commit baf63b4

File tree

6 files changed

+103
-1
lines changed

6 files changed

+103
-1
lines changed

typed_python/compiler/tests/one_of_compilation_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,3 +724,11 @@ def toString(x: OneOf("hi", "bye")): # noqa
724724
return f"its: {x}"
725725

726726
toString('hi')
727+
728+
def test_unpack_oneof_none_or_tuple(self):
729+
@Entrypoint
730+
def unpackIt(x: OneOf(None, Tuple(int, int))):
731+
a, b = x
732+
return a + b
733+
734+
unpackIt((1, 2))

typed_python/compiler/tests/tuple_compilation_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,18 @@ def checkIn():
398398
return None
399399

400400
assert checkIn.resultTypeFor().typeRepresentation == int
401+
402+
def test_iterate_tuple(self):
403+
T = Tuple(int, float, str)
404+
405+
@Entrypoint
406+
def iterateIt(x: T):
407+
res = ListOf(OneOf(int, float, str))()
408+
409+
for e in x:
410+
res.append(e)
411+
412+
return res
413+
414+
tup = T((1, 2, '3'))
415+
assert iterateIt(tup) == ListOf(OneOf(int, float, str))(tup)

typed_python/compiler/type_wrappers/none_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def __init__(self):
2929
def getCompileTimeConstant(self):
3030
return None
3131

32+
def isIterable(self):
33+
return False
34+
3235
def convert_default_initialize(self, context, target):
3336
pass
3437

typed_python/compiler/type_wrappers/one_of_wrapper.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,34 @@ def getNativeLayoutType(self):
5757
def convert_which_native(self, expr):
5858
return expr.ElementPtrIntegers(0, 0).load()
5959

60+
def get_iteration_expressions(self, context, expr):
61+
itExprs = None
62+
63+
for ix in range(len(self.typeRepresentation.Types)):
64+
T = self.typeRepresentation.Types[ix]
65+
66+
if typeWrapper(T).isIterable == "Maybe":
67+
return
68+
elif typeWrapper(T).isIterable is False:
69+
pass
70+
else:
71+
# we don't know how to union two sets of iteration expressions
72+
if itExprs is not None:
73+
return None
74+
75+
itExprs = typeWrapper(T).get_iteration_expressions(context, expr.refAs(ix))
76+
itExprIx = ix
77+
78+
if itExprs is not None:
79+
with context.ifelse(self.convert_which_native(expr.expr).eq(itExprIx)) as (ifTrue, ifFalse):
80+
with ifFalse:
81+
context.pushException(
82+
TypeError,
83+
"Instance of type {self.typeRepresentation.__name__} is not iterable"
84+
)
85+
86+
return itExprs
87+
6088
def unwrap(self, context, expr, generator):
6189
"""Call 'generator' on 'expr' cast down to each subtype and combine the results.
6290
"""

typed_python/compiler/type_wrappers/tuple_of_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def __init__(self, t):
318318
('data', native_ast.UInt8Ptr)
319319
), name='TupleOfLayout' if self.is_tuple else 'ListOfLayout').pointer()
320320

321+
def isIterable(self):
322+
return True
323+
321324
def has_fastnext_iter(self):
322325
"""If we call '__iter__' on instances of this type, will they support __fastnext__?"""
323326
return True

typed_python/compiler/type_wrappers/tuple_wrapper.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from typed_python.compiler.type_wrappers.wrapper import Wrapper
1616
from typed_python.compiler.merge_type_wrappers import mergeTypes
1717
from typed_python import (
18-
_types, Int32, Tuple, NamedTuple, Function, Dict, Set, ConstDict, ListOf, TupleOf
18+
_types, Int32, Tuple, NamedTuple, Function, Dict, Set, ConstDict, ListOf,
19+
TupleOf, PointerTo, pointerTo, TypeFunction, OneOf, Class, Member, Final
1920
)
2021
import typed_python._types
2122
from typed_python.compiler.conversion_level import ConversionLevel
@@ -195,6 +196,9 @@ def __init__(self, t):
195196
self._is_pod = all(typeWrapper(possibility).is_pod for possibility in self.subTypeWrappers)
196197
self.is_default_constructible = _types.is_default_constructible(t)
197198

199+
def isIterable(self):
200+
return True
201+
198202
@property
199203
def unionType(self):
200204
if self._unionType is None and self.typeRepresentation.ElementTypes:
@@ -220,6 +224,23 @@ def is_pod(self):
220224
def getNativeLayoutType(self):
221225
return self.layoutType
222226

227+
def convert_attribute(self, context, instance, attribute):
228+
if attribute in ["__iter__"]:
229+
return instance.changeType(BoundMethodWrapper.Make(self, attribute))
230+
231+
return super().convert_attribute(context, instance, attribute)
232+
233+
def convert_method_call(self, context, instance, methodname, args, kwargs):
234+
if methodname == '__iter__' and not args:
235+
return typeWrapper(TupleIterator(self.typeRepresentation)).convert_type_call(
236+
context,
237+
None,
238+
[],
239+
dict(pos=context.constant(-1), tup=instance)
240+
)
241+
242+
return super().convert_method_call(context, instance, methodname, args, kwargs)
243+
223244
def convert_initialize_from_args(self, context, target, *args):
224245
assert len(args) <= len(self.byteOffsets)
225246

@@ -608,6 +629,9 @@ def __init__(self, t):
608629
self.namesToIndices = {n: i for i, n in enumerate(t.ElementNames)}
609630
self.namesToTypes = {n: t.ElementTypes[i] for i, n in enumerate(t.ElementNames)}
610631

632+
def isIterable(self):
633+
return True
634+
611635
def has_fastnext_iter(self):
612636
if self.isSubclassOfNamedTuple:
613637
return "__iter__" in self.typeRepresentation.__dict__
@@ -869,3 +893,24 @@ def convert_bin_op_reverse(self, context, r, op, l, inplace):
869893
return self.convert_method_call(context, r, magic, (l,), {})
870894

871895
return super().convert_bin_op_reverse(context, r, op, l, inplace)
896+
897+
898+
@TypeFunction
899+
def TupleIterator(T):
900+
EltT = OneOf(*T.ElementTypes)
901+
902+
class TupleIterator(Class, Final, __name__=f"TupleIterator({T.__name__})"):
903+
pos = Member(int, nonempty=True)
904+
elt = Member(EltT)
905+
tup = Member(T, nonempty=True)
906+
907+
def __fastnext__(self):
908+
self.pos += 1
909+
910+
if self.pos < len(self.tup):
911+
self.elt = self.tup[self.pos]
912+
return pointerTo(self).elt
913+
else:
914+
return PointerTo(EltT)()
915+
916+
return TupleIterator

0 commit comments

Comments
 (0)