Skip to content

Commit 503519c

Browse files
committed
Planning: Rule to push projection through join into TableScan
A new rule `PushProjectionThroughJoinIntoTableScan` is introduced to push projections that appear above a join down to the table scans on either side of the join. This optimization is particularly beneficial for cross-connector joins where the join itself cannot be pushed down to the connectors. The rule applies when: - All projection expressions are deterministic. - Each projection expression references columns from only one side of the join. - For outer joins, projections on the non-preserved side are not pushed to maintain correctness. This transformation creates new `ProjectNode`s on each side of the join, potentially followed by the original `ProjectNode` if some expressions could not be pushed down. The rule also ensures that symbols required by the join criteria and filter, as well as the original project's output symbols, are preserved through identity projections on the respective sides.
1 parent 97f1088 commit 503519c

File tree

10 files changed

+736
-161
lines changed

10 files changed

+736
-161
lines changed

core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
import io.trino.sql.planner.iterative.rule.PushPredicateThroughProjectIntoWindow;
177177
import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan;
178178
import io.trino.sql.planner.iterative.rule.PushProjectionThroughExchange;
179+
import io.trino.sql.planner.iterative.rule.PushProjectionThroughJoinIntoTableScan;
179180
import io.trino.sql.planner.iterative.rule.PushProjectionThroughUnion;
180181
import io.trino.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
181182
import io.trino.sql.planner.iterative.rule.PushSampleIntoTableScan;
@@ -666,6 +667,7 @@ public PlanOptimizers(
666667
ImmutableSet.<Rule<?>>builder()
667668
.addAll(projectionPushdownRules)
668669
.add(new PushProjectionIntoTableScan(plannerContext, scalarStatsCalculator))
670+
.add(new PushProjectionThroughJoinIntoTableScan())
669671
.build());
670672

