-
Notifications
You must be signed in to change notification settings - Fork 74
Add epilogue subtiling #948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,6 +63,7 @@ | |||||||
| from .type_propagation import _eval_binary | ||||||||
| from .type_propagation import _eval_compare | ||||||||
| from .type_propagation import _eval_unary | ||||||||
| from .utils import _use_epilogue_subtile | ||||||||
|
|
||||||||
| if TYPE_CHECKING: | ||||||||
| from collections.abc import Callable | ||||||||
|
|
@@ -1191,6 +1192,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: | |||||||
| total_load_count, loads_without_eviction_policy, store_count | ||||||||
| ) | ||||||||
|
|
||||||||
| if _use_epilogue_subtile(): | ||||||||
| for graph in device_ir.graphs: | ||||||||
| # Epilogue subtiling only for Blackwell | ||||||||
| epilogue_subtiling_pass(graph.graph, store_count) | ||||||||
|
|
||||||||
| return device_ir | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -1348,3 +1354,69 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: | |||||||
| user.args = tuple(new_args) | ||||||||
| if len(node.users) == 0: | ||||||||
| graph.erase_node(node) | ||||||||
|
|
||||||||
|
|
||||||||
| def epilogue_subtiling_pass(graph: torch.fx.Graph, store_count: int) -> None: | ||||||||
| """ | ||||||||
| Replace epilogue subtile with a tunable value. | ||||||||
| """ | ||||||||
| if store_count == 0: | ||||||||
| return | ||||||||
|
|
||||||||
| from ..autotuner.config_fragment import EnumFragment | ||||||||
| from ..autotuner.config_fragment import ListOf | ||||||||
| from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES | ||||||||
| from .inductor_lowering import PointwiseLowering | ||||||||
|
|
||||||||
| env = CompileEnvironment.current() | ||||||||
| # Register a tunable for epilogue subtile for all device stores | ||||||||
| fragment = ListOf( | ||||||||
| EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count | ||||||||
| ) | ||||||||
| env.config_spec.epilogue_subtiling = fragment | ||||||||
|
|
||||||||
| def collect_pointwise_epilogue_nodes( | ||||||||
| store_node: torch.fx.Node, | ||||||||
| ) -> dict[torch.fx.Node, None]: | ||||||||
| """Recursively collect all pointwise nodes that can be subtiled in the epilogue. | ||||||||
|
|
||||||||
| Starting from a store node, traverse backwards through all input nodes, | ||||||||
| collecting pointwise operations until we hit non-pointwise nodes. | ||||||||
| Only include pointwise nodes that have a single user to ensure they can be fused. | ||||||||
| """ | ||||||||
| # dict to preserve order | ||||||||
| pointwise_nodes = {} | ||||||||
| visited = set() | ||||||||
| stack = [store_node.args[2]] # Start with the value being stored | ||||||||
|
|
||||||||
| while stack: | ||||||||
| current = stack.pop() | ||||||||
| if current in visited or not isinstance(current, torch.fx.Node): | ||||||||
| continue | ||||||||
|
|
||||||||
| visited.add(current) | ||||||||
|
|
||||||||
| lowering = current.meta.get("lowering") | ||||||||
| # Check if this is a pointwise operation with only one user | ||||||||
| if isinstance(lowering, PointwiseLowering) and len(current.users) == 1: | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the users==1 requirement? Is this meant to ensure everything is contained in the same graph? Maybe we should check this constraint more directly. |
||||||||
| if current not in pointwise_nodes: | ||||||||
| pointwise_nodes[current] = None | ||||||||
|
Comment on lines
+1402
to
+1403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| stack.extend(current.all_input_nodes) | ||||||||
|
|
||||||||
| return pointwise_nodes | ||||||||
|
|
||||||||
| from ..language import store as store_api | ||||||||
|
|
||||||||
| stores = set() | ||||||||
|
|
||||||||
| for node in graph.nodes: | ||||||||
| if node.op == "call_function" and node.target == store_api: | ||||||||
| stores.add(node) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this used? |
||||||||
| # Collect all pointwise nodes that can be subtiled in the epilogue | ||||||||
| pointwise_nodes = collect_pointwise_epilogue_nodes(node) | ||||||||
| if pointwise_nodes: | ||||||||
| # Mark all collected pointwise nodes for epilogue subtiling | ||||||||
| for pw_node in pointwise_nodes: | ||||||||
| pw_node.meta["epilogue_subtile"] = True | ||||||||
| # Store the set of pointwise nodes in the store node's metadata | ||||||||
| node.meta["pointwise_epilogue_nodes"] = pointwise_nodes | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the fragment defnition to config spec.