diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c9ade02a00..e309c9f485 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -77,6 +77,7 @@ register_infer_shape, switch, tensor_copy, + tile, zeros, zeros_like, ) @@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node): return [ret] +@register_canonicalize +@node_rewriter([Join]) +def local_join_to_repeat(fgraph, node): + """Join(axis, x, x, x, ...) -> tile(x, reps) + + When the same tensor is concatenated multiple times along an axis, + replace with a single tile operation which is more efficient. + + Examples + -------- + join(0, x, x, x) -> tile(x, (3, 1, 1, ...)) + join(1, x, x) -> tile(x, (1, 2, 1, ...)) + """ + # Extract axis and the tensors being joined + axis, *tensors = node.inputs + + # Optimization only applies when axis is constant + if not isinstance(axis, Constant): + return None + + # Extract the Python integer from the constant + axis_val = axis.data + + # Need at least 2 tensors to consider optimization + if len(tensors) <= 1: + return + + # Check if all tensors are identical + if not all(t == tensors[0] for t in tensors[1:]): + return + + n_reps = len(tensors) + first_tensor = tensors[0] + ndim = first_tensor.ndim + + # Build reps tuple to repeat only along the join axis + # For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1) + # This directly concatenates n_reps copies along axis_val + reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim)) + + result = tile(first_tensor, reps) + + # Preserve debugging information + copy_stack_trace(node.outputs[0], result) + return [result] + + @register_specialize @register_canonicalize @register_useless diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index d9eb2ad7ad..cad4e7b606 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1237,33 +1237,98 @@ def test_local_join_1(): assert len([n for n in e if isinstance(n.op, Join)]) == 0 assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test we don't apply when their is 2 inputs - s = join(1, a, a) + # Test that join with 2 different inputs remains (not optimized away) + s = join(1, a, a[:, ::-1]) f = function([a], s, mode=rewrite_mode) - val = f([[1]]) - assert np.all(val == [[1]]) + val = f([[1, 2]]) + assert np.all(val == [[1, 2, 2, 1]]) # joined along axis 1 e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert len([n for n in e if isinstance(n.op, Join)]) == 1 # join remains assert f.maker.fgraph.outputs[0].dtype == config.floatX +def test_local_join_to_tile(): + """Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k. + + This optimization applies whenever we concatenate the *same* tensor multiple + times along a given axis. It replaces the Join/concatenate with a Tile op. + """ + + # ---- Case 1: joining same vector along axis 0 ---- + x = vector("x") + s = join(0, x, x, x) # (3n,) + f = function([x], s, mode=rewrite_mode) + + test_val = np.array([1.0, 2.0], dtype=config.floatX) + result = f(test_val) + expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX) + assert np.allclose(result, expected) + + # Join should be optimized away + ops = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in ops) + + # ---- Case 2: joining same matrix along axis 0 ---- + a = matrix("a") + s = join(0, a, a) # (2m, n) + f = function([a], s, mode=rewrite_mode) + + test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + result = f(test_mat) + expected = np.vstack([test_mat, test_mat]) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in ops) + + # ---- Case 3: joining same matrix along axis 1 ---- + s = join(1, a, a, a) # (m, 3n) + f = function([a], s, mode=rewrite_mode) + + result = f(test_mat) + expected = np.hstack([test_mat, test_mat, test_mat]) + assert np.allclose(result, expected) + + ops = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in ops) + + # ---- Case 4: different tensors -> should NOT optimize ---- + y = vector("y") + s = join(0, x, y) # inputs differ + f = function([x, y], s, mode=rewrite_mode) + + test_vec1 = np.array([1.0, 2.0], dtype=config.floatX) + test_vec2 = np.array([3.0, 4.0], dtype=config.floatX) + result = f(test_vec1, test_vec2) + expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX) + assert np.allclose(result, expected) + + # Join should still be present since inputs aren't identical + ops = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in ops) + + def test_local_join_empty(): - # Vector case + # Vector case - empty tensors should be removed empty_vec = np.asarray([], dtype=config.floatX) vec = vector("vec") - s = pt.join(0, vec, vec, empty_vec) + s = pt.join(0, vec, vec[::-1], empty_vec) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(0, vec, vec)]) assert new_s.dtype == s.dtype + # Verify that empty tensors are removed from the join + expected = pt.join(0, vec, vec[::-1]) + assert equal_computations([new_s], [expected]) - # Matrix case + # Matrix case - empty tensors should be removed empty_mat = np.zeros((2, 0), dtype=config.floatX) empty_sym_mat = matrix("m", shape=(2, 0)) mat = matrix("mat", shape=(2, 10)) - s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) + s = join(1, empty_mat, mat, empty_sym_mat, mat[:, ::-1]) new_s = rewrite_graph(s) - assert equal_computations([new_s], [join(1, mat, mat, mat)]) assert new_s.dtype == s.dtype + # Verify that empty tensors are removed from the join + expected = join(1, mat, mat[:, ::-1]) + assert equal_computations([new_s], [expected]) # Join can be completely removed, but casting and specify_shape are propagated int_mat = matrix("int_mat", dtype=int)