diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 443229a3cb77..3fa602f12554 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -114,10 +114,10 @@ impl Default for SessionConfig { } /// A type map for storing extensions. -/// +/// /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one /// will be kept. -/// +/// /// Extensions are opaque objects that are unknown to DataFusion itself but can be downcast by optimizer rules, /// execution plans, or other components that have access to the session config. /// They provide a flexible way to attach extra data or behavior to the session config. diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b91..d06629d03c5e 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -24,10 +24,13 @@ use datafusion_common::{ }; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{simplifier::not::simplify_not_expr, PhysicalExpr}; +pub mod not; pub mod unwrap_cast; +const MAX_LOOP_COUNT: usize = 5; + /// Simplifies physical expressions by applying various optimizations /// /// This can be useful after adapting expressions from a table schema @@ -48,7 +51,17 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { - Ok(expr.rewrite(self)?.data) + let mut current_expr = expr; + let mut count = 0; + while count < MAX_LOOP_COUNT { + count += 1; + let result = current_expr.rewrite(self)?; + if !result.transformed { + return Ok(result.data); + } + current_expr = result.data; + } + Ok(current_expr) } } @@ -56,24 +69,32 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { type Node = Arc; fn f_up(&mut self, node: Self::Node) -> Result> { - // Apply unwrap cast optimization #[cfg(test)] let original_type = node.data_type(self.schema).unwrap(); - let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?; + + // Apply NOT expression simplification first, then unwrap cast optimization + let rewritten = + simplify_not_expr(&node, self.schema)?.transform_data(|node| { + unwrap_cast::unwrap_cast_in_comparison(node, self.schema) + })?; + #[cfg(test)] assert_eq!( - unwrapped.data.data_type(self.schema).unwrap(), + rewritten.data.data_type(self.schema).unwrap(), original_type, "Simplified expression should have the same data type as the original" ); - Ok(unwrapped) + + Ok(rewritten) } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; + use crate::expressions::{ + col, in_list, lit, BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, + }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; @@ -86,6 +107,42 @@ mod tests { ]) } + fn not_test_schema() -> Schema { + Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]) + } + + /// Helper function to extract a Literal from a PhysicalExpr + fn as_literal(expr: &Arc) -> &Literal { + expr.as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("Expected Literal, got: {expr}")) + } + + /// Helper function to extract a BinaryExpr from a PhysicalExpr + fn as_binary(expr: &Arc) -> &BinaryExpr { + expr.as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}")) + } + + /// Assert that simplifying `input` produces `expected` + fn assert_not_simplify( + simplifier: &mut PhysicalExprSimplifier, + input: Arc, + expected: Arc, + ) { + let result = simplifier.simplify(Arc::clone(&input)).unwrap(); + assert_eq!( + &result, + &expected, + "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}" + ); + } + #[test] fn test_simplify() { let schema = test_schema(); @@ -101,7 +158,7 @@ mod tests { // Apply full simplification (uses TreeNodeRewriter) let optimized = simplifier.simplify(binary_expr).unwrap(); - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = as_binary(&optimized); // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match) let left_expr = optimized_binary.left(); @@ -109,11 +166,7 @@ mod tests { left_expr.as_any().downcast_ref::().is_none() && left_expr.as_any().downcast_ref::().is_none() ); - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(optimized_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99))); } @@ -138,14 +191,10 @@ mod tests { // Apply simplification let optimized = simplifier.simplify(or_expr).unwrap(); - let or_binary = optimized.as_any().downcast_ref::().unwrap(); + let or_binary = as_binary(&optimized); // Verify left side: c1 > INT32(5) - let left_binary = or_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); + let left_binary = as_binary(or_binary.left()); let left_left_expr = left_binary.left(); assert!( left_left_expr.as_any().downcast_ref::().is_none() @@ -154,19 +203,11 @@ mod tests { .downcast_ref::() .is_none() ); - let left_literal = left_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let left_literal = as_literal(left_binary.right()); assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5))); // Verify right side: c2 <= INT64(10) - let right_binary = or_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_binary = as_binary(or_binary.right()); let right_left_expr = right_binary.left(); assert!( right_left_expr @@ -178,11 +219,276 @@ mod tests { .downcast_ref::() .is_none() ); - let right_literal = right_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(right_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10))); } + + #[test] + fn test_double_negation_elimination() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c > 5)) -> c > 5 + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); + let double_not: Arc = Arc::new(NotExpr::new(inner_not)); + + let expected = inner_expr; + assert_not_simplify(&mut simplifier, double_not, expected); + Ok(()) + } + + #[test] + fn test_not_literal() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(TRUE) -> FALSE + let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); + let expected = lit(ScalarValue::Boolean(Some(false))); + assert_not_simplify(&mut simplifier, not_true, expected); + + // NOT(FALSE) -> TRUE + let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); + let expected = lit(ScalarValue::Boolean(Some(true))); + assert_not_simplify(&mut simplifier, not_false, expected); + + Ok(()) + } + + #[test] + fn test_negate_comparison() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 5) -> c != 5 + let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(5))), + )))); + let expected = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(5))), + )); + assert_not_simplify(&mut simplifier, not_eq, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_and() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a AND b) -> NOT a OR NOT b + let and_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + )); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::Or, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_or() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a OR b) -> NOT a AND NOT b + let or_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + let not_or: Arc = Arc::new(NotExpr::new(or_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::And, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&mut simplifier, not_or, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_with_comparison_simplification() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 1 AND c = 2) -> c != 1 OR c != 2 + let eq1 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(1))), + )); + let eq2 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(2))), + )); + let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(1))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(2))), + )), + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_of_not_and_not() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(a) AND NOT(b)) -> a OR b + let not_a = Arc::new(NotExpr::new(col("a", &schema)?)); + let not_b = Arc::new(NotExpr::new(col("b", &schema)?)); + let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c IN (1, 2, 3)) -> c NOT IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &true, &schema)?; + assert_not_simplify(&mut simplifier, not_in, expected); + + Ok(()) + } + + #[test] + fn test_not_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c NOT IN (1, 2, 3)) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?; + let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&mut simplifier, not_not_in, expected); + + Ok(()) + } + + #[test] + fn test_double_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c IN (1, 2, 3))) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in = Arc::new(NotExpr::new(in_list_expr)); + let double_not: Arc = Arc::new(NotExpr::new(not_in)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&mut simplifier, double_not, expected); + + Ok(()) + } + + #[test] + fn test_deeply_nested_not() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(c > 5)...))) + // This tests that we don't get stack overflow with many nested NOTs. + // With recursive_protection enabled (default), this should work by + // automatically growing the stack as needed. + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + + let mut expr = Arc::clone(&inner_expr); + // Create 200 layers of NOT to test deep recursion handling + for _ in 0..200 { + expr = Arc::new(NotExpr::new(expr)); + } + + // With 200 NOTs (even number), should simplify back to the original expression + let expected = inner_expr; + assert_not_simplify(&mut simplifier, Arc::clone(&expr), expected); + + // Manually dismantle the deep input expression to avoid Stack Overflow on Drop + // If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack + // even with recursive_protection, because Drop doesn't use the #[recursive] attribute. + // We peel off layers one by one to avoid deep recursion in Drop. + while let Some(not_expr) = expr.as_any().downcast_ref::() { + // Clone the child (Arc increment). + // Now child has 2 refs: one in parent, one in `child`. + let child = Arc::clone(not_expr.arg()); + + // Reassign `expr` to `child`. + // This drops the old `expr` (Parent). + // Parent refcount -> 0, Parent is dropped. + // Parent drops its reference to Child. + // Child refcount decrements 2 -> 1. + // Child is NOT dropped recursively because we still hold it in `expr` + expr = child; + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs new file mode 100644 index 000000000000..1ea969f58ff9 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplify NOT expressions in physical expressions +//! +//! This module provides optimizations for NOT expressions such as: +//! - Double negation elimination: NOT(NOT(expr)) -> expr +//! - NOT with binary comparisons: NOT(a = b) -> a != b +//! - NOT with IN expressions: NOT(a IN (list)) -> a NOT IN (list) +//! - De Morgan's laws: NOT(A AND B) -> NOT A OR NOT B +//! - Constant folding: NOT(TRUE) -> FALSE, NOT(FALSE) -> TRUE +//! +//! This function is designed to work with TreeNodeRewriter's f_up traversal, +//! which means children are already simplified when this function is called. +//! The TreeNodeRewriter will automatically call this function repeatedly until +//! no more transformations are possible. + +use std::sync::Arc; + +use arrow::datatypes::Schema; +use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; +use datafusion_expr::Operator; + +use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr}; +use crate::PhysicalExpr; + +/// Attempts to simplify NOT expressions by applying one level of transformation +/// +/// This function applies a single simplification rule and returns. When used with +/// TreeNodeRewriter, multiple passes will automatically be applied until no more +/// transformations are possible. +pub fn simplify_not_expr( + expr: &Arc, + schema: &Schema, +) -> Result>> { + // Check if this is a NOT expression + let not_expr = match expr.as_any().downcast_ref::() { + Some(not_expr) => not_expr, + None => return Ok(Transformed::no(Arc::clone(expr))), + }; + + let inner_expr = not_expr.arg(); + + // Handle NOT(NOT(expr)) -> expr (double negation elimination) + if let Some(inner_not) = inner_expr.as_any().downcast_ref::() { + return Ok(Transformed::yes(Arc::clone(inner_not.arg()))); + } + + // Handle NOT(literal) -> !literal + if let Some(literal) = inner_expr.as_any().downcast_ref::() { + if let ScalarValue::Boolean(Some(val)) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(Some(!val))))); + } + if let ScalarValue::Boolean(None) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(None)))); + } + } + + // Handle NOT(IN list) -> NOT IN list + if let Some(in_list_expr) = inner_expr.as_any().downcast_ref::() { + let negated = !in_list_expr.negated(); + let new_in_list = in_list( + Arc::clone(in_list_expr.expr()), + in_list_expr.list().to_vec(), + &negated, + schema, + )?; + return Ok(Transformed::yes(new_in_list)); + } + + // Handle NOT(binary_expr) + if let Some(binary_expr) = inner_expr.as_any().downcast_ref::() { + if let Some(negated_op) = binary_expr.op().negate() { + let new_binary = Arc::new(BinaryExpr::new( + Arc::clone(binary_expr.left()), + negated_op, + Arc::clone(binary_expr.right()), + )); + return Ok(Transformed::yes(new_binary)); + } + + // Handle De Morgan's laws for AND/OR + match binary_expr.op() { + Operator::And => { + // NOT(A AND B) -> NOT A OR NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::Or, not_right)); + return Ok(Transformed::yes(new_binary)); + } + Operator::Or => { + // NOT(A OR B) -> NOT A AND NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::And, not_right)); + return Ok(Transformed::yes(new_binary)); + } + _ => {} + } + } + + // If no simplification possible, return the original expression + Ok(Transformed::no(Arc::clone(expr))) +}