Skip to content

Commit f25584f

Browse files
committed
Planning: Rule to push projection through join into TableScan
A new rule `PushProjectionThroughJoinIntoTableScan` is introduced to push projections that appear above a join down to the table scans on either side of the join. This optimization is particularly beneficial for cross-connector joins where the join itself cannot be pushed down to the connectors. The rule applies when: - All projection expressions are deterministic. - Each projection expression references columns from only one side of the join. - For outer joins, projections on the non-preserved side are not pushed to maintain correctness. This transformation creates new `ProjectNode`s on each side of the join, potentially followed by the original `ProjectNode` if some expressions could not be pushed down. The rule also ensures that symbols required by the join criteria and filter, as well as the original project's output symbols, are preserved through identity projections on the respective sides.
1 parent 97f1088 commit f25584f

File tree

5 files changed

+613
-14
lines changed

5 files changed

+613
-14
lines changed

core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
import io.trino.sql.planner.iterative.rule.PushPredicateThroughProjectIntoWindow;
177177
import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan;
178178
import io.trino.sql.planner.iterative.rule.PushProjectionThroughExchange;
179+
import io.trino.sql.planner.iterative.rule.PushProjectionThroughJoinIntoTableScan;
179180
import io.trino.sql.planner.iterative.rule.PushProjectionThroughUnion;
180181
import io.trino.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
181182
import io.trino.sql.planner.iterative.rule.PushSampleIntoTableScan;
@@ -666,6 +667,7 @@ public PlanOptimizers(
666667
ImmutableSet.<Rule<?>>builder()
667668
.addAll(projectionPushdownRules)
668669
.add(new PushProjectionIntoTableScan(plannerContext, scalarStatsCalculator))
670+
.add(new PushProjectionThroughJoinIntoTableScan())
669671
.build());
670672

