@@ -3,7 +3,7 @@ use itertools::Itertools;
33use sqlparser:: ast:: { Expr , OrderByExpr } ;
44use std:: collections:: HashSet ;
55
6- use crate :: binder:: { BindError , InputRefType } ;
6+ use crate :: binder:: BindError ;
77use crate :: planner:: LogicalPlan ;
88use crate :: storage:: Transaction ;
99use crate :: {
@@ -28,7 +28,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
2828 select_items : & mut [ ScalarExpression ] ,
2929 ) -> Result < ( ) , BindError > {
3030 for column in select_items {
31- self . visit_column_agg_expr ( column, true ) ?;
31+ self . visit_column_agg_expr ( column) ?;
3232 }
3333 Ok ( ( ) )
3434 }
@@ -55,7 +55,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
5555 // Extract having expression.
5656 let return_having = if let Some ( having) = having {
5757 let mut having = self . bind_expr ( having) ?;
58- self . visit_column_agg_expr ( & mut having, false ) ?;
58+ self . visit_column_agg_expr ( & mut having) ?;
5959
6060 Some ( having)
6161 } else {
@@ -72,7 +72,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
7272 nulls_first,
7373 } = orderby;
7474 let mut expr = self . bind_expr ( expr) ?;
75- self . visit_column_agg_expr ( & mut expr, false ) ?;
75+ self . visit_column_agg_expr ( & mut expr) ?;
7676
7777 return_orderby. push ( SortField :: new (
7878 expr,
@@ -87,77 +87,30 @@ impl<'a, T: Transaction> Binder<'a, T> {
8787 Ok ( ( return_having, return_orderby) )
8888 }
8989
90- fn visit_column_agg_expr (
91- & mut self ,
92- expr : & mut ScalarExpression ,
93- is_select : bool ,
94- ) -> Result < ( ) , BindError > {
95- let ref_columns = expr. referenced_columns ( ) ;
96-
90+ fn visit_column_agg_expr ( & mut self , expr : & mut ScalarExpression ) -> Result < ( ) , BindError > {
9791 match expr {
98- ScalarExpression :: AggCall {
99- ty : return_type, ..
100- } => {
101- let ty = return_type. clone ( ) ;
102- if is_select {
103- let index = self . context . input_ref_index ( InputRefType :: AggCall ) ;
104- let input_ref = ScalarExpression :: InputRef {
105- index,
106- ty,
107- ref_columns,
108- } ;
109- match std:: mem:: replace ( expr, input_ref) {
110- ScalarExpression :: AggCall {
111- kind,
112- args,
113- ty,
114- distinct,
115- } => {
116- self . context . agg_calls . push ( ScalarExpression :: AggCall {
117- distinct,
118- kind,
119- args,
120- ty,
121- } ) ;
122- }
123- _ => unreachable ! ( ) ,
124- }
125- } else {
126- let ( index, _) = self
127- . context
128- . agg_calls
129- . iter ( )
130- . find_position ( |agg_expr| agg_expr == & expr)
131- . ok_or_else ( || BindError :: AggMiss ( format ! ( "{:?}" , expr) ) ) ?;
132-
133- let _ = std:: mem:: replace (
134- expr,
135- ScalarExpression :: InputRef {
136- index,
137- ty,
138- ref_columns,
139- } ,
140- ) ;
141- }
142- }
143-
144- ScalarExpression :: TypeCast { expr, .. } => {
145- self . visit_column_agg_expr ( expr, is_select) ?
92+ ScalarExpression :: AggCall { .. } => {
93+ self . context . agg_calls . push ( expr. clone ( ) ) ;
14694 }
147- ScalarExpression :: IsNull { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
148- ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
149- ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
95+ ScalarExpression :: TypeCast { expr, .. } => self . visit_column_agg_expr ( expr) ?,
96+ ScalarExpression :: IsNull { expr, .. } => self . visit_column_agg_expr ( expr) ?,
97+ ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr) ?,
98+ ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr) ?,
15099 ScalarExpression :: Binary {
151100 left_expr,
152101 right_expr,
153102 ..
154103 } => {
155- self . visit_column_agg_expr ( left_expr, is_select ) ?;
156- self . visit_column_agg_expr ( right_expr, is_select ) ?;
104+ self . visit_column_agg_expr ( left_expr) ?;
105+ self . visit_column_agg_expr ( right_expr) ?;
157106 }
158- ScalarExpression :: Constant ( _)
159- | ScalarExpression :: ColumnRef { .. }
160- | ScalarExpression :: InputRef { .. } => { }
107+ ScalarExpression :: In { expr, args, .. } => {
108+ self . visit_column_agg_expr ( expr) ?;
109+ for arg in args {
110+ self . visit_column_agg_expr ( arg) ?;
111+ }
112+ }
113+ ScalarExpression :: Constant ( _) | ScalarExpression :: ColumnRef { .. } => { }
161114 }
162115
163116 Ok ( ( ) )
@@ -239,44 +192,13 @@ impl<'a, T: Transaction> Binder<'a, T> {
239192 false
240193 }
241194 } ) {
242- let index = self . context . input_ref_index ( InputRefType :: GroupBy ) ;
243- let mut select_item = & mut select_list[ i] ;
244- let ref_columns = select_item. referenced_columns ( ) ;
245- let return_type = select_item. return_type ( ) ;
246-
247- self . context . group_by_exprs . push ( std:: mem:: replace (
248- & mut select_item,
249- ScalarExpression :: InputRef {
250- index,
251- ty : return_type,
252- ref_columns,
253- } ,
254- ) ) ;
195+ self . context . group_by_exprs . push ( select_list[ i] . clone ( ) ) ;
255196 return ;
256197 }
257198 }
258199
259200 if let Some ( i) = select_list. iter ( ) . position ( |column| column == expr) {
260- let expr = & mut select_list[ i] ;
261- let ref_columns = expr. referenced_columns ( ) ;
262-
263- match expr {
264- ScalarExpression :: Constant ( _) | ScalarExpression :: ColumnRef { .. } => {
265- self . context . group_by_exprs . push ( expr. clone ( ) )
266- }
267- _ => {
268- let index = self . context . input_ref_index ( InputRefType :: GroupBy ) ;
269-
270- self . context . group_by_exprs . push ( std:: mem:: replace (
271- expr,
272- ScalarExpression :: InputRef {
273- index,
274- ty : expr. return_type ( ) ,
275- ref_columns,
276- } ,
277- ) )
278- }
279- }
201+ self . context . group_by_exprs . push ( select_list[ i] . clone ( ) )
280202 }
281203 }
282204
@@ -320,6 +242,13 @@ impl<'a, T: Transaction> Binder<'a, T> {
320242 ScalarExpression :: TypeCast { expr, .. } => self . validate_having_orderby ( expr) ,
321243 ScalarExpression :: IsNull { expr, .. } => self . validate_having_orderby ( expr) ,
322244 ScalarExpression :: Unary { expr, .. } => self . validate_having_orderby ( expr) ,
245+ ScalarExpression :: In { expr, args, .. } => {
246+ self . validate_having_orderby ( expr) ?;
247+ for arg in args {
248+ self . validate_having_orderby ( arg) ?;
249+ }
250+ Ok ( ( ) )
251+ }
323252 ScalarExpression :: Binary {
324253 left_expr,
325254 right_expr,
@@ -330,7 +259,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
330259 Ok ( ( ) )
331260 }
332261
333- ScalarExpression :: Constant ( _) | ScalarExpression :: InputRef { .. } => Ok ( ( ) ) ,
262+ ScalarExpression :: Constant ( _) => Ok ( ( ) ) ,
334263 }
335264 }
336265}
0 commit comments