Skip to content
48 changes: 37 additions & 11 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
Expand Down Expand Up @@ -82,7 +82,7 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, repeat
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
Expand Down Expand Up @@ -915,26 +915,52 @@ def local_join_make_vector(fgraph, node):
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times,
replace with a single repeat operation which is more efficient.
When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.

Examples
--------
concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0)
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs
axis_sym, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return
return None

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return
# Extract (and normalize) axis as Python int
try:
axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True))
except NotScalarConstantError:
return None

# Get first tensor and check if ndim is known
first = tensors[0]
ndim = first.ndim
if ndim is None:
return None

# Normalize negative axes (e.g., -1 -> ndim-1)
axis_val = axis_val % ndim

# All inputs must be structurally the same tensor
# Use equal_computations to check structural equality, not symbolic ==
for t in tensors[1:]:
if not equal_computations([t], [first]):
return None

# Only apply when size along join axis is statically 1
# (e.g., x[None] has a guaranteed 1 at that axis)
shp = first.type.shape # tuple of ints/None
if shp is None or axis_val >= len(shp) or shp[axis_val] != 1:
return None

# Replace with repeat operation
result = repeat(tensors[0], len(tensors), axis)
from pytensor.tensor.extra_ops import repeat

n = len(tensors)
result = repeat(first, n, axis=axis_val)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
Expand Down
77 changes: 48 additions & 29 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import Repeat
from pytensor.tensor.math import (
add,
bitwise_and,
Expand Down Expand Up @@ -1249,83 +1248,103 @@ def test_local_join_1():


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)"""
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)

# Test with vector - concatenate same vector 3 times along axis 0
This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
x = vector("x")
s = join(0, x, x, x)
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Repeat
# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - concatenate same matrix along axis 0
a = matrix("a")
s = join(0, a, a, a, a)
# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat, test_mat, test_mat])
expected = np.array([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - concatenate along axis 1
s = join(1, a, a)
# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat])
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test that it does NOT apply when tensors are different
b = matrix("b")
s = join(0, a, b)
f = function([a, b], s, mode=rewrite_mode)

test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX)
test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX)
result = f(test_mat1, test_mat2)
expected = np.vstack([test_mat1, test_mat2])
y = vector("y")
s = join(0, x[None], y[None])
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (not optimized to Repeat)
# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x, x, x, x, x)
s = join(0, x[None], x[None], x[None], x[None], x[None])
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.tile(test_val, 5)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1


def test_local_join_empty():
Expand Down