671673
builder.add(
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.planner.iterative.rule;
15+
16+
import com.google.common.collect.ImmutableSet;
17+
import io.trino.matching.Capture;
18+
import io.trino.matching.Captures;
19+
import io.trino.matching.Pattern;
20+
import io.trino.sql.ir.Expression;
21+
import io.trino.sql.ir.Reference;
22+
import io.trino.sql.planner.DeterminismEvaluator;
23+
import io.trino.sql.planner.PlanNodeIdAllocator;
24+
import io.trino.sql.planner.Symbol;
25+
import io.trino.sql.planner.SymbolsExtractor;
26+
import io.trino.sql.planner.iterative.Rule;
27+
import io.trino.sql.planner.plan.Assignments;
28+
import io.trino.sql.planner.plan.JoinNode;
29+
import io.trino.sql.planner.plan.JoinType;
30+
import io.trino.sql.planner.plan.PlanNode;
31+
import io.trino.sql.planner.plan.ProjectNode;
32+
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.Set;
36+
import java.util.stream.Stream;
37+
38+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
39+
import static io.trino.matching.Capture.newCapture;
40+
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
41+
import static io.trino.sql.planner.plan.Patterns.join;
42+
import static io.trino.sql.planner.plan.Patterns.project;
43+
import static io.trino.sql.planner.plan.Patterns.source;
44+
45+
/**
46+
* Pushes projections that appear above a join down to the table scans on either side of the join.
47+
* This is particularly useful for cross-connector joins where the join cannot be pushed down, but
48+
* individual projections on each side can still be pushed to their respective connectors.
49+
*
50+
* <pre>
51+
* Transforms:
52+
* Project(x := f(a), y := g(b))
53+
* Join(a = b)
54+
* TableScan(a, ...)
55+
* TableScan(b, ...)
56+
*
57+
* Into:
58+
* Project(x, y) -- identity projections
59+
* Join(a = b)
60+
* Project(x := f(a), a) -- pushed down
61+
* TableScan(a, ...)
62+
* Project(y := g(b), b) -- pushed down
63+
* TableScan(b, ...)
64+
* </pre>
65+
*
66+
* <p>The rule only applies when: - All projection expressions are deterministic - Each projection
67+
* expression references columns from only one side of the join - For outer joins, projections on
68+
* the non-preserved side are not pushed (to maintain correctness)
69+
*/
70+
public class PushProjectionThroughJoinIntoTableScan
71+
implements Rule<ProjectNode>
72+
{
73+
private static final Capture<JoinNode> JOIN = newCapture();
74+
75+
private static final Pattern<ProjectNode> PATTERN = project().with(source().matching(join().capturedAs(JOIN)));
76+
77+
@Override
78+
public Pattern<ProjectNode> getPattern()
79+
{
80+
return PATTERN;
81+
}
82+
83+
@Override
84+
public Result apply(ProjectNode project, Captures captures, Context context)
85+
{
86+
JoinNode join = captures.get(JOIN);
87+
88+
// Only apply to deterministic projections
89+
if (!project.getAssignments().expressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) {
90+
return Result.empty();
91+
}
92+
93+
// Skip if all projections are identity - nothing to push down
94+
if (project.getAssignments().isIdentity()) {
95+
return Result.empty();
96+
}
97+
98+
Set<Symbol> joinCriteriaSymbols = join.getCriteria().stream().flatMap(criteria -> Stream.of(criteria.getLeft(), criteria.getRight())).collect(toImmutableSet());
99+
Set<Symbol> joinFilterSymbols = join.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of());
100+
Set<Symbol> joinRequiredSymbols = ImmutableSet.<Symbol>builder().addAll(joinCriteriaSymbols).addAll(joinFilterSymbols).build();
101+
102+
// Separate projections by which side of the join they reference
103+
Assignments.Builder leftProjections = Assignments.builder();
104+
Assignments.Builder rightProjections = Assignments.builder();
105+
Assignments.Builder remainingProjections = Assignments.builder();
106+
107+
Set<Symbol> leftSymbols = ImmutableSet.copyOf(join.getLeft().getOutputSymbols());
108+
Set<Symbol> rightSymbols = ImmutableSet.copyOf(join.getRight().getOutputSymbols());
109+
110+
// Track if we're pushing down any non-identity projections
111+
boolean hasNonIdentityProjectionsToPush = false;
112+
113+
for (Map.Entry<Symbol, Expression> assignment : project.getAssignments().entrySet()) {
114+
Symbol outputSymbol = assignment.getKey();
115+
Expression expression = assignment.getValue();
116+
Set<Symbol> referencedSymbols = extractUnique(expression);
117+
118+
boolean referencesLeft = leftSymbols.containsAll(referencedSymbols);
119+
boolean referencesRight = rightSymbols.containsAll(referencedSymbols);
120+
boolean isIdentity = expression instanceof Reference && ((Reference) expression).name().equals(outputSymbol.name());
121+
122+
if (referencesLeft && !referencesRight) {
123+
// Can potentially push to left side
124+
if (join.getType() == JoinType.RIGHT || join.getType() == JoinType.FULL) {
125+
remainingProjections.put(outputSymbol, expression);
126+
}
127+
else {
128+
leftProjections.put(outputSymbol, expression);
129+
if (!isIdentity) {
130+
hasNonIdentityProjectionsToPush = true;
131+
}
132+
}
133+
}
134+
else if (referencesRight && !referencesLeft) {
135+
// Can potentially push to right side
136+
if (join.getType() == JoinType.LEFT || join.getType() == JoinType.FULL) {
137+
remainingProjections.put(outputSymbol, expression);
138+
}
139+
else {
140+
rightProjections.put(outputSymbol, expression);
141+
if (!isIdentity) {
142+
hasNonIdentityProjectionsToPush = true;
143+
}
144+
}
145+
}
146+
else {
147+
// References both sides or neither - keep above join
148+
remainingProjections.put(outputSymbol, expression);
149+
}
150+
}
151+
152+
// If no non-identity projections can be pushed down, return empty
153+
if (!hasNonIdentityProjectionsToPush) {
154+
return Result.empty();
155+
}
156+
157+
// Add identity projections for symbols required by the join
158+
for (Symbol symbol : joinRequiredSymbols) {
159+
if (leftSymbols.contains(symbol)) {
160+
leftProjections.putIdentity(symbol);
161+
}
162+
if (rightSymbols.contains(symbol)) {
163+
rightProjections.putIdentity(symbol);
164+
}
165+
}
166+
167+
// Also add identity projections for any symbols that appear in the project's output
168+
// These need to flow through the join even if they're not computed
169+
for (Symbol symbol : project.getOutputSymbols()) {
170+
if (leftSymbols.contains(symbol)) {
171+
leftProjections.putIdentity(symbol);
172+
}
173+
if (rightSymbols.contains(symbol)) {
174+
rightProjections.putIdentity(symbol);
175+
}
176+
}
177+
178+
Assignments leftAssignments = leftProjections.build();
179+
Assignments rightAssignments = rightProjections.build();
180+
181+
// Create new project nodes on each side if there are projections to push
182+
PlanNodeIdAllocator idAllocator = context.getIdAllocator();
183+
PlanNode newLeft = join.getLeft();
184+
PlanNode newRight = join.getRight();
185+
186+
if (!leftAssignments.isEmpty()) {
187+
newLeft = new ProjectNode(idAllocator.getNextId(), join.getLeft(), leftAssignments);
188+
}
189+
190+
if (!rightAssignments.isEmpty()) {
191+
newRight = new ProjectNode(idAllocator.getNextId(), join.getRight(), rightAssignments);
192+
}
193+
194+
// Build the list of output symbols for the new join
195+
// Include all symbols from both sides - the Project above will filter as needed
196+
List<Symbol> newLeftOutputSymbols = newLeft.getOutputSymbols();
197+
List<Symbol> newRightOutputSymbols = newRight.getOutputSymbols();
198+
199+
// Create new join with pushed-down projections
200+
JoinNode newJoin = new JoinNode(
201+
join.getId(),
202+
join.getType(),
203+
newLeft,
204+
newRight,
205+
join.getCriteria(),
206+
newLeftOutputSymbols,
207+
newRightOutputSymbols,
208+
join.isMaySkipOutputDuplicates(),
209+
join.getFilter(),
210+
join.getDistributionType(),
211+
join.isSpillable(),
212+
join.getDynamicFilters(),
213+
join.getReorderJoinStatsAndCost());
214+
215+
// If there are remaining projections, keep them above the join
216+
if (!remainingProjections.build().isEmpty()) {
217+
// Create identity projections for pushed-down symbols
218+
for (Symbol symbol : project.getOutputSymbols()) {
219+
if (remainingProjections.build().get(symbol) == null) {
220+
remainingProjections.putIdentity(symbol);
221+
}
222+
}
223+
224+
return Result.ofPlanNode(new ProjectNode(project.getId(), newJoin, remainingProjections.build()));
225+
}
226+
227+
// All projections were pushed down, just return the join
228+
// But we may need to restrict outputs to match the original project's outputs
229+
if (!newJoin.getOutputSymbols().equals(project.getOutputSymbols())) {
230+
return Result.ofPlanNode(new ProjectNode(project.getId(), newJoin, Assignments.identity(project.getOutputSymbols())));
231+
}
232+
233+
return Result.ofPlanNode(newJoin);
234+
}
235+
}

