Skip to content

Commit 8c8b99f

Browse files
authored
Add hl.rand op with seed arg lowering to tl.rand (#652)
1 parent 5ccf6f4 commit 8c8b99f

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .matmul_ops import dot as dot
2222
from .memory_ops import load as load
2323
from .memory_ops import store as store
24+
from .random_ops import rand as rand
2425
from .reduce_ops import reduce as reduce
2526
from .scan_ops import associative_scan as associative_scan
2627
from .scan_ops import cumprod as cumprod

helion/language/random_ops.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from .._compiler.ast_extension import expr_from_string
8+
from .._compiler.compile_environment import CompileEnvironment
9+
from ..exc import NotInsideKernel
10+
from . import _decorators
11+
from .ref_tile import RefTile
12+
13+
if TYPE_CHECKING:
14+
import ast
15+
16+
from .._compiler.inductor_lowering import CodegenState
17+
18+
__all__ = ["rand"]
19+
20+
21+
@_decorators.api(tiles_as_sizes=True)
22+
def rand(
23+
shape: list[object],
24+
seed: int,
25+
dtype: torch.dtype = torch.float32,
26+
device: torch.device | None = None,
27+
) -> torch.Tensor:
28+
"""
29+
The main propose of ``hl.rand`` is to explicitly pass a seed arg for deterministic
30+
randomness in helion kernels, whereas ``torch.rand_like`` doesn't take seed arg
31+
(though it can seeded globally)`. ``hl.rand`` lower to ``tl.rand(seed, offset)`` with ``offset``
32+
built from a linear range over the allocation and reshaped to the given shape.
33+
34+
Note:
35+
Only use within ``hl.tile()`` loops for creating local tensors.
36+
For host allocations, use ``torch.rand()``.
37+
38+
Args:
39+
shape: A list of sizes
40+
seed: int seed for the random number generator
41+
dtype: currently only float32 supported
42+
43+
Returns:
44+
torch.Tensor: A device tensor of the given shape and dtype filled with random values
45+
46+
Examples:
47+
.. code-block:: python
48+
49+
@helion.kernel
50+
def process_kernel(x: torch.Tensor) -> torch.Tensor:
51+
output = torch.zeros_like(x)
52+
(m,) = x.shape
53+
for (tile_m,) in hl.tile([m]):
54+
output[tile_m] = hl.rand([tile_m], seed=seed)
55+
return output
56+
57+
"""
58+
raise NotInsideKernel
59+
60+
61+
@_decorators.register_fake(rand)
62+
def _rand_fake(
63+
shape: list[int | torch.SymInt],
64+
seed: int,
65+
dtype: torch.dtype = torch.float32,
66+
device: torch.device | None = None,
67+
) -> torch.Tensor:
68+
if not isinstance(shape, (list, tuple)):
69+
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
70+
env = CompileEnvironment.current()
71+
env.add_kernel_tensor_size(shape)
72+
return torch.empty(
73+
[*shape],
74+
dtype=dtype,
75+
device=env.device if device is None else device,
76+
)
77+
78+
79+
@_decorators.codegen(rand)
80+
def _rand_codegen(state: CodegenState) -> ast.AST:
81+
fake_value = state.fake_value
82+
assert isinstance(fake_value, torch.Tensor)
83+
shape_str = state.device_function.tile_strategy.shape_str(fake_value.size())
84+
85+
numel = " * ".join(shape_str.strip("[]").split(","))
86+
seed_ast = state.ast_arg(1)
87+
offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})"
88+
expr = f"tl.rand({{seed}}, {offs_expr})"
89+
90+
return expr_from_string(expr, seed=seed_ast)
91+
92+
93+
@_decorators.get_masked_value(rand)
94+
def _(
95+
node: torch.fx.Node,
96+
) -> float:
97+
return 0
98+
99+
100+
@_decorators.ref(rand)
101+
def _(
102+
shape: list[int | RefTile],
103+
seed: int,
104+
dtype: torch.dtype = torch.float32,
105+
device: torch.device | None = None,
106+
) -> torch.Tensor:
107+
processed_shape: list[int] = []
108+
for s in shape:
109+
if isinstance(s, RefTile):
110+
processed_shape.append(s.end - s.begin)
111+
else:
112+
processed_shape.append(int(s))
113+
env = CompileEnvironment.current()
114+
gen = torch.Generator(device=env.device if device is None else device)
115+
gen.manual_seed(seed)
116+
return torch.rand(
117+
processed_shape,
118+
dtype=dtype,
119+
generator=gen,
120+
device=env.device if device is None else device,
121+
)

test/test_rng.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,97 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
348348
f"Slice {b_idx} std {slice_std} is not well distributed",
349349
)
350350

351+
def test_hl_rand_1d(self):
352+
@helion.kernel
353+
def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor:
354+
output = torch.zeros_like(x)
355+
(m,) = x.shape
356+
for (tile_m,) in hl.tile([m]):
357+
output[tile_m] = hl.rand([tile_m], seed=seed)
358+
return output
359+
360+
x_small = torch.ones(128, device=DEVICE)
361+
_, output = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
362+
_, output2 = code_and_output(rand_kernel_tiled_1d, (x_small, 1337))
363+
364+
self.assertFalse(
365+
torch.allclose(output, output2),
366+
"Different seeds should produce different outputs",
367+
)
368+
369+
_, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
370+
self.assertTrue(
371+
torch.allclose(output, output3),
372+
"Same seed should produce identical outputs",
373+
)
374+
375+
# Check that all values are in [0, 1) range
376+
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
377+
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")
378+
379+
def test_hl_rand_2d(self):
380+
@helion.kernel
381+
def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor:
382+
output = torch.zeros_like(x)
383+
m, n = x.shape
384+
for tile_m, tile_n in hl.tile([m, n]):
385+
output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed)
386+
return output
387+
388+
x_small = torch.ones(128, 128, device=DEVICE)
389+
_, output = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
390+
_, output2 = code_and_output(rand_kernel_tiled_2d, (x_small, 1337))
391+
392+
self.assertFalse(
393+
torch.allclose(output, output2),
394+
"Different seeds should produce different outputs",
395+
)
396+
397+
_, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
398+
self.assertTrue(
399+
torch.allclose(output, output3),
400+
"Same seed should produce identical outputs",
401+
)
402+
403+
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
404+
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")
405+
406+
def test_hl_rand_3d(self):
407+
@helion.kernel
408+
def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor:
409+
output = torch.zeros_like(x)
410+
b, m, n = x.shape
411+
for tile_b, tile_m, tile_n in hl.tile([b, m, n]):
412+
output[tile_b, tile_m, tile_n] = hl.rand(
413+
[tile_b, tile_m, tile_n], seed=seed
414+
)
415+
return output
416+
417+
x_small = torch.ones(16, 32, 64, device=DEVICE)
418+
_, output = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
419+
_, output2 = code_and_output(rand_kernel_tiled_3d, (x_small, 1337))
420+
421+
self.assertFalse(
422+
torch.allclose(output, output2),
423+
"Different seeds should produce different outputs",
424+
)
425+
426+
_, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
427+
self.assertTrue(
428+
torch.allclose(output, output3),
429+
"Same seed should produce identical outputs",
430+
)
431+
432+
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
433+
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")
434+
435+
# Check distribution properties
436+
mean_val = output.mean().item()
437+
self.assertTrue(
438+
0.4 < mean_val < 0.6,
439+
f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution",
440+
)
441+
351442

352443
if __name__ == "__main__":
353444
unittest.main()

0 commit comments

Comments
 (0)