Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
register_infer_shape,
switch,
tensor_copy,
tile,
zeros,
zeros_like,
)
Expand Down Expand Up @@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> tile(x, reps)

When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.

Examples
--------
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Optimization only applies when axis is constant
if not isinstance(axis, Constant):
return None

# Extract the Python integer from the constant
axis_val = axis.data

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

n_reps = len(tensors)
first_tensor = tensors[0]
ndim = first_tensor.ndim

# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))

result = tile(first_tensor, reps)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
83 changes: 74 additions & 9 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,33 +1237,98 @@ def test_local_join_1():
assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.fgraph.outputs[0].dtype == config.floatX

# test we don't apply when their is 2 inputs
# test that join with 2 identical inputs now gets optimized to tile
s = join(1, a, a)
f = function([a], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
assert np.all(val == [[1, 1]]) # joined along axis 1
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert len([n for n in e if isinstance(n.op, Join)]) == 0 # optimized away
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_tile():
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.

This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis. It replaces the Join/concatenate with a Tile op.
"""

# ---- Case 1: joining same vector along axis 0 ----
x = vector("x")
s = join(0, x, x, x) # (3n,)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should be optimized away
ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 2: joining same matrix along axis 0 ----
a = matrix("a")
s = join(0, a, a) # (2m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat])
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 3: joining same matrix along axis 1 ----
s = join(1, a, a, a) # (m, 3n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)

# ---- Case 4: different tensors -> should NOT optimize ----
y = vector("y")
s = join(0, x, y) # inputs differ
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present since inputs aren't identical
ops = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in ops)


def test_local_join_empty():
# Vector case
# Vector case - empty tensors should be removed
empty_vec = np.asarray([], dtype=config.floatX)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
s = pt.join(0, vec, vec[::-1], empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype
# Compare to the expected form (without rewriting expected)
expected = pt.join(0, vec, vec[::-1])
assert equal_computations([new_s], [expected])

# Matrix case
# Matrix case - empty tensors should be removed
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
s = join(1, empty_mat, mat, empty_sym_mat, mat[:, ::-1])
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Compare to the expected form (without rewriting expected)
expected = join(1, mat, mat[:, ::-1])
assert equal_computations([new_s], [expected])

# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
Expand Down