671673
builder.add(
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.planner.iterative.rule;
15+
16+
import com.google.common.collect.ImmutableSet;
17+
import io.trino.matching.Capture;
18+
import io.trino.matching.Captures;
19+
import io.trino.matching.Pattern;
20+
import io.trino.sql.ir.Expression;
21+
import io.trino.sql.ir.Reference;
22+
import io.trino.sql.planner.DeterminismEvaluator;
23+
import io.trino.sql.planner.PlanNodeIdAllocator;
24+
import io.trino.sql.planner.Symbol;
25+
import io.trino.sql.planner.SymbolsExtractor;
26+
import io.trino.sql.planner.iterative.Rule;
27+
import io.trino.sql.planner.plan.Assignments;
28+
import io.trino.sql.planner.plan.JoinNode;
29+
import io.trino.sql.planner.plan.JoinType;
30+
import io.trino.sql.planner.plan.PlanNode;
31+
import io.trino.sql.planner.plan.ProjectNode;
32+
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.Set;
36+
import java.util.stream.Stream;
37+
38+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
39+
import static io.trino.matching.Capture.newCapture;
40+
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
41+
import static io.trino.sql.planner.plan.Patterns.join;
42+
import static io.trino.sql.planner.plan.Patterns.project;
43+
import static io.trino.sql.planner.plan.Patterns.source;
44+
45+
/**
46+
* Pushes projections that appear above a join down to the table scans on either side of the join.
47+
* This is particularly useful for cross-connector joins where the join cannot be pushed down, but
48+
* individual projections on each side can still be pushed to their respective connectors.
49+
*
50+
* <pre>
51+
* Transforms:
52+
* Project(x := f(a), y := g(b))
53+
* Join(a = b)
54+
* TableScan(a, ...)
55+
* TableScan(b, ...)
56+
*
57+
* Into:
58+
* Project(x, y) -- identity projections
59+
* Join(a = b)
60+
* Project(x := f(a), a) -- pushed down
61+
* TableScan(a, ...)
62+
* Project(y := g(b), b) -- pushed down
63+
* TableScan(b, ...)
64+
* </pre>
65+
*
66+
* <p>The rule only applies when: - All projection expressions are deterministic - Each projection
67+
* expression references columns from only one side of the join - For outer joins, projections on
68+
* the non-preserved side are not pushed (to maintain correctness)
69+
*/
70+
public class PushProjectionThroughJoinIntoTableScan
71+
implements Rule<ProjectNode>
72+
{
73+
private static final Capture<JoinNode> JOIN = newCapture();
74+
75+
private static final Pattern<ProjectNode> PATTERN = project().with(source().matching(join().capturedAs(JOIN)));
76+
77+
@Override
78+
public Pattern<ProjectNode> getPattern()
79+
{
80+
return PATTERN;
81+
}
82+
83+
@Override
84+
public Result apply(ProjectNode project, Captures captures, Context context)
85+
{
86+
JoinNode join = captures.get(JOIN);
87+
88+
// Only apply to deterministic projections
89+
if (!project.getAssignments().expressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) {
90+
return Result.empty();
91+
}
92+
93+
// Skip if all projections are identity - nothing to push down
94+
if (project.getAssignments().isIdentity()) {
95+
return Result.empty();
96+
}
97+
98+
Set<Symbol> joinCriteriaSymbols = join.getCriteria().stream().flatMap(criteria -> Stream.of(criteria.getLeft(), criteria.getRight())).collect(toImmutableSet());
99+
Set<Symbol> joinFilterSymbols = join.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of());
100+
Set<Symbol> joinRequiredSymbols = ImmutableSet.<Symbol>builder().addAll(joinCriteriaSymbols).addAll(joinFilterSymbols).build();
101+
102+
// Separate projections by which side of the join they reference
103+
Assignments.Builder leftProjections = Assignments.builder();
104+
Assignments.Builder rightProjections = Assignments.builder();
105+
Assignments.Builder remainingProjections = Assignments.builder();
106+
107+
Set<Symbol> leftSymbols = ImmutableSet.copyOf(join.getLeft().getOutputSymbols());
108+
Set<Symbol> rightSymbols = ImmutableSet.copyOf(join.getRight().getOutputSymbols());
109+
110+
// Track if we're pushing down any non-identity projections
111+
boolean hasNonIdentityProjectionsToPush = false;
112+
113+
for (Map.Entry<Symbol, Expression> assignment : project.getAssignments().entrySet()) {
114+
Symbol outputSymbol = assignment.getKey();
115+
Expression expression = assignment.getValue();
116+
Set<Symbol> referencedSymbols = extractUnique(expression);
117+
118+
boolean referencesLeft = leftSymbols.containsAll(referencedSymbols);
119+
boolean referencesRight = rightSymbols.containsAll(referencedSymbols);
120+
boolean isIdentity = expression instanceof Reference && ((Reference) expression).name().equals(outputSymbol.name());
121+
122+
if (referencesLeft && !referencesRight) {
123+
// Can potentially push to left side
124+
if (join.getType() == JoinType.RIGHT || join.getType() == JoinType.FULL) {
125+
remainingProjections.put(outputSymbol, expression);
126+
}
127+
else {
128+
leftProjections.put(outputSymbol, expression);
129+
if (!isIdentity) {
130+
hasNonIdentityProjectionsToPush = true;
131+
}
132+
}
133+
}
134+
else if (referencesRight && !referencesLeft) {
135+
// Can potentially push to right side
136+
if (join.getType() == JoinType.LEFT || join.getType() == JoinType.FULL) {
137+
remainingProjections.put(outputSymbol, expression);
138+
}
139+
else {
140+
rightProjections.put(outputSymbol, expression);
141+
if (!isIdentity) {
142+
hasNonIdentityProjectionsToPush = true;
143+
}
144+
}
145+
}
146+
else {
147+
// References both sides or neither - keep above join
148+
remainingProjections.put(outputSymbol, expression);
149+
}
150+
}
151+
152+
// If there are remaining projections that will stay above the join,
153+
// ensure their dependencies flow through the join outputs
154+
Assignments remainingAssignments = remainingProjections.build();
155+
Set<Symbol> remainingDependencies = ImmutableSet.of();
156+
if (!remainingAssignments.isEmpty()) {
157+
// Extract all symbols referenced by remaining projections
158+
remainingDependencies = remainingAssignments.expressions().stream()
159+
.flatMap(expr -> extractUnique(expr).stream())
160+
.collect(toImmutableSet());
161+
}
162+
163+
// Add identity projections for symbols required by the join
164+
// Also add identity projections for symbols needed by remaining projections
165+
// Only add to sides that have TableScans
166+
Set<Symbol> requiredSymbols = ImmutableSet.<Symbol>builder()
167+
.addAll(joinRequiredSymbols)
168+
.addAll(remainingDependencies)
169+
.build();
170+
171+
for (Symbol symbol : requiredSymbols) {
172+
if (leftSymbols.contains(symbol)) {
173+
leftProjections.putIdentity(symbol);
174+
}
175+
if (rightSymbols.contains(symbol)) {
176+
rightProjections.putIdentity(symbol);
177+
}
178+
}
179+
180+
Assignments leftAssignments = leftProjections.build();
181+
Assignments rightAssignments = rightProjections.build();
182+
183+
// If no non-identity projections can be pushed down, return empty
184+
// This prevents infinite loops where we keep pushing down only identity projections
185+
if (!hasNonIdentityProjectionsToPush) {
186+
return Result.empty();
187+
}
188+
189+
// Create new project nodes on each side if there are projections to push
190+
PlanNodeIdAllocator idAllocator = context.getIdAllocator();
191+
PlanNode newLeft = join.getLeft();
192+
PlanNode newRight = join.getRight();
193+
194+
if (!leftAssignments.isEmpty()) {
195+
newLeft = new ProjectNode(idAllocator.getNextId(), join.getLeft(), leftAssignments);
196+
}
197+
198+
if (!rightAssignments.isEmpty()) {
199+
newRight = new ProjectNode(idAllocator.getNextId(), join.getRight(), rightAssignments);
200+
}
201+
202+
// Build the list of output symbols for the new join
203+
// Include all symbols from both sides - the Project above will filter as needed
204+
List<Symbol> newLeftOutputSymbols = newLeft.getOutputSymbols();
205+
List<Symbol> newRightOutputSymbols = newRight.getOutputSymbols();
206+
207+
// Create new join with pushed-down projections
208+
JoinNode newJoin = new JoinNode(
209+
join.getId(),
210+
join.getType(),
211+
newLeft,
212+
newRight,
213+
join.getCriteria(),
214+
newLeftOutputSymbols,
215+
newRightOutputSymbols,
216+
join.isMaySkipOutputDuplicates(),
217+
join.getFilter(),
218+
join.getDistributionType(),
219+
join.isSpillable(),
220+
join.getDynamicFilters(),
221+
join.getReorderJoinStatsAndCost());
222+
223+
// If there are remaining projections, keep them above the join
224+
if (!remainingProjections.build().isEmpty()) {
225+
// Create identity projections for pushed-down symbols
226+
for (Symbol symbol : project.getOutputSymbols()) {
227+
if (remainingProjections.build().get(symbol) == null) {
228+
remainingProjections.putIdentity(symbol);
229+
}
230+
}
231+
232+
return Result.ofPlanNode(new ProjectNode(project.getId(), newJoin, remainingProjections.build()));
233+
}
234+
235+
// All projections were pushed down, just return the join
236+
// But we may need to restrict outputs to match the original project's outputs
237+
if (!newJoin.getOutputSymbols().equals(project.getOutputSymbols())) {
238+
return Result.ofPlanNode(new ProjectNode(project.getId(), newJoin, Assignments.identity(project.getOutputSymbols())));
239+
}
240+
241+
return Result.ofPlanNode(newJoin);
242+
}
243+
}

