diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index d3d00052..42d0be4d 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -7,6 +7,7 @@ from warnings import warn from copy import deepcopy from collections import defaultdict, deque +from inspect import signature import numpy as np from scipy.stats.qmc import LatinHypercube @@ -495,6 +496,13 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: def __add_restrictions(self, parameter_space: Problem) -> Problem: """Add the user-specified restrictions as constraints on the parameter space.""" restrictions = deepcopy(self.restrictions) + # differentiate between old style monolithic with single 'p' argument and newer *args style + if (len(restrictions) == 1 + and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) + and callable(restrictions[0]) + and len(signature(restrictions[0]).parameters) == 1 + and len(self.param_names) > 1): + restrictions = restrictions[0] if isinstance(restrictions, list): for restriction in restrictions: required_params = self.param_names @@ -504,10 +512,6 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: required_params = restriction[1] restriction = restriction[0] if callable(restriction) and not isinstance(restriction, Constraint): - # def restrictions_wrapper(*args): - # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) - # print(restriction, isinstance(restriction, Constraint)) - # restriction = FunctionConstraint(restrictions_wrapper) restriction = FunctionConstraint(restriction, required_params) # add as a Constraint @@ -529,6 +533,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: elif callable(restrictions): def restrictions_wrapper(*args): + """Wrap old-style monolithic restrictions to work with multiple arguments.""" return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False) parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names) diff --git a/test/test_searchspace.py b/test/test_searchspace.py index f742a4b7..32f58bec 100644 --- a/test/test_searchspace.py +++ b/test/test_searchspace.py @@ -623,3 +623,37 @@ def test_full_searchspace(compare_against_bruteforce=False): compare_two_searchspace_objects(searchspace, searchspace_bruteforce) else: assert searchspace.size == len(searchspace.list) == 349853 + +def test_restriction_backwards_compatibility(): + """Test whether the backwards compatibility code for restrictions (list of strings) works as expected.""" + # create a searchspace with mixed parameter types + max_threads = 1024 + tune_params = dict() + tune_params["N_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024] + tune_params["M_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024] + tune_params["block_size_y"] = [1, 2, 4, 8, 16, 32] + tune_params["block_size_z"] = [1, 2, 4, 8, 16, 32] + + # old style monolithic restriction function + def restrict(p): + n_global_per_warp = int(p["N_PER_BLOCK"] // p["block_size_y"]) + m_global_per_warp = int(p["M_PER_BLOCK"] // p["block_size_z"]) + if n_global_per_warp == 0 or m_global_per_warp == 0: + return False + + searchspace_callable = Searchspace(tune_params, restrict, max_threads) + + def restrict_args(N_PER_BLOCK, M_PER_BLOCK, block_size_y, block_size_z): + n_global_per_warp = int(N_PER_BLOCK // block_size_y) + m_global_per_warp = int(M_PER_BLOCK // block_size_z) + if n_global_per_warp == 0 or m_global_per_warp == 0: + return False + + # args-style restriction + searchspace_str = Searchspace(tune_params, restrict_args, max_threads) + + # check the size + assert searchspace_str.size == searchspace_callable.size + + # check that both searchspaces are identical in outcome + compare_two_searchspace_objects(searchspace_str, searchspace_callable)