Skip to content

Commit f5ba06d

Browse files
authored
Add initial backwards compatibility tests (#958)
1 parent cfe3d9b commit f5ba06d

File tree

3 files changed

+242
-3
lines changed

3 files changed

+242
-3
lines changed

helion/runtime/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __eq__(self, other: object) -> bool:
109109
return self.config == other.config
110110

111111
def __hash__(self) -> int:
112-
return hash(frozenset([(k, _list_to_tuple(v)) for k, v in self.config.items()]))
112+
return hash(frozenset([(k, _to_hashable(v)) for k, v in self.config.items()]))
113113

114114
def __getstate__(self) -> dict[str, object]:
115115
return dict(self.config)
@@ -207,7 +207,9 @@ def indexing(self) -> IndexingLiteral:
207207
return self.config.get("indexing", "pointer") # type: ignore[return-value]
208208

209209

210-
def _list_to_tuple(x: object) -> object:
210+
def _to_hashable(x: object) -> object:
211211
if isinstance(x, list):
212-
return tuple([_list_to_tuple(i) for i in x])
212+
return tuple([_to_hashable(i) for i in x])
213+
if isinstance(x, dict):
214+
return tuple(sorted([(k, _to_hashable(v)) for k, v in x.items()]))
213215
return x

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ filecheck
55
expecttest
66
numpy
77
rich
8+
hypothesis

test/test_config_api.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import inspect
5+
import pickle
6+
from typing import Any
7+
import unittest
8+
9+
from hypothesis import given
10+
from hypothesis import settings
11+
from hypothesis import strategies as st
12+
13+
import helion
14+
from helion._testing import TestCase
15+
16+
17+
def _json_safe_values() -> st.SearchStrategy[Any]:
18+
# JSON-safe primitives/containers
19+
scalar = st.one_of(
20+
st.integers(), st.floats(allow_nan=False), st.booleans(), st.text()
21+
)
22+
leaf = st.one_of(scalar, st.none())
23+
return st.recursive(
24+
leaf,
25+
lambda children: st.one_of(
26+
st.lists(children, max_size=4),
27+
st.dictionaries(st.text(min_size=0, max_size=8), children, max_size=4),
28+
),
29+
max_leaves=8,
30+
)
31+
32+
33+
def _known_keys_strategy() -> st.SearchStrategy[dict[str, Any]]:
34+
# For known keys, None values are omitted by constructor; favor non-None
35+
return st.fixed_dictionaries(
36+
{
37+
"block_sizes": st.lists(
38+
st.integers(min_value=1, max_value=4096), max_size=4
39+
),
40+
"loop_orders": st.lists(
41+
st.lists(st.integers(min_value=0, max_value=4), max_size=4),
42+
max_size=3,
43+
),
44+
"flatten_loops": st.lists(st.booleans(), max_size=4),
45+
"l2_groupings": st.lists(
46+
st.integers(min_value=1, max_value=128), max_size=4
47+
),
48+
"reduction_loops": st.lists(
49+
st.one_of(st.integers(min_value=0, max_value=8), st.none()),
50+
max_size=4,
51+
),
52+
"range_unroll_factors": st.lists(
53+
st.integers(min_value=1, max_value=16), max_size=4
54+
),
55+
"range_warp_specializes": st.lists(
56+
st.one_of(st.booleans(), st.none()), max_size=4
57+
),
58+
"range_num_stages": st.lists(
59+
st.integers(min_value=1, max_value=8), max_size=4
60+
),
61+
"range_multi_buffers": st.lists(
62+
st.one_of(st.booleans(), st.none()), max_size=4
63+
),
64+
"range_flattens": st.lists(st.one_of(st.booleans(), st.none()), max_size=4),
65+
"static_ranges": st.lists(st.booleans(), max_size=4),
66+
"load_eviction_policies": st.lists(
67+
st.sampled_from(["", "first", "last"]), max_size=4
68+
),
69+
"num_warps": st.integers(min_value=1, max_value=64),
70+
"num_stages": st.integers(min_value=1, max_value=16),
71+
"pid_type": st.sampled_from(
72+
["flat", "xyz", "persistent_blocked", "persistent_interleaved"]
73+
),
74+
"indexing": st.sampled_from(["pointer", "tensor_descriptor", "block_ptr"]),
75+
}
76+
)
77+
78+
79+
def _unknown_keys_strategy() -> st.SearchStrategy[dict[str, Any]]:
80+
key = st.from_regex(r"[A-Za-z_][A-Za-z0-9_]{0,12}")
81+
# Avoid colliding with known keys and enforce distinctness
82+
return st.dictionaries(
83+
keys=key.filter(
84+
lambda k: k
85+
not in {
86+
"block_sizes",
87+
"loop_orders",
88+
"flatten_loops",
89+
"l2_groupings",
90+
"reduction_loops",
91+
"range_unroll_factors",
92+
"range_warp_specializes",
93+
"range_num_stages",
94+
"range_multi_buffers",
95+
"range_flattens",
96+
"static_ranges",
97+
"load_eviction_policies",
98+
"num_warps",
99+
"num_stages",
100+
"pid_type",
101+
"indexing",
102+
}
103+
),
104+
values=_json_safe_values(),
105+
max_size=4,
106+
)
107+
108+
109+
class TestConfigAPI(TestCase):
110+
def test_config_import_path_stability(self) -> None:
111+
runtime = importlib.import_module("helion.runtime")
112+
113+
self.assertIs(helion.Config, runtime.Config)
114+
self.assertIs(helion.Config, helion.runtime.Config)
115+
116+
def test_config_constructor_signature_contains_expected_kwargs(self) -> None:
117+
# Keep this list in sync with public kwargs; removal/rename should fail tests
118+
expected = {
119+
"block_sizes",
120+
"loop_orders",
121+
"flatten_loops",
122+
"l2_groupings",
123+
"reduction_loops",
124+
"range_unroll_factors",
125+
"range_warp_specializes",
126+
"range_num_stages",
127+
"range_multi_buffers",
128+
"range_flattens",
129+
"static_ranges",
130+
"load_eviction_policies",
131+
"num_warps",
132+
"num_stages",
133+
"pid_type",
134+
"indexing",
135+
}
136+
137+
sig = inspect.signature(helion.Config.__init__)
138+
kwonly = {
139+
name
140+
for name, p in sig.parameters.items()
141+
if p.kind is inspect.Parameter.KEYWORD_ONLY
142+
}
143+
# Expected kwargs must be present as keyword-only
144+
self.assertTrue(expected.issubset(kwonly))
145+
146+
def test_mapping_behavior_len_iter_dict_roundtrip(self) -> None:
147+
data = {
148+
"block_sizes": [64, 32],
149+
"num_warps": 8,
150+
"custom_extra": {"a": 1},
151+
}
152+
cfg = helion.Config(**data)
153+
154+
# Supports Mapping protocol
155+
self.assertEqual(len(cfg), len(cfg.config))
156+
self.assertEqual(dict(cfg), cfg.config)
157+
self.assertEqual(set(iter(cfg)), set(cfg.config.keys()))
158+
159+
# Equality and hash coherence
160+
cfg2 = helion.Config(**data)
161+
self.assertEqual(cfg, cfg2)
162+
self.assertEqual(hash(cfg), hash(cfg2))
163+
164+
@settings(deadline=None)
165+
@given(
166+
st.builds(lambda a, b: (a, b), _known_keys_strategy(), _unknown_keys_strategy())
167+
)
168+
def test_json_roundtrip_preserves_keys_and_values(
169+
self, pair: tuple[dict[str, Any], dict[str, Any]]
170+
) -> None:
171+
known, unknown = pair
172+
data = {**known, **unknown}
173+
cfg = helion.Config(**data)
174+
175+
# JSON round-trip
176+
json_str = cfg.to_json()
177+
restored = helion.Config.from_json(json_str)
178+
179+
# Compare as dicts; JSON dumps may reorder keys
180+
self.assertEqual(dict(restored), dict(cfg))
181+
182+
# Unknown keys must persist
183+
for k in unknown:
184+
self.assertIn(k, restored)
185+
self.assertEqual(restored[k], unknown[k])
186+
187+
@settings(deadline=None)
188+
@given(_known_keys_strategy(), _unknown_keys_strategy())
189+
def test_pickle_roundtrip_preserves_equality_and_hash(
190+
self, known: dict[str, Any], unknown: dict[str, Any]
191+
) -> None:
192+
data = {**known, **unknown}
193+
cfg = helion.Config(**data)
194+
blob = pickle.dumps(cfg)
195+
restored = pickle.loads(blob)
196+
197+
self.assertEqual(restored, cfg)
198+
self.assertEqual(hash(restored), hash(cfg))
199+
200+
def test_list_tuple_hash_equivalence(self) -> None:
201+
cfg_list = helion.Config(block_sizes=[32, 64], loop_orders=[[1, 0]])
202+
cfg_tuple = helion.Config(block_sizes=[32, 64], loop_orders=[[1, 0]])
203+
204+
# Same content should be equal and have equal hashes
205+
self.assertEqual(cfg_list, cfg_tuple)
206+
self.assertEqual(hash(cfg_list), hash(cfg_tuple))
207+
208+
def test_pre_serialized_json_backward_compat(self) -> None:
209+
# Simulated config JSON saved in a prior release (hand-written, stable keys)
210+
json_str = (
211+
"{\n"
212+
' "block_sizes": [64, 32],\n'
213+
' "num_warps": 8,\n'
214+
' "indexing": "pointer",\n'
215+
' "custom_extra": {"alpha": 1, "beta": [1, 2]}\n'
216+
"}\n"
217+
)
218+
219+
restored = helion.Config.from_json(json_str)
220+
221+
expected = {
222+
"block_sizes": [64, 32],
223+
"num_warps": 8,
224+
"indexing": "pointer",
225+
"custom_extra": {"alpha": 1, "beta": [1, 2]},
226+
}
227+
self.assertEqual(dict(restored), expected)
228+
229+
# Ensure we can still serialize it back and preserve content
230+
rejson = restored.to_json()
231+
reread = helion.Config.from_json(rejson)
232+
self.assertEqual(dict(reread), expected)
233+
234+
235+
if __name__ == "__main__":
236+
unittest.main()

0 commit comments

Comments
 (0)