Skip to content

Commit 8452d52

Browse files
authored
Refactor _decorators.codegen to allow multiple backends (#1099)
1 parent a1ec064 commit 8452d52

23 files changed

+86
-62
lines changed

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7979
super().__init__()
8080
self.device = device
8181
self.settings = settings
82+
# TODO(jansel): make backend configurable
83+
self.backend = "triton"
8284
self.shape_env = ShapeEnv(
8385
specialize_zero_one=True,
8486
duck_shape=False,

helion/_compiler/generate_ast.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,13 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
272272
assert fn_node._type_info is not None
273273
fn = fn_node._type_info.proxy()
274274
assert is_api_func(fn)
275-
assert fn._codegen is not None
275+
env = CompileEnvironment.current()
276+
codegen_fn = fn._codegen.get(env.backend)
277+
if codegen_fn is None:
278+
raise exc.BackendImplementationMissing(
279+
env.backend,
280+
f"codegen for API function {fn.__qualname__}",
281+
)
276282
bound = fn._signature.bind(*args, **kwargs)
277283
bound.apply_defaults()
278284

@@ -285,7 +291,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
285291
ast_args=None, # pyright: ignore[reportArgumentType]
286292
)
287293

288-
fn._codegen(state)
294+
codegen_fn(state)
289295
assert node._root_id is not None
290296
codegen_call_with_graph(
291297
self,
@@ -376,11 +382,15 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
376382
[x.from_config(self.device_function.config) for x in block_infos]
377383
)
378384
)
379-
elif (
380-
isinstance(fn_type_info := func_node._type_info, CallableType)
381-
and is_api_func(api := fn_type_info.value)
382-
and api._codegen is not None
385+
elif isinstance(fn_type_info := func_node._type_info, CallableType) and (
386+
is_api_func(api := fn_type_info.value)
383387
):
388+
codegen_fn = api._codegen.get(env.backend)
389+
if codegen_fn is None:
390+
raise exc.BackendImplementationMissing(
391+
env.backend,
392+
f"codegen for API function {api.__qualname__}",
393+
)
384394
ast_args = []
385395
ast_kwargs = {}
386396
proxy_args = []
@@ -401,7 +411,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
401411
proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs)
402412
ast_params.apply_defaults()
403413
proxy_params.apply_defaults()
404-
return api._codegen( # pyright: ignore[reportReturnType]
414+
return codegen_fn( # pyright: ignore[reportReturnType]
405415
CodegenState(
406416
self,
407417
None,

helion/_compiler/inductor_lowering.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
626626
self.buffer.data.inner_fn(indices, reduction_indices)
627627
)
628628

629-
from .. import exc
630629
from .generate_ast import GenerateAST
631630

632631
if not isinstance(ctx.cg, GenerateAST):
@@ -744,14 +743,19 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
744743
ast_args = [*map_arg(node.args, lambda arg: ctx.env[arg])]
745744
proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])]
746745

747-
assert self.api_func._codegen is not None
748-
from .. import exc
746+
env = CompileEnvironment.current()
747+
codegen_fn = self.api_func._codegen.get(env.backend)
748+
if codegen_fn is None:
749+
raise exc.BackendImplementationMissing(
750+
env.backend,
751+
f"codegen for API function {self.api_func.__qualname__}",
752+
)
749753
from .generate_ast import GenerateAST
750754

751755
if not isinstance(ctx.cg, GenerateAST):
752756
raise exc.NotAllowedInHelperFunction
753757

