Skip to content

Commit 271a2d8

Browse files
committed
Fix an issue causing us to hit the interpreter when converting OneOf('someString', 'someOtherString') to str.
1 parent e409c04 commit 271a2d8

File tree

5 files changed

+73
-18
lines changed

5 files changed

+73
-18
lines changed

typed_python/compiler/python_to_native_converter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,6 @@ def _installInflightFunctions(self, name):
957957
outboundTargets
958958
)
959959

960-
961-
962960
if identifier not in self._inflight_definitions:
963961
raise Exception(
964962
f"Expected a definition for {identifier} depended on by:\n"

typed_python/compiler/tests/one_of_compilation_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,10 @@ def checkIt(x: OneOf(None, Tuple(str, str))):
717717

718718
assert checkIt(('a', 'b')) == 'ab'
719719
assert checkIt(None) == 'empty'
720+
721+
def test_convert_one_of_constants_to_string(self):
722+
@Entrypoint
723+
def toString(x: OneOf("hi", "bye")): # noqa
724+
return f"its: {x}"
725+
726+
toString('hi')

typed_python/compiler/tests/string_compilation_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
import unittest
1717
import time
18+
import os
1819

1920
from flaky import flaky
2021
from typed_python import _types, ListOf, TupleOf, Dict, ConstDict, Compiled, Entrypoint, OneOf
@@ -23,7 +24,7 @@
2324
strEndswith, strRangeEndswith, strEndswithTuple, strRangeEndswithTuple, \
2425
strReplace, strPartition, strRpartition, strCenter, strRjust, strLjust, strExpandtabs, strZfill
2526
from typed_python.test_util import currentMemUsageMb, compilerPerformanceComparison
26-
from typed_python.compiler.runtime import PrintNewFunctionVisitor
27+
from typed_python.compiler.runtime import PrintNewFunctionVisitor, RuntimeEventVisitor
2728

2829

2930
someStrings = [
@@ -1445,3 +1446,42 @@ def test_string_internal_fns(self):
14451446
self.assertEqual(strExpandtabs(v, 8), v.expandtabs(8))
14461447
self.assertEqual(strZfill(v, 20), v.zfill(20))
14471448
self.assertEqual(strZfill('+123', 20), '+123'.zfill(20))
1449+
1450+
def test_fstring_of_string_doesnt_hit_interpreter(self):
1451+
# if the cache is on, this won't work
1452+
if os.getenv("TP_COMPILER_CACHE"):
1453+
return
1454+
1455+
class Visitor(RuntimeEventVisitor):
1456+
"""Base class for a Visitor that gets to see what's going on in the runtime.
1457+
1458+
Clients should subclass this and pass it to 'addEventVisitor' in the runtime
1459+
to find out about events like function typing assignments.
1460+
"""
1461+
def onNewFunction(
1462+
self,
1463+
identifier,
1464+
functionConverter,
1465+
nativeFunction,
1466+
funcName,
1467+
funcCode,
1468+
funcGlobals,
1469+
closureVars,
1470+
inputTypes,
1471+
outputType,
1472+
yieldType,
1473+
variableTypes,
1474+
conversionType,
1475+
calledFunctions,
1476+
):
1477+
if funcName == "toString":
1478+
self.nativeFunction = nativeFunction
1479+
1480+
@Entrypoint
1481+
def toString(x: OneOf("hi", "bye")): # noqa
1482+
return f"its: {x}"
1483+
1484+
with Visitor() as vis:
1485+
toString("hi")
1486+
1487+
assert 'PythonObjectOfTypeWrapper' not in str(vis.nativeFunction)

typed_python/compiler/type_wrappers/string_wrapper.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typed_python.compiler.type_wrappers.typed_list_masquerading_as_list_wrapper import TypedListMasqueradingAsList
2222
import typed_python.compiler.type_wrappers.runtime_functions as runtime_functions
2323
from typed_python.compiler.type_wrappers.bound_method_wrapper import BoundMethodWrapper
24+
from typed_python.compiler.conversion_level import ConversionLevel
2425

2526
import typed_python.compiler.native_ast as native_ast
2627
import typed_python.compiler
@@ -1010,7 +1011,31 @@ def _can_convert_to_type(self, targetType, conversionLevel):
10101011
Float32, Int8, Int16, Int32, UInt8, UInt16, UInt32, UInt64, float, int, bool, str
10111012
)
10121013

1014+
def convert_to_self_with_target(self, context, targetVal, sourceVal, level: ConversionLevel, mayThrowOnFailure=False):
1015+
if sourceVal.expr_type.typeRepresentation is str:
1016+
targetVal.convert_copy_initialize(targetVal)
1017+
return context.constant(True)
1018+
1019+
if level.isNewOrHigher():
1020+
if not sourceVal.isReference:
1021+
sourceVal = context.pushMove(sourceVal)
1022+
1023+
return context.pushPod(
1024+
bool,
1025+
runtime_functions.np_try_pyobj_to_str.call(
1026+
sourceVal.expr.cast(VoidPtr),
1027+
targetVal.expr.cast(VoidPtr),
1028+
context.getTypePointer(sourceVal.expr_type.typeRepresentation)
1029+
)
1030+
)
1031+
1032+
return super().convert_to_self_with_target(context, targetVal, sourceVal, level, mayThrowOnFailure)
1033+
10131034
def convert_to_type_with_target(self, context, instance, targetVal, conversionLevel, mayThrowOnFailure=False):
1035+
if targetVal.expr_type.typeRepresentation is str:
1036+
targetVal.convert_copy_initialize(instance)
1037+
return context.constant(True)
1038+
10141039
if not conversionLevel.isNewOrHigher():
10151040
return super().convert_to_type_with_target(context, instance, targetVal, conversionLevel, mayThrowOnFailure)
10161041

typed_python/compiler/type_wrappers/wrapper.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -665,21 +665,6 @@ def convert_to_type_with_target(self, context, expr, targetVal, level: Conversio
665665
assert isinstance(level, ConversionLevel)
666666
assert targetVal.isReference
667667

668-
if level.isNewOrHigher() and targetVal.expr_type.typeRepresentation is str:
669-
if not expr.isReference:
670-
expr = context.pushMove(expr)
671-
672-
expr = expr.convert_to_type(object, ConversionLevel.Signature)
673-
674-
return context.pushPod(
675-
bool,
676-
runtime_functions.np_try_pyobj_to_str.call(
677-
expr.expr.cast(VoidPtr),
678-
targetVal.expr.cast(VoidPtr),
679-
context.getTypePointer(expr.expr_type.typeRepresentation)
680-
)
681-
)
682-
683668
return targetVal.expr_type.convert_to_self_with_target(context, targetVal, expr, level, mayThrowOnFailure)
684669

685670
def convert_to_self_with_target(self, context, targetVal, sourceVal, level: ConversionLevel, mayThrowOnFailure=False):

0 commit comments

Comments
 (0)