Skip to content

Commit f791da0

Browse files
authored
rms_norm: get weight from function args (#664)
1 parent c3a59f2 commit f791da0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.
192192
# Benchmark Wrapper
193193
# --------------
194194
def rms_norm_tritonbench(
195-
tb_op: object, H: int, inp: torch.Tensor
195+
tb_op: object, H: int, inp: torch.Tensor, weight: torch.Tensor
196196
) -> Callable[[], torch.Tensor]:
197197
"""
198198
Wrapper for tritonbench that matches expected interface.
@@ -201,11 +201,11 @@ def rms_norm_tritonbench(
201201
tb_op: TritonBench operator instance
202202
H: Hidden dimension size
203203
inp: Input tensor
204+
weight: Weight tensor
204205
205206
Returns:
206207
Callable that returns normalized tensor
207208
"""
208-
weight = torch.ones(H, device=inp.device, dtype=inp.dtype, requires_grad=True)
209209
return lambda: rms_norm(inp, weight, eps=1e-6)
210210

211211

0 commit comments

Comments
 (0)