-
Notifications
You must be signed in to change notification settings - Fork 62
automatic generation of type checks for overload #911
base: numba_typing
Are you sure you want to change the base?
Changes from 8 commits
8bd7bc0
d3f4a5d
b54da17
05c745b
b7446ca
09289bf
1cb60da
967ae29
716402c
2096e94
5ec33ac
7da564b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| import numba | ||
| from numba import types | ||
| from numba.extending import overload | ||
| from type_annotations import product_annotations, get_func_annotations | ||
| import typing | ||
| from numba.typed import List, Dict | ||
| from inspect import getfullargspec | ||
|
|
||
|
|
||
| def overload_list(orig_func): | ||
| def overload_inner(ovld_list): | ||
| def wrapper(*args): | ||
| func_list = ovld_list() | ||
| sig_list = [] | ||
| for func in func_list: | ||
| sig_list.append((product_annotations( | ||
| get_func_annotations(func)), func)) | ||
| args_orig_func = getfullargspec(orig_func) | ||
| values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)} | ||
| defaults_dict = {} | ||
| if args_orig_func.defaults: | ||
| defaults_dict = {name: value for name, value in zip( | ||
| args_orig_func.args[::-1], args_orig_func.defaults[::-1])} | ||
| if valid_signature(sig_list, values_dict, defaults_dict): | ||
| result = choose_func_by_sig(sig_list, values_dict) | ||
|
|
||
| if result is None: | ||
| raise TypeError(f'Unsupported types a={a}, b={b}') | ||
|
|
||
| return result | ||
|
|
||
| return overload(orig_func, strict=False)(wrapper) | ||
|
|
||
| return overload_inner | ||
|
|
||
|
|
||
| def valid_signature(list_signature, values_dict, defaults_dict): | ||
| def check_defaults(sig_def): | ||
| for name, val in defaults_dict.items(): | ||
| if sig_def.get(name) is None: | ||
| raise AttributeError(f'{name} does not match the signature of the function passed to overload_list') | ||
| if not sig_def[name] == val: | ||
|
||
| raise ValueError(f'The default arguments are not equal: {name}: {val} != {sig_def[name]}') | ||
|
|
||
| for sig, _ in list_signature: | ||
| for param in sig.parameters: | ||
| if len(param) != len(values_dict.items()): | ||
|
||
| check_defaults(sig.defaults) | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| def check_int_type(n_type): | ||
| return isinstance(n_type, types.Integer) | ||
|
|
||
|
|
||
| def check_float_type(n_type): | ||
| return isinstance(n_type, types.Float) | ||
|
|
||
|
|
||
| def check_bool_type(n_type): | ||
| return isinstance(n_type, types.Boolean) | ||
|
|
||
|
|
||
| def check_str_type(n_type): | ||
| return isinstance(n_type, types.UnicodeType) | ||
|
|
||
|
|
||
| def check_list_type(self, p_type, n_type): | ||
| res = isinstance(n_type, (types.List, types.ListType)) | ||
| if p_type == list: | ||
| return res | ||
| else: | ||
| return res and self.match(p_type.__args__[0], n_type.dtype) | ||
|
|
||
|
|
||
| def check_tuple_type(self, p_type, n_type): | ||
| if not isinstance(n_type, (types.Tuple, types.UniTuple)): | ||
| return False | ||
| try: | ||
| if len(p_type.__args__) != len(n_type.types): | ||
| return False | ||
| except AttributeError: # if p_type == tuple | ||
| return True | ||
|
|
||
| for p_val, n_val in zip(p_type.__args__, n_type.types): | ||
| if not self.match(p_val, n_val): | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| def check_dict_type(self, p_type, n_type): | ||
| res = False | ||
| if isinstance(n_type, types.DictType): | ||
| res = True | ||
| if isinstance(p_type, type): | ||
| return res | ||
| for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type): | ||
| res = res and self.match(p_val, n_val) | ||
| return res | ||
|
|
||
|
|
||
| class TypeChecker: | ||
|
|
||
| _types_dict: dict = {} | ||
|
|
||
| def __init__(self): | ||
| self._typevars_dict = {} | ||
|
|
||
| def clear_typevars_dict(self): | ||
| self._typevars_dict.clear() | ||
|
|
||
| @classmethod | ||
| def add_type_check(cls, type_check, func): | ||
| cls._types_dict[type_check] = func | ||
|
|
||
| @staticmethod | ||
| def _is_generic(p_obj): | ||
| if isinstance(p_obj, typing._GenericAlias): | ||
| return True | ||
|
|
||
| if isinstance(p_obj, typing._SpecialForm): | ||
| return p_obj not in {typing.Any} | ||
|
|
||
| return False | ||
|
|
||
| @staticmethod | ||
| def _get_origin(p_obj): | ||
| return p_obj.__origin__ | ||
|
|
||
| def match(self, p_type, n_type): | ||
| if p_type == typing.Any: | ||
| return True | ||
| try: | ||
| if self._is_generic(p_type): | ||
| origin_type = self._get_origin(p_type) | ||
| if origin_type == typing.Generic: | ||
| return self.match_generic(p_type, n_type) | ||
|
|
||
| return self._types_dict[origin_type](self, p_type, n_type) | ||
|
|
||
| if isinstance(p_type, typing.TypeVar): | ||
| return self.match_typevar(p_type, n_type) | ||
|
|
||
| if p_type in (list, tuple, dict): | ||
| return self._types_dict[p_type](self, p_type, n_type) | ||
|
|
||
| return self._types_dict[p_type](n_type) | ||
|
|
||
| except KeyError: | ||
| raise TypeError(f'A check for the {p_type} was not found.') | ||
|
|
||
| def match_typevar(self, p_type, n_type): | ||
| if isinstance(n_type, types.List): | ||
| n_type = types.ListType(n_type.dtype) | ||
| if not self._typevars_dict.get(p_type): | ||
| self._typevars_dict[p_type] = n_type | ||
| return True | ||
| return self._typevars_dict.get(p_type) == n_type | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it should be |
||
|
|
||
| def match_generic(self, p_type, n_type): | ||
| raise SystemError | ||
|
|
||
|
|
||
| TypeChecker.add_type_check(int, check_int_type) | ||
| TypeChecker.add_type_check(float, check_float_type) | ||
| TypeChecker.add_type_check(str, check_str_type) | ||
| TypeChecker.add_type_check(bool, check_bool_type) | ||
| TypeChecker.add_type_check(list, check_list_type) | ||
| TypeChecker.add_type_check(tuple, check_tuple_type) | ||
| TypeChecker.add_type_check(dict, check_dict_type) | ||
|
|
||
|
|
||
| def choose_func_by_sig(sig_list, values_dict): | ||
| def check_signature(sig_params, types_dict): | ||
| checker = TypeChecker() | ||
| for name, typ in types_dict.items(): # name,type = 'a',int64 | ||
| if isinstance(typ, types.Literal): | ||
| typ = typ.literal_type | ||
| if not checker.match(sig_params[name], typ): | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| for sig, func in sig_list: # sig = (Signature,func) | ||
| for param in sig.parameters: # param = {'a':int,'b':int} | ||
| if check_signature(param, values_dict): | ||
| return func | ||
|
|
||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import overload_list | ||
| from overload_list import List, Dict | ||
| from overload_list import types | ||
| import unittest | ||
| import typing | ||
| from numba import njit, core | ||
|
|
||
|
|
||
| T = typing.TypeVar('T') | ||
| K = typing.TypeVar('K') | ||
|
|
||
|
|
||
| class TestOverloadList(unittest.TestCase): | ||
| maxDiff = None | ||
|
|
||
| def test_myfunc_literal_type_default(self): | ||
| def foo(a, b=0): | ||
| ... | ||
|
|
||
| @overload_list.overload_list(foo) | ||
| def foo_ovld_list(): | ||
|
|
||
| def foo_int_literal(a: int, b: int = 0): | ||
| return ('literal', a, b) | ||
|
|
||
| return (foo_int_literal,) | ||
|
|
||
| @njit | ||
| def jit_func(a): | ||
| return foo(a, 2) | ||
|
|
||
| self.assertEqual(jit_func(1), ('literal', 1, 2)) | ||
|
|
||
| def test_myfunc_tuple_type_error(self): | ||
| def foo(a, b=(0, 0)): | ||
| ... | ||
|
|
||
| @overload_list.overload_list(foo) | ||
| def foo_ovld_list(): | ||
|
|
||
| def foo_tuple(a: typing.Tuple[int, int], b: tuple = (0, 0)): | ||
| return ('tuple_', a, b) | ||
|
|
||
| return (foo_tuple,) | ||
|
|
||
| @njit | ||
| def jit_func(a, b): | ||
| return foo(a, b) | ||
|
|
||
| self.assertRaises(core.errors.TypingError, jit_func, (1, 2, 3), ('3', False)) | ||
|
|
||
|
|
||
| def generator_test(func_name, param, values_dict, defaults_dict={}): | ||
|
|
||
| def check_type(typ): | ||
| if isinstance(typ, type): | ||
| return typ.__name__ | ||
| return typ | ||
|
|
||
| value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) | ||
|
||
| defaults_keys = ", ".join("{}".format(key) for key in defaults_dict.keys()) | ||
| value_str = ", ".join("{}: {}".format(key, check_type(val)) for key, val in values_dict.items()) | ||
| defaults_str = ", ".join("{} = {}".format(key, val) if not isinstance( | ||
| val, str) else "{} = '{}'".format(key, val) for key, val in defaults_dict.items()) | ||
| defaults_str_type = ", ".join("{}: {} = {}".format(key, check_type(type(val)), val) if not isinstance(val, str) | ||
| else "{}: {} = '{}'".format(key, check_type(type(val)), val) | ||
| for key, val in defaults_dict.items()) | ||
| value_type = ", ".join("{}".format(val) for val in values_dict.values()) | ||
| defaults_type = ", ".join("{}".format(type(val)) for val in defaults_dict.values()) | ||
| param_qwe = ", ".join("{}".format(i) for i in param) | ||
| test = f""" | ||
| def test_myfunc_{func_name}_type_default(self): | ||
| def foo({value_keys},{defaults_str}): | ||
| ... | ||
|
|
||
| @overload_list.overload_list(foo) | ||
| def foo_ovld_list(): | ||
|
|
||
| def foo_{func_name}({value_str},{defaults_str_type}): | ||
| return ("{value_type}","{defaults_type}") | ||
|
|
||
| return (foo_{func_name},) | ||
|
|
||
| @njit | ||
| def jit_func({value_keys},{defaults_str}): | ||
| return foo({value_keys},{defaults_keys}) | ||
|
|
||
| self.assertEqual(jit_func({param_qwe}), ("{value_type}", "{defaults_type}")) | ||
| """ | ||
| loc = {} | ||
| exec(test, globals(), loc) | ||
| return loc | ||
|
|
||
|
|
||
| L = List([1, 2, 3]) | ||
| L_int = List([List([1, 2])]) | ||
| L_float = List([List([List([3.0, 4.0])])]) | ||
| L_f = List([1.0, 2.0]) | ||
| D = Dict.empty(key_type=types.unicode_type, value_type=types.int64) | ||
| D_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) | ||
| D['qwe'] = 1 | ||
| D['qaz'] = 2 | ||
| D_1[1] = True | ||
| D_1[0] = False | ||
| list_type = types.ListType(types.int64) | ||
| D_list = Dict.empty(key_type=types.unicode_type, value_type=list_type) | ||
| D_list['qwe'] = List([3, 4, 5]) | ||
| str_1 = 'qwe' | ||
| str_2 = 'qaz' | ||
| test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}), | ||
|
||
| ('bool', [True, True], {'a': bool, 'b': bool}), ('str', ['str_1', 'str_2'], {'a': str, 'b': str}), | ||
| ('list', [[1, 2], [3, 4]], {'a': typing.List[int], 'b':list}), | ||
|
||
| ('List_typed', [L, [3, 4]], {'a': typing.List[int], 'b':list}), | ||
| ('tuple', [(1, 2.0), ('3', False)], {'a': typing.Tuple[int, float], 'b':tuple}), | ||
| ('dict', ['D', 'D_1'], {'a': typing.Dict[str, int], 'b': typing.Dict[int, bool]}), | ||
| ('union_1', [1, False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), | ||
| ('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), | ||
|
||
| ('nested_list', ['L_int', 'L_float'], {'a': typing.List[typing.List[int]], | ||
| 'b': typing.List[typing.List[typing.List[float]]]}), | ||
| ('TypeVar_TT', ['L_f', [3.0, 4.0]], {'a': 'T', 'b': 'T'}), | ||
|
||
| ('TypeVar_TK', [1.0, 2], {'a': 'T', 'b': 'K'}), | ||
| ('TypeVar_ListT_T', ['L', 5], {'a': 'typing.List[T]', 'b': 'T'}), | ||
| ('TypeVar_ListT_DictKT', ['L', 'D'], {'a': 'typing.List[T]', 'b': 'typing.Dict[K, T]'}), | ||
| ('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]', | ||
| 'b': 'typing.Dict[K, typing.List[T]]'})] | ||
|
||
|
|
||
| test_cases_default = [('int_defaults', [1], {'a': int}, {'b': 0}), ('float_defaults', [1.0], {'a': float}, {'b': 0.0}), | ||
|
||
| ('bool_defaults', [True], {'a': bool}, {'b': False}), | ||
| ('str_defaults', ['str_1'], {'a': str}, {'b': '0'}), | ||
| ('tuple_defaults', [(1, 2)], {'a': tuple}, {'b': (0, 0)})] | ||
|
|
||
|
|
||
| for name, val, annotation in test_cases: | ||
|
||
| run_generator = generator_test(name, val, annotation) | ||
| test_name = list(run_generator.keys())[0] | ||
| setattr(TestOverloadList, test_name, run_generator[test_name]) | ||
|
|
||
|
|
||
| for name, val, annotation, defaults in test_cases_default: | ||
| run_generator = generator_test(name, val, annotation, defaults) | ||
| test_name = list(run_generator.keys())[0] | ||
| setattr(TestOverloadList, test_name, run_generator[test_name]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks both
aandbare undefined.