Skip to content

Commit b416d4f

Browse files
authored
Remove unrolling + warp spec (#967)
1 parent f5ba06d commit b416d4f

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

helion/autotuner/config_spec.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,6 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
196196
name, config.get(name, ()), block_ids=static_range_block_ids
197197
)
198198

199-
# Only one range_warp_specializes is allowed, take the last one
200-
range_warp_specializes = cast(
201-
"list[bool | None]", config.get("range_warp_specializes", [])
202-
)
203-
for i in [j for j, val in enumerate(range_warp_specializes) if val][:-1]:
204-
range_warp_specializes[i] = None
205-
206199
for name in (
207200
"loop_orders",
208201
"l2_groupings",
@@ -252,6 +245,23 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
252245
name, config.get(name, ()), block_ids=self.grid_block_ids
253246
)
254247

248+
# Only one range_warp_specializes is allowed, take the last one
249+
range_warp_specializes = cast(
250+
"list[bool | None]", config.get("range_warp_specializes", [])
251+
)
252+
253+
if range_warp_specializes and any(range_warp_specializes):
254+
for i in [j for j, val in enumerate(range_warp_specializes) if val][:-1]:
255+
range_warp_specializes[i] = None
256+
257+
range_unroll_factors = cast(
258+
"list[int]", config.get("range_unroll_factors", [])
259+
)
260+
if range_unroll_factors and range_unroll_factors[-1]:
261+
range_unroll_factors[-1] = 0
262+
config["range_unroll_factors"] = range_unroll_factors
263+
264+
config["range_warp_specializes"] = range_warp_specializes
255265
# Allow tunable parameter keys in addition to VALID_KEYS
256266
allowed_keys = VALID_KEYS | {*self.user_defined_tunables.keys()}
257267
if invalid_keys := ({*config} - allowed_keys):

test/test_autotuner.expected

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,40 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
33

44
--- assertExpectedJournal(TestAutotuner.test_config_fragment0)
55
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
6-
helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True])
6+
helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 0], range_warp_specializes=[None, True])
77
helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False])
8-
helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 3], range_warp_specializes=[True, None])
8+
helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 0], range_warp_specializes=[True, None])
99
helion.Config(block_sizes=[64, 32, 16], indexing='block_ptr', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None])
1010
helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[32], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, None], range_num_stages=[0, 2], range_unroll_factors=[0, 2], range_warp_specializes=[None, False])
1111
helion.Config(block_sizes=[16, 32, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 3], range_unroll_factors=[0, 3], range_warp_specializes=[None, None])
12-
helion.Config(block_sizes=[16, 32, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, None], range_num_stages=[3, 3], range_unroll_factors=[2, 3], range_warp_specializes=[False, True])
13-
helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 1], range_unroll_factors=[0, 2], range_warp_specializes=[None, True])
14-
helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 4], range_warp_specializes=[False, True])
12+
helion.Config(block_sizes=[16, 32, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, None], range_num_stages=[3, 3], range_unroll_factors=[2, 0], range_warp_specializes=[False, True])
13+
helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 1], range_unroll_factors=[0, 0], range_warp_specializes=[None, True])
14+
helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 0], range_warp_specializes=[False, True])
1515

1616
--- assertExpectedJournal(TestAutotuner.test_config_fragment1)
1717
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
18-
helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True])
18+
helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
1919
helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
20-
helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[True])
21-
helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[1], range_warp_specializes=[True])
20+
helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
21+
helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True])
2222
helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[4], range_warp_specializes=[None])
2323
helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None])
2424
helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
25-
helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[3], range_warp_specializes=[True])
25+
helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
2626
helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[False])
2727

28+
--- assertExpectedJournal(TestAutotuner.test_config_warp_specialize_unroll)
29+
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
30+
helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
31+
helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
32+
helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
33+
helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True])
34+
helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
35+
helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True])
36+
helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
37+
helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
38+
helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True])
39+
2840
--- assertExpectedJournal(TestAutotuner.test_save_load_config)
2941
{
3042
"block_sizes": [

test/test_autotuner.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,23 @@ def test_config_fragment1(self):
160160
configs = ConfigGeneration(spec).random_population(10)
161161
self.assertExpectedJournal("\n".join(map(repr, configs)))
162162

163+
@patch(
164+
"helion.autotuner.config_generation.warps_to_threads",
165+
lambda num_warps: num_warps * 32,
166+
)
167+
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
168+
@patch.object(loops, "_supports_warp_specialize", lambda: True)
169+
def test_config_warp_specialize_unroll(self):
170+
args = (
171+
torch.randn([8, 512, 512], device=DEVICE),
172+
torch.randn([8, 512, 512], device=DEVICE),
173+
)
174+
spec = basic_kernels.add.bind(args).config_spec
175+
overrides = {"range_unroll_factors": [4], "range_warp_specializes": ([True])}
176+
# We expect all the unroll factors to be set to 0
177+
configs = ConfigGeneration(spec, overrides=overrides).random_population(10)
178+
self.assertExpectedJournal("\n".join(map(repr, configs)))
179+
163180
def test_config_generation_overrides(self):
164181
args = (
165182
torch.randn([8, 512, 512], device=DEVICE),

0 commit comments

Comments
 (0)