@@ -40,7 +40,7 @@ def _get_symnode(debug_name: str) -> int:
4040 raise AssertionError ("this should never be called" )
4141
4242
43- @_decorators .codegen (_get_symnode )
43+ @_decorators .codegen (_get_symnode , "triton" )
4444def _ (state : CodegenState ) -> ast .AST :
4545 val = state .fx_node .meta ["val" ] # pyright: ignore[reportOptionalMemberAccess]
4646
@@ -69,7 +69,7 @@ def _host_tensor(debug_name: str) -> torch.Tensor:
6969 raise AssertionError ("this should never be called" )
7070
7171
72- @_decorators .codegen (_host_tensor )
72+ @_decorators .codegen (_host_tensor , "triton" )
7373def _ (state : CodegenState ) -> ast .AST :
7474 return expr_from_string ("_host_tensor" ) # should be unused
7575
@@ -83,7 +83,7 @@ def _for_loop(
8383 raise AssertionError ("this should never be called" )
8484
8585
86- @_decorators .codegen (_for_loop )
86+ @_decorators .codegen (_for_loop , "triton" )
8787def _ (state : CodegenState ) -> None :
8888 return HostFunction .current ().device_ir .graphs [state .proxy_arg (0 )].codegen (state ) # pyright: ignore[reportArgumentType,reportCallIssue]
8989
@@ -100,7 +100,7 @@ def _while_loop(
100100 raise AssertionError ("this should never be called" )
101101
102102
103- @_decorators .codegen (_while_loop )
103+ @_decorators .codegen (_while_loop , "triton" )
104104def _ (state : CodegenState ) -> None :
105105 return HostFunction .current ().device_ir .graphs [state .proxy_arg (1 )].codegen (state ) # pyright: ignore[reportArgumentType,reportCallIssue]
106106
@@ -112,7 +112,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]:
112112 raise AssertionError ("this should never be called" )
113113
114114
115- @_decorators .codegen (_if )
115+ @_decorators .codegen (_if , "triton" )
116116def _ (state : CodegenState ) -> None :
117117 return HostFunction .current ().device_ir .graphs [state .proxy_arg (1 )].codegen (state ) # pyright: ignore[reportArgumentType,reportCallIssue]
118118
@@ -139,7 +139,7 @@ def _(lhs: object, rhs: object) -> object:
139139 return torch .empty_like (lhs )
140140
141141
142- @_decorators .codegen (_phi )
142+ @_decorators .codegen (_phi , "triton" )
143143def _ (state : CodegenState ) -> ast .Name :
144144 lhs = state .ast_arg (0 )
145145 assert isinstance (lhs , ast .Name ), lhs
@@ -180,7 +180,7 @@ def _and(left: object, right: object) -> object:
180180 raise NotInsideKernel
181181
182182
183- @_decorators .codegen (_and )
183+ @_decorators .codegen (_and , "triton" )
184184def _ (state : CodegenState ) -> None :
185185 return expr_from_string (
186186 "{lhs} and {rhs}" , lhs = state .ast_arg (0 ), rhs = state .ast_arg (1 )
@@ -233,7 +233,7 @@ def _(left: object, right: object) -> object:
233233 return env .shape_env .create_unbacked_symbool ()
234234
235235
236- @_decorators .codegen (_or )
236+ @_decorators .codegen (_or , "triton" )
237237def _ (state : CodegenState ) -> None :
238238 return expr_from_string (
239239 "{lhs} or {rhs}" , lhs = state .ast_arg (0 ), rhs = state .ast_arg (1 )
@@ -258,7 +258,7 @@ def _(left: object) -> object:
258258 return env .shape_env .create_unbacked_symbool ()
259259
260260
261- @_decorators .codegen (_not )
261+ @_decorators .codegen (_not , "triton" )
262262def _ (state : CodegenState ) -> ast .AST :
263263 return expr_from_string (
264264 "not {lhs}" ,
@@ -289,7 +289,7 @@ def _(tensor: torch.Tensor, other: float) -> torch.Tensor:
289289 return torch .empty_like (tensor )
290290
291291
292- @_decorators .codegen (_mask_to )
292+ @_decorators .codegen (_mask_to , "triton" )
293293def _ (state : CodegenState ) -> ast .AST :
294294 tensor = state .proxy_arg (0 )
295295 assert isinstance (tensor , torch .Tensor )
@@ -351,7 +351,7 @@ def _(value: _T) -> _T:
351351 raise NotImplementedError (f"Unsupported type for _new_var: { type (value )} " )
352352
353353
354- @_decorators .codegen (_new_var )
354+ @_decorators .codegen (_new_var , "triton" )
355355def _ (state : CodegenState ) -> ast .AST :
356356 value = state .ast_arg (0 )
357357 assert isinstance (value , ast .AST )
0 commit comments