754-
return self.api_func._codegen(
758+
return codegen_fn(
755759
CodegenState(
756760
ctx.cg,
757761
fx_node=node,

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class AutotuneError(BaseError):
5252
message = "{0}"
5353

5454

55+
class BackendImplementationMissing(BaseError):
56+
message = "Backend '{backend}' is missing required implementation: {detail}"
57+
58+
def __init__(self, backend: str, detail: str) -> None:
59+
super().__init__(backend=backend, detail=detail)
60+
61+
5562
class CacheAssertionError(BaseError):
5663
message = "Expected cache hit for kernel '{0}', but got cache miss. See stderr for diagnostic information."
5764

helion/language/_decorators.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class APIFunc(Protocol):
5555
_cache_type: Whether to cache the type information for repeated calls.
5656
_type_function: A callable that determines the return type of this function
5757
during type propagation phase.
58-
_codegen: A callable that generates the device code for this function.
58+
_codegen: Mapping of backend names to callables that generate device code.
5959
_fake_fn: A callable that provides a "fake" implementation used during
6060
tracing and compilation.
6161
_prepare_args: A callable that preprocesses the arguments before they're
@@ -72,7 +72,7 @@ class APIFunc(Protocol):
7272
_tiles_as_sizes: bool
7373
_cache_type: bool
7474
_type_function: Callable[..., TypeInfo] | None
75-
_codegen: Callable[[CodegenState], object] | None
75+
_codegen: dict[str, Callable[[CodegenState], object]]
7676
_fake_fn: Callable[..., object] | None
7777
_prepare_args: Callable[[tuple[object, ...]], tuple[object, ...]]
7878
_get_masked_value: Callable[[torch.fx.Node], float | bool | None] | None
@@ -189,7 +189,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
189189
api._prepare_args = no_op_prepare_args
190190
api._cache_type = cache_type
191191
api._type_function = None
192-
api._codegen = None
192+
api._codegen = {}
193193
api._fake_fn = None
194194
api._get_masked_value = None
195195
api._to_device_ir = None
@@ -254,15 +254,16 @@ def _impl(
254254

255255
def codegen(
256256
original_fn: Callable[..., object],
257+
backend: str,
257258
) -> _NoReturnDecorator[object]:
258259
def _impl(codegen_fn: Callable[[CodegenState], object]) -> Callable[..., Never]:
259260
assert is_api_func(original_fn), (
260261
f"{type_propagation.__qualname__} can only be used on API functions"
261262
)
262-
assert original_fn._codegen is None, (
263-
"codegen can only be used once per function"
263+
assert backend not in original_fn._codegen, (
264+
f"codegen already registered for backend {backend!r}"
264265
)
265-
original_fn._codegen = codegen_fn
266+
original_fn._codegen[backend] = codegen_fn
266267
return _no_call
267268

268269
return _impl # pyright: ignore[reportReturnType]

helion/language/_tracing_ops.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_symnode(debug_name: str) -> int:
4040
raise AssertionError("this should never be called")
4141

4242

43-
@_decorators.codegen(_get_symnode)
43+
@_decorators.codegen(_get_symnode, "triton")
4444
def _(state: CodegenState) -> ast.AST:
4545
val = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]
4646

@@ -69,7 +69,7 @@ def _host_tensor(debug_name: str) -> torch.Tensor:
6969
raise AssertionError("this should never be called")
7070

7171

72-
@_decorators.codegen(_host_tensor)
72+
@_decorators.codegen(_host_tensor, "triton")
7373
def _(state: CodegenState) -> ast.AST:
7474
return expr_from_string("_host_tensor") # should be unused
7575

@@ -83,7 +83,7 @@ def _for_loop(
8383
raise AssertionError("this should never be called")
8484

8585

86-
@_decorators.codegen(_for_loop)
86+
@_decorators.codegen(_for_loop, "triton")
8787
def _(state: CodegenState) -> None:
8888
return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
8989

@@ -100,7 +100,7 @@ def _while_loop(
100100
raise AssertionError("this should never be called")
101101

102102

103-
@_decorators.codegen(_while_loop)
103+
@_decorators.codegen(_while_loop, "triton")
104104
def _(state: CodegenState) -> None:
105105
return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
106106

@@ -112,7 +112,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]:
112112
raise AssertionError("this should never be called")
113113

114114

115-
@_decorators.codegen(_if)
115+
@_decorators.codegen(_if, "triton")
116116
def _(state: CodegenState) -> None:
117117
return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]
118118

@@ -139,7 +139,7 @@ def _(lhs: object, rhs: object) -> object:
139139
return torch.empty_like(lhs)
140140

141141

142-
@_decorators.codegen(_phi)
142+
@_decorators.codegen(_phi, "triton")
143143
def _(state: CodegenState) -> ast.Name:
144144
lhs = state.ast_arg(0)
145145
assert isinstance(lhs, ast.Name), lhs
@@ -180,7 +180,7 @@ def _and(left: object, right: object) -> object:
180180
raise NotInsideKernel
181181

182182

183-
@_decorators.codegen(_and)
183+
@_decorators.codegen(_and, "triton")
184184
def _(state: CodegenState) -> None:
185185
return expr_from_string(
186186
"{lhs} and {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1)
@@ -233,7 +233,7 @@ def _(left: object, right: object) -> object:
233233
return env.shape_env.create_unbacked_symbool()
234234

235235

236-
@_decorators.codegen(_or)
236+
@_decorators.codegen(_or, "triton")
237237
def _(state: CodegenState) -> None:
238238
return expr_from_string(
239239
"{lhs} or {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1)
@@ -258,7 +258,7 @@ def _(left: object) -> object:
258258
return env.shape_env.create_unbacked_symbool()
259259

260260

261-
@_decorators.codegen(_not)
261+
@_decorators.codegen(_not, "triton")
262262
def _(state: CodegenState) -> ast.AST:
263263
return expr_from_string(
264264
"not {lhs}",
@@ -289,7 +289,7 @@ def _(tensor: torch.Tensor, other: float) -> torch.Tensor:
289289
return torch.empty_like(tensor)
290290

291291

292-
@_decorators.codegen(_mask_to)
292+
@_decorators.codegen(_mask_to, "triton")
293293
def _(state: CodegenState) -> ast.AST:
294294
tensor = state.proxy_arg(0)
295295
assert isinstance(tensor, torch.Tensor)
@@ -351,7 +351,7 @@ def _(value: _T) -> _T:
351351
raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}")
352352

353353

354-
@_decorators.codegen(_new_var)
354+
@_decorators.codegen(_new_var, "triton")
355355
def _(state: CodegenState) -> ast.AST:
356356
value = state.ast_arg(0)
357357
assert isinstance(value, ast.AST)

helion/language/atomic_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
263263
return prev
264264

265265

266-
@_decorators.codegen(atomic_add)
266+
@_decorators.codegen(atomic_add, "triton")
267267
def _(state: CodegenState) -> ast.AST:
268268
value_expr = state.ast_args[2]
269269
return _codegen_common("atomic_add", state, _to_ast_values([value_expr]))
@@ -343,7 +343,7 @@ def _(
343343
return prev
344344

345345

346-
@_decorators.codegen(atomic_xchg)
346+
@_decorators.codegen(atomic_xchg, "triton")
347347
def _(state: CodegenState) -> ast.AST:
348348
value_expr = state.ast_args[2]
349349
return _codegen_common("atomic_xchg", state, _to_ast_values([value_expr]))
@@ -420,7 +420,7 @@ def _(
420420
return prev
421421

422422

423-
@_decorators.codegen(atomic_and)
423+
@_decorators.codegen(atomic_and, "triton")
424424
def _(state: CodegenState) -> ast.AST:
425425
value_expr = state.ast_args[2]
426426
return _codegen_common("atomic_and", state, _to_ast_values([value_expr]))
@@ -494,7 +494,7 @@ def _(
494494
return prev
495495

496496

497-
@_decorators.codegen(atomic_or)
497+
@_decorators.codegen(atomic_or, "triton")
498498
def _(state: CodegenState) -> ast.AST:
499499
value_expr = state.ast_args[2]
500500
return _codegen_common("atomic_or", state, _to_ast_values([value_expr]))
@@ -568,7 +568,7 @@ def _(
568568
return prev
569569

570570

571-
@_decorators.codegen(atomic_xor)
571+
@_decorators.codegen(atomic_xor, "triton")
572572
def _(state: CodegenState) -> ast.AST:
573573
value_expr = state.ast_args[2]
574574
return _codegen_common("atomic_xor", state, _to_ast_values([value_expr]))
@@ -634,7 +634,7 @@ def apply(t: torch.Tensor, idx: tuple, v: object) -> None:
634634
_ref_apply(target, index, apply, value)
635635

636636

637-
@_decorators.codegen(atomic_max)
637+
@_decorators.codegen(atomic_max, "triton")
638638
def _(state: CodegenState) -> ast.AST:
639639
value_expr = state.ast_args[2]
640640
return _codegen_common("atomic_max", state, _to_ast_values([value_expr]))
@@ -709,7 +709,7 @@ def _(
709709
return prev
710710

711711

712-
@_decorators.codegen(atomic_min)
712+
@_decorators.codegen(atomic_min, "triton")
713713
def _(state: CodegenState) -> ast.AST:
714714
value_expr = state.ast_args[2]
715715
return _codegen_common("atomic_min", state, _to_ast_values([value_expr]))
@@ -810,7 +810,7 @@ def _(
810810
return prev
811811

812812

813-
@_decorators.codegen(atomic_cas)
813+
@_decorators.codegen(atomic_cas, "triton")
814814
def _(state: CodegenState) -> ast.AST:
815815
exp_expr = state.ast_args[2]
816816
val_expr = state.ast_args[3]

helion/language/constexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def handle_symint(symint: torch.SymInt) -> int:
9494
return TypeInfo.from_example(specialized, origin=origin)
9595

9696

97-
@_decorators.codegen(specialize)
97+
@_decorators.codegen(specialize, "triton")
9898
def _(state: CodegenState) -> ast.AST:
9999
value = state.proxy_arg(0)
100100
specialized = _convert_specializable(value)

helion/language/creation_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _full_fake(
128128
)
129129

130130

131-
@_decorators.codegen(full)
131+
@_decorators.codegen(full, "triton")
132132
def _full_codegen(state: CodegenState) -> ast.AST:
133133
fake_value = state.fake_value
134134
assert isinstance(fake_value, torch.Tensor)

helion/language/debug_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _(*args: object, origin: Origin, **kwargs: object) -> TypeInfo:
4747
return LiteralType(origin, None)
4848

4949

50-
@_decorators.codegen(breakpoint)
50+
@_decorators.codegen(breakpoint, "triton")
5151
def _(state: CodegenState) -> None:
5252
state.add_statement("breakpoint()")
5353

0 commit comments

Comments
 (0)