core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ public void testDereferencePushdownMultiLevel()
6464
assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))), ROW(CAST(ROW(3, 4.0) AS ROW(x BIGINT, y DOUBLE)))) " +
6565
"SELECT a.msg.x, a.msg, b.msg.y FROM t a CROSS JOIN t b",
6666
output(ImmutableList.of("a_msg_x", "a_msg", "b_msg_y"),
67-
strictProject(
68-
ImmutableMap.of(
69-
"a_msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), 0)),
70-
"a_msg", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg")),
71-
"b_msg_y", expression(new Reference(DOUBLE, "b_msg_y"))),
72-
join(INNER, builder -> builder
73-
.left(values("a_msg"))
74-
.right(
75-
values(ImmutableList.of("b_msg_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0)), ImmutableList.of(new Constant(DOUBLE, 4e0)))))))));
67+
join(INNER, builder -> builder
68+
.left(
69+
strictProject(
70+
ImmutableMap.of(
71+
"a_msg", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg")),
72+
"a_msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), 0))),
73+
values("a_msg")))
74+
.right(
75+
values(ImmutableList.of("b_msg_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0)), ImmutableList.of(new Constant(DOUBLE, 4e0))))))));
7676
}
7777

7878
@Test

core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,13 +2590,14 @@ WITH RECURSIVE recursive_call (level) AS (
25902590
SELECT * FROM recursive_call
25912591
""",
25922592
output(exchange(
2593-
LOCAL,
2594-
any(unnest(values("array"))),
2595-
any(join(
2593+
// First branch: Project -> Unnest -> Values
2594+
project(unnest(values("array"))),
2595+
// Second branch: CrossJoin (displayed as JoinNode with empty criteria)
2596+
join(
25962597
INNER,
25972598
builder -> builder
2598-
.left(any(unnest(values("array"))))
2599-
.right(exchange(tableScan("nation"))))))));
2599+
.left(project(unnest(values("array"))))
2600+
.right(exchange(tableScan("nation")))))));
26002601
}
26012602

26022603
@Test

0 commit comments

Comments
 (0)