core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ public void testDereferencePushdownMultiLevel()
6464
assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))), ROW(CAST(ROW(3, 4.0) AS ROW(x BIGINT, y DOUBLE)))) " +
6565
"SELECT a.msg.x, a.msg, b.msg.y FROM t a CROSS JOIN t b",
6666
output(ImmutableList.of("a_msg_x", "a_msg", "b_msg_y"),
67-
strictProject(
68-
ImmutableMap.of(
69-
"a_msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), 0)),
70-
"a_msg", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg")),
71-
"b_msg_y", expression(new Reference(DOUBLE, "b_msg_y"))),
72-
join(INNER, builder -> builder
73-
.left(values("a_msg"))
74-
.right(
75-
values(ImmutableList.of("b_msg_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0)), ImmutableList.of(new Constant(DOUBLE, 4e0)))))))));
67+
join(INNER, builder -> builder
68+
.left(
69+
strictProject(
70+
ImmutableMap.of(
71+
"a_msg", expression(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg")),
72+
"a_msg_x", expression(new FieldReference(new Reference(RowType.anonymousRow(BIGINT, DOUBLE), "a_msg"), 0))),
73+
values("a_msg")))
74+
.right(
75+
values(ImmutableList.of("b_msg_y"), ImmutableList.of(ImmutableList.of(new Constant(DOUBLE, 2e0)), ImmutableList.of(new Constant(DOUBLE, 4e0))))))));
7676
}
7777

7878
@Test

core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,13 +2590,14 @@ WITH RECURSIVE recursive_call (level) AS (
25902590
SELECT * FROM recursive_call
25912591
""",
25922592
output(exchange(
2593-
LOCAL,
2594-
any(unnest(values("array"))),
2595-
any(join(
2593+
// First branch: Project -> Unnest -> Values
2594+
project(unnest(values("array"))),
2595+
// Second branch: CrossJoin (displayed as JoinNode with empty criteria)
2596+
join(
25962597
INNER,
25972598
builder -> builder
2598-
.left(any(unnest(values("array"))))
2599-
.right(exchange(tableScan("nation"))))))));
2599+
.left(project(unnest(values("array"))))
2600+
.right(exchange(tableScan("nation")))))));
26002601
}
26012602

26022603
@Test

0 commit comments

Comments
 (0)