Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
register_infer_shape,
switch,
tensor_copy,
tile,
zeros,
zeros_like,
)
Expand Down Expand Up @@ -913,14 +914,15 @@ def local_join_make_vector(fgraph, node):
@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)
"""Join(axis, x, x, x, ...) -> tile(x, reps)

When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.
When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.

Examples
--------
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs
Expand All @@ -940,19 +942,19 @@ def local_join_to_repeat(fgraph, node):
if not all(t == tensors[0] for t in tensors[1:]):
return

# Only optimize if the tensor has size 1 along the join axis
n_reps = len(tensors)
first_tensor = tensors[0]
if first_tensor.type.shape[axis_val] != 1:
return
ndim = first_tensor.ndim

# Replace with repeat operation
from pytensor.tensor.extra_ops import repeat
# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))

result = repeat(first_tensor, len(tensors), axis_val)
result = tile(first_tensor, reps)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


Expand Down
121 changes: 73 additions & 48 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,133 +1237,158 @@ def test_local_join_1():
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.fgraph.outputs[0].dtype == config.floatX

# test we don't apply when their is 2 inputs
# test that join with 2 identical inputs now gets optimized to tile
s = join(1, a, a)
f = function([a], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
assert np.all(val == [[1, 1]]) # joined along axis 1
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert len([n for n in e if isinstance(n.op, Join)]) == 0 # optimized away
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)
def test_local_join_to_tile():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test can be simplified a lot. It carries baggage from the first iterations that were using repeat and only worked when joining over dummy dimensions. Let's start over and test fewer more clean cases.

"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.

This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis (no need for size-1 dims / ExpandDims). It replaces
the Join/concatenate with a single Tile op.
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
# Helpers to inspect the graph without depending on concrete Op classes
def count_op(ops, cls_name):
return sum(1 for n in ops if n.op.__class__.__name__ == cls_name)

def has_no_join(fgraph_ops):
return count_op(fgraph_ops, "Join") == 0

# ---- Case 1: vector expanded to (1, n), concat along axis 0 ----
x = vector("x")
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
x_expanded = x[None] # (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1
assert has_no_join(ops)
# Note: Tile may be further optimized to Alloc, so we don't check for it

# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
# ---- Case 2: matrix, concat along new leading axis ----
a = matrix("a") # (m, n)
a_expanded = a[None, :, :] # (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.array([test_mat, test_mat, test_mat, test_mat])
expected = np.array([test_mat, test_mat, test_mat, test_mat], dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1
assert has_no_join(ops)

# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
# ---- Case 3: matrix, expand along axis 1 then concat ----
a_expanded_ax1 = a[:, None, :] # (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
expected = np.array(
[[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]],
dtype=config.floatX,
)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1
assert has_no_join(ops)

# Test that it does NOT apply when tensors are different
# ---- Case 4: different tensors -> should NOT optimize ----
y = vector("y")
s = join(0, x[None], y[None])
s = join(0, x[None], y[None]) # inputs differ
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
# Join should still be present since inputs aren't identical
assert count_op(ops, "Join") == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
# ---- Case 5: plain concat without ExpandDims should now optimize ----
s = join(0, x, x, x) # (3n,)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
assert has_no_join(ops)

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x[None], x[None], x[None], x[None], x[None])
# ---- Case 6: larger repetition count ----
s = join(0, x[None], x[None], x[None], x[None], x[None]) # (5, n)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1
assert has_no_join(ops)


def test_local_join_empty():
# Vector case
# Vector case - empty tensors should be removed and join optimized
empty_vec = np.asarray([], dtype=config.floatX)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
# Verify dtype is preserved
assert new_s.dtype == s.dtype
# Verify no Join in the optimized graph
f = function([vec], new_s, mode=rewrite_mode)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
# Verify numerical correctness
test_vec = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_vec)
expected = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Matrix case
# Matrix case - empty tensors should be removed and join optimized
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
# Verify dtype is preserved
assert new_s.dtype == s.dtype
# Verify no Join in the optimized graph
f = function([mat], new_s, mode=rewrite_mode)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
# Verify numerical correctness
test_mat = np.array(
[
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
],
dtype=config.floatX,
)
result = f(test_mat)
expected = np.concatenate([test_mat, test_mat, test_mat], axis=1)
assert np.allclose(result, expected)

# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
Expand Down