-
Notifications
You must be signed in to change notification settings - Fork 150
WIP: Add rewrite to fuse nested BlockDiag Ops #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,36 @@ | |
| solve_triangular, | ||
| ) | ||
|
|
||
| from pytensor.tensor.slinalg import BlockDiagonal | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) | ||
|
|
||
|
|
||
| from pytensor.tensor.slinalg import BlockDiagonal | ||
| from pytensor.graph import Apply | ||
|
|
||
| def fuse_blockdiagonal(node): | ||
|
||
| # Only process if this node is a BlockDiagonal | ||
| if not isinstance(node.owner.op, BlockDiagonal): | ||
| return node | ||
|
||
|
|
||
| new_inputs = [] | ||
| changed = False | ||
| for inp in node.owner.inputs: | ||
| # If input is itself a BlockDiagonal, flatten its inputs | ||
| if inp.owner and isinstance(inp.owner.op, BlockDiagonal): | ||
| new_inputs.extend(inp.owner.inputs) | ||
| changed = True | ||
| else: | ||
| new_inputs.append(inp) | ||
|
|
||
| if changed: | ||
| # Return a new fused BlockDiagonal with all inputs | ||
| return BlockDiagonal(len(new_inputs))(*new_inputs) | ||
| return node | ||
|
||
|
|
||
|
|
||
| def is_matrix_transpose(x: TensorVariable) -> bool: | ||
| """Check if a variable corresponds to a transpose of the last two axes""" | ||
| node = x.owner | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,50 @@ | |
| from tests import unittest_tools as utt | ||
| from tests.test_rop import break_op | ||
|
|
||
| from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal | ||
|
|
||
|
|
||
| def test_nested_blockdiag_fusion(): | ||
| # Create matrix variables | ||
| x = pt.matrix("x") | ||
|
||
| y = pt.matrix("y") | ||
| z = pt.matrix("z") | ||
|
|
||
| # Nested BlockDiagonal | ||
| inner = BlockDiagonal(2)(x, y) | ||
| outer = BlockDiagonal(2)(inner, z) | ||
|
|
||
| # Count number of BlockDiagonal ops before fusion | ||
| nodes_before = ancestors([outer]) | ||
| initial_count = sum( | ||
| 1 for node in nodes_before | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| ) | ||
| assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" | ||
|
||
|
|
||
| # Apply the rewrite | ||
| fused = fuse_blockdiagonal(outer) | ||
|
||
|
|
||
| # Count number of BlockDiagonal ops after fusion | ||
| nodes_after = ancestors([fused]) | ||
|
||
| fused_count = sum( | ||
| 1 for node in nodes_after | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| ) | ||
| assert fused_count == 1, "Nested BlockDiagonal ops were not fused" | ||
|
||
|
|
||
| # Check that all original inputs are preserved | ||
| fused_inputs = [ | ||
| inp | ||
| for node in ancestors([fused]) | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| for inp in node.owner.inputs | ||
| ] | ||
| assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def test_matrix_inverse_rop_lop(): | ||
| rtol = 1e-7 if config.floatX == "float64" else 1e-5 | ||
| mx = matrix("mx") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you have
pre-commitand you've donepre-commit installin your dev environment. You have doubled imports and other issues this tool with help you check.