Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Open
191 changes: 191 additions & 0 deletions numba_typing/overload_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import numba
from numba import types
from numba.extending import overload
from type_annotations import product_annotations, get_func_annotations
import typing
from numba.typed import List, Dict
from inspect import getfullargspec


def overload_list(orig_func):
def overload_inner(ovld_list):
def wrapper(*args):
func_list = ovld_list()
sig_list = []
for func in func_list:
sig_list.append((product_annotations(
get_func_annotations(func)), func))
args_orig_func = getfullargspec(orig_func)
values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)}
defaults_dict = {}
if args_orig_func.defaults:
defaults_dict = {name: value for name, value in zip(
args_orig_func.args[::-1], args_orig_func.defaults[::-1])}
if valid_signature(sig_list, values_dict, defaults_dict):
result = choose_func_by_sig(sig_list, values_dict)

if result is None:
raise TypeError(f'Unsupported types a={a}, b={b}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks both a and b are undefined.


return result

return overload(orig_func, strict=False)(wrapper)

return overload_inner


def valid_signature(list_signature, values_dict, defaults_dict):
def check_defaults(sig_def):
for name, val in defaults_dict.items():
if sig_def.get(name) is None:
raise AttributeError(f'{name} does not match the signature of the function passed to overload_list')
if not sig_def[name] == val:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition looks pretty strange. Maybe if sig_def[name] != val?

raise ValueError(f'The default arguments are not equal: {name}: {val} != {sig_def[name]}')

for sig, _ in list_signature:
for param in sig.parameters:
if len(param) != len(values_dict.items()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to call method items() here.

check_defaults(sig.defaults)

return True


def check_int_type(n_type):
return isinstance(n_type, types.Integer)


def check_float_type(n_type):
return isinstance(n_type, types.Float)


def check_bool_type(n_type):
return isinstance(n_type, types.Boolean)


def check_str_type(n_type):
return isinstance(n_type, types.UnicodeType)


def check_list_type(self, p_type, n_type):
res = isinstance(n_type, (types.List, types.ListType))
if p_type == list:
return res
else:
return res and self.match(p_type.__args__[0], n_type.dtype)


def check_tuple_type(self, p_type, n_type):
if not isinstance(n_type, (types.Tuple, types.UniTuple)):
return False
try:
if len(p_type.__args__) != len(n_type.types):
return False
except AttributeError: # if p_type == tuple
return True

for p_val, n_val in zip(p_type.__args__, n_type.types):
if not self.match(p_val, n_val):
return False

return True


def check_dict_type(self, p_type, n_type):
res = False
if isinstance(n_type, types.DictType):
res = True
if isinstance(p_type, type):
return res
for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type):
res = res and self.match(p_val, n_val)
return res


class TypeChecker:

_types_dict: dict = {}

def __init__(self):
self._typevars_dict = {}

def clear_typevars_dict(self):
self._typevars_dict.clear()

@classmethod
def add_type_check(cls, type_check, func):
cls._types_dict[type_check] = func

@staticmethod
def _is_generic(p_obj):
if isinstance(p_obj, typing._GenericAlias):
return True

if isinstance(p_obj, typing._SpecialForm):
return p_obj not in {typing.Any}

return False

@staticmethod
def _get_origin(p_obj):
return p_obj.__origin__

def match(self, p_type, n_type):
if p_type == typing.Any:
return True
try:
if self._is_generic(p_type):
origin_type = self._get_origin(p_type)
if origin_type == typing.Generic:
return self.match_generic(p_type, n_type)

return self._types_dict[origin_type](self, p_type, n_type)

if isinstance(p_type, typing.TypeVar):
return self.match_typevar(p_type, n_type)

if p_type in (list, tuple, dict):
return self._types_dict[p_type](self, p_type, n_type)

return self._types_dict[p_type](n_type)

except KeyError:
raise TypeError(f'A check for the {p_type} was not found.')

def match_typevar(self, p_type, n_type):
if isinstance(n_type, types.List):
n_type = types.ListType(n_type.dtype)
if not self._typevars_dict.get(p_type):
self._typevars_dict[p_type] = n_type
return True
return self._typevars_dict.get(p_type) == n_type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it should be self.match. E.g. list and 'types.List' are synonyms but will fail equality check ('list != types.List').
And I'm assuming you don't have such tests?


def match_generic(self, p_type, n_type):
raise SystemError


TypeChecker.add_type_check(int, check_int_type)
TypeChecker.add_type_check(float, check_float_type)
TypeChecker.add_type_check(str, check_str_type)
TypeChecker.add_type_check(bool, check_bool_type)
TypeChecker.add_type_check(list, check_list_type)
TypeChecker.add_type_check(tuple, check_tuple_type)
TypeChecker.add_type_check(dict, check_dict_type)


def choose_func_by_sig(sig_list, values_dict):
def check_signature(sig_params, types_dict):
checker = TypeChecker()
for name, typ in types_dict.items(): # name,type = 'a',int64
if isinstance(typ, types.Literal):
typ = typ.literal_type
if not checker.match(sig_params[name], typ):
return False

return True

for sig, func in sig_list: # sig = (Signature,func)
for param in sig.parameters: # param = {'a':int,'b':int}
if check_signature(param, values_dict):
return func

return None
146 changes: 146 additions & 0 deletions numba_typing/test_overload_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import overload_list
from overload_list import List, Dict
from overload_list import types
import unittest
import typing
from numba import njit, core


T = typing.TypeVar('T')
K = typing.TypeVar('K')


class TestOverloadList(unittest.TestCase):
maxDiff = None

def test_myfunc_literal_type_default(self):
def foo(a, b=0):
...

@overload_list.overload_list(foo)
def foo_ovld_list():

def foo_int_literal(a: int, b: int = 0):
return ('literal', a, b)

return (foo_int_literal,)

@njit
def jit_func(a):
return foo(a, 2)

self.assertEqual(jit_func(1), ('literal', 1, 2))

def test_myfunc_tuple_type_error(self):
def foo(a, b=(0, 0)):
...

@overload_list.overload_list(foo)
def foo_ovld_list():

def foo_tuple(a: typing.Tuple[int, int], b: tuple = (0, 0)):
return ('tuple_', a, b)

return (foo_tuple,)

@njit
def jit_func(a, b):
return foo(a, b)

self.assertRaises(core.errors.TypingError, jit_func, (1, 2, 3), ('3', False))


def generator_test(func_name, param, values_dict, defaults_dict={}):

def check_type(typ):
if isinstance(typ, type):
return typ.__name__
return typ

value_keys = ", ".join("{}".format(key) for key in values_dict.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use f'{key}' instead. Here and everywhere

defaults_keys = ", ".join("{}".format(key) for key in defaults_dict.keys())
value_str = ", ".join("{}: {}".format(key, check_type(val)) for key, val in values_dict.items())
defaults_str = ", ".join("{} = {}".format(key, val) if not isinstance(
val, str) else "{} = '{}'".format(key, val) for key, val in defaults_dict.items())
defaults_str_type = ", ".join("{}: {} = {}".format(key, check_type(type(val)), val) if not isinstance(val, str)
else "{}: {} = '{}'".format(key, check_type(type(val)), val)
for key, val in defaults_dict.items())
value_type = ", ".join("{}".format(val) for val in values_dict.values())
defaults_type = ", ".join("{}".format(type(val)) for val in defaults_dict.values())
param_qwe = ", ".join("{}".format(i) for i in param)
test = f"""
def test_myfunc_{func_name}_type_default(self):
def foo({value_keys},{defaults_str}):
...

@overload_list.overload_list(foo)
def foo_ovld_list():

def foo_{func_name}({value_str},{defaults_str_type}):
return ("{value_type}","{defaults_type}")

return (foo_{func_name},)

@njit
def jit_func({value_keys},{defaults_str}):
return foo({value_keys},{defaults_keys})

self.assertEqual(jit_func({param_qwe}), ("{value_type}", "{defaults_type}"))
"""
loc = {}
exec(test, globals(), loc)
return loc


L = List([1, 2, 3])
L_int = List([List([1, 2])])
L_float = List([List([List([3.0, 4.0])])])
L_f = List([1.0, 2.0])
D = Dict.empty(key_type=types.unicode_type, value_type=types.int64)
D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean)
D['qwe'] = 1
D['qaz'] = 2
D_1[1] = True
D_1[0] = False
list_type = types.ListType(types.int64)
D_list = Dict.empty(key_type=types.unicode_type, value_type=list_type)
D_list['qwe'] = List([3, 4, 5])
str_1 = 'qwe'
str_2 = 'qaz'
test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{'a': int, 'b': int} why do you need two parameters of the same type?
What are you actually testing?

('bool', [True, True], {'a': bool, 'b': bool}), ('str', ['str_1', 'str_2'], {'a': str, 'b': str}),
('list', [[1, 2], [3, 4]], {'a': typing.List[int], 'b':list}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need second list parameter?

('List_typed', [L, [3, 4]], {'a': typing.List[int], 'b':list}),
('tuple', [(1, 2.0), ('3', False)], {'a': typing.Tuple[int, float], 'b':tuple}),
('dict', ['D', 'D_1'], {'a': typing.Dict[str, int], 'b': typing.Dict[int, bool]}),
('union_1', [1, False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}),
('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why you are not testing second Union parameter?

('nested_list', ['L_int', 'L_float'], {'a': typing.List[typing.List[int]],
'b': typing.List[typing.List[typing.List[float]]]}),
('TypeVar_TT', ['L_f', [3.0, 4.0]], {'a': 'T', 'b': 'T'}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any negative tests for such case?

('TypeVar_TK', [1.0, 2], {'a': 'T', 'b': 'K'}),
('TypeVar_ListT_T', ['L', 5], {'a': 'typing.List[T]', 'b': 'T'}),
('TypeVar_ListT_DictKT', ['L', 'D'], {'a': 'typing.List[T]', 'b': 'typing.Dict[K, T]'}),
('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]',
'b': 'typing.Dict[K, typing.List[T]]'})]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are where any tests for TypeVar with specified types restriction?


test_cases_default = [('int_defaults', [1], {'a': int}, {'b': 0}), ('float_defaults', [1.0], {'a': float}, {'b': 0.0}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what about both - annotation and default value?

('bool_defaults', [True], {'a': bool}, {'b': False}),
('str_defaults', ['str_1'], {'a': str}, {'b': '0'}),
('tuple_defaults', [(1, 2)], {'a': tuple}, {'b': (0, 0)})]


for name, val, annotation in test_cases:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer you to group all cases into 4-5 tests with subtests. Something like this:

def test_common_types():
    test_cases = [({'a': 0}, {'a': int}),
                 ({'a': 0.}, {'a': float}),
                ...]
    
    for case in test_cases:
        with self.Subtest(case=case):
            run_test(case)

run_generator = generator_test(name, val, annotation)
test_name = list(run_generator.keys())[0]
setattr(TestOverloadList, test_name, run_generator[test_name])


for name, val, annotation, defaults in test_cases_default:
run_generator = generator_test(name, val, annotation, defaults)
test_name = list(run_generator.keys())[0]
setattr(TestOverloadList, test_name, run_generator[test_name])


if __name__ == "__main__":
unittest.main()