diff --git a/src/max_min_by.rs b/src/max_min_by.rs index bfa3754..5de4056 100644 --- a/src/max_min_by.rs +++ b/src/max_min_by.rs @@ -13,6 +13,7 @@ make_udaf_expr_and_func!( #[derive(Eq, Hash, PartialEq)] pub struct MaxByFunction { + null_first: bool, signature: logical_expr::Signature, } @@ -27,13 +28,14 @@ impl fmt::Debug for MaxByFunction { } impl Default for MaxByFunction { fn default() -> Self { - Self::new() + Self::new(true) } } impl MaxByFunction { - pub fn new() -> Self { + pub fn new(null_first: bool) -> Self { Self { + null_first, signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable), } } @@ -80,6 +82,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction { ) -> error::Result> { common::exec_err!("should not reach here") } + fn coerce_types( &self, arg_types: &[arrow::datatypes::DataType], @@ -88,25 +91,25 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction { } fn simplify(&self) -> Option { - let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction, - _: &dyn logical_expr::simplify::SimplifyInfo| { + let null_first = self.null_first; + let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction, + _: &dyn logical_expr::simplify::SimplifyInfo| { let mut order_by = aggr_func.params.order_by; let (second_arg, first_arg) = ( aggr_func.params.args.remove(1), aggr_func.params.args.remove(0), ); - let sort = logical_expr::expr::Sort::new(second_arg, true, false); + let sort = logical_expr::expr::Sort::new(second_arg, true, null_first); order_by.push(sort); - let func = logical_expr::expr::Expr::AggregateFunction( - logical_expr::expr::AggregateFunction::new_udf( - functions_aggregate::first_last::last_value_udaf(), - vec![first_arg], - aggr_func.params.distinct, - aggr_func.params.filter, - order_by, - aggr_func.params.null_treatment, - ), + let func = logical_expr::expr::AggregateFunction::new_udf( + functions_aggregate::first_last::last_value_udaf(), + vec![first_arg], + aggr_func.params.distinct, + aggr_func.params.filter, + order_by, + aggr_func.params.null_treatment, ); + let func = logical_expr::expr::Expr::AggregateFunction(func); Ok(func) }; Some(Box::new(simplify)) @@ -123,6 +126,7 @@ make_udaf_expr_and_func!( #[derive(Eq, Hash, PartialEq)] pub struct MinByFunction { + null_first: bool, signature: logical_expr::Signature, } @@ -138,13 +142,14 @@ impl fmt::Debug for MinByFunction { impl Default for MinByFunction { fn default() -> Self { - Self::new() + Self::new(true) } } impl MinByFunction { - pub fn new() -> Self { + pub fn new(null_first: bool) -> Self { Self { + null_first, signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable), } } @@ -185,26 +190,26 @@ impl logical_expr::AggregateUDFImpl for MinByFunction { } fn simplify(&self) -> Option { - let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction, - _: &dyn logical_expr::simplify::SimplifyInfo| { + let null_first = self.null_first; + let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction, + _: &dyn logical_expr::simplify::SimplifyInfo| { let mut order_by = aggr_func.params.order_by; let (second_arg, first_arg) = ( aggr_func.params.args.remove(1), aggr_func.params.args.remove(0), ); - let sort = logical_expr::expr::Sort::new(second_arg, false, false); + let sort = logical_expr::expr::Sort::new(second_arg, false, null_first); order_by.push(sort); // false for ascending sort - let func = logical_expr::expr::Expr::AggregateFunction( - logical_expr::expr::AggregateFunction::new_udf( - functions_aggregate::first_last::last_value_udaf(), - vec![first_arg], - aggr_func.params.distinct, - aggr_func.params.filter, - order_by, - aggr_func.params.null_treatment, - ), + let func = logical_expr::expr::AggregateFunction::new_udf( + functions_aggregate::first_last::last_value_udaf(), + vec![first_arg], + aggr_func.params.distinct, + aggr_func.params.filter, + order_by, + aggr_func.params.null_treatment, ); + let func = logical_expr::expr::Expr::AggregateFunction(func); Ok(func) }; Some(Box::new(simplify)) @@ -325,6 +330,7 @@ mod tests { #[cfg(test)] mod max_by { + use super::*; #[tokio::test] @@ -387,9 +393,26 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_max_by_ignores_nulls() -> error::Result<()> { + let query = r#" + SELECT max_by(v, k) + FROM ( + VALUES + ('a', 1), + ('b', CAST(NULL AS INT)), + ('c', 2) + ) AS t(v, k) + "#; + let df = ctx()?.sql(query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, "c", "max_by should ignore NULLs"); + Ok(()) + } + fn ctx() -> error::Result { let ctx = test_ctx()?; - let max_by_udaf = MaxByFunction::new(); + let max_by_udaf = MaxByFunction::default(); ctx.register_udaf(max_by_udaf.into()); Ok(ctx) } @@ -460,9 +483,26 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_min_by_ignores_nulls() -> error::Result<()> { + let query = r#" + SELECT min_by(v, k) + FROM ( + VALUES + ('a', 1), + ('b', CAST(NULL AS INT)), + ('c', 2) + ) AS t(v, k) + "#; + let df = ctx()?.sql(query).await?; + let result = extract_single_value::(df).await?; + assert_eq!(result, "a", "min_by should ignore NULLs"); + Ok(()) + } + fn ctx() -> error::Result { let ctx = test_ctx()?; - let min_by_udaf = MinByFunction::new(); + let min_by_udaf = MinByFunction::default(); ctx.register_udaf(min_by_udaf.into()); Ok(ctx) } diff --git a/tests/main.rs b/tests/main.rs index 351f884..aedb016 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -185,7 +185,7 @@ async fn test_max_by_and_min_by() { - +---------------------+ - "| max_by(tab.x,tab.y) |" - +---------------------+ - - "| 2 |" + - "| 3 |" - +---------------------+ "###); @@ -200,7 +200,7 @@ async fn test_max_by_and_min_by() { - +---------------------+ - "| min_by(tab.x,tab.y) |" - +---------------------+ - - "| 2 |" + - "| |" - +---------------------+ "###);