11from collections .abc import Callable
22from dataclasses import dataclass
3- from decimal import Decimal
43from typing import Any
5-
4+ import duckdb
5+ import pyarrow as pa
66import sqlglot
77import sqlglot .expressions
88import sqlglot .optimizer .simplify
@@ -19,36 +19,82 @@ class BaseFieldInfo:
1919
2020
2121@dataclass
22- class RangeFieldInfo [ T : Any ] (BaseFieldInfo ):
22+ class RangeFieldInfo (BaseFieldInfo ):
2323 """
2424 Information about a field that has a min and max value.
2525 """
2626
27- min_value : T
28- max_value : T
27+ min_value : pa . Scalar
28+ max_value : pa . Scalar
2929
3030
3131@dataclass
32- class SetFieldInfo [ T : Any ] (BaseFieldInfo ):
32+ class SetFieldInfo (BaseFieldInfo ):
3333 """
3434 Information about a field where the set of values are known.
3535 The information about what values that are contained can produce
3636 false positives.
3737 """
3838
3939 values : set [
40- T
40+ pa . Scalar
4141 ] # Set of values that are known to be present in the field, false positives are okay.
4242
4343
44- AnyFieldInfo = (
45- SetFieldInfo [Decimal ]
46- | SetFieldInfo [float ]
47- | SetFieldInfo [str ]
48- | SetFieldInfo [int ]
49- | RangeFieldInfo [int ]
50- | RangeFieldInfo [None ]
51- )
44+ AnyFieldInfo = SetFieldInfo | RangeFieldInfo
45+
46+
47+ def _scalar_value_op (
48+ a : pa .Scalar , b : pa .Scalar , op : Callable [[Any , Any ], bool ]
49+ ) -> bool :
50+ assert not pa .types .is_null (a .type ), (
51+ f"Expected a non-null scalar value, got { a } of type { a .type } "
52+ )
53+ assert not pa .types .is_null (b .type ), (
54+ f"Expected a non-null scalar value, got { b } of type { b .type } "
55+ )
56+
57+ # If we have integers or floats we can do that comparision regardless of their types.
58+ if pa .types .is_integer (a .type ) and pa .types .is_integer (b .type ):
59+ return op (a .as_py (), b .as_py ())
60+
61+ if pa .types .is_floating (a .type ) and pa .types .is_floating (b .type ):
62+ return op (a .as_py (), b .as_py ())
63+
64+ if pa .types .is_string (a .type ) and pa .types .is_string (b .type ):
65+ return op (a .as_py (), b .as_py ())
66+
67+ if pa .types .is_boolean (a .type ) and pa .types .is_boolean (b .type ):
68+ return op (a .as_py (), b .as_py ())
69+
70+ if pa .types .is_decimal (a .type ) and pa .types .is_decimal (b .type ):
71+ return op (a .as_py (), b .as_py ())
72+
73+ assert type (a ) is type (b ), (
74+ f"Expected same type for comparison, got { type (a )} and { type (b )} "
75+ )
76+
77+ return op (a .as_py (), b .as_py ())
78+
79+
80+ def _scalar_value_lte (a : pa .Scalar , b : pa .Scalar ) -> bool :
81+ return _scalar_value_op (a , b , lambda x , y : x <= y )
82+
83+
84+ def _scalar_value_lt (a : pa .Scalar , b : pa .Scalar ) -> bool :
85+ return _scalar_value_op (a , b , lambda x , y : x < y )
86+
87+
88+ def _scalar_value_gt (a : pa .Scalar , b : pa .Scalar ) -> bool :
89+ return _scalar_value_op (a , b , lambda x , y : x > y )
90+
91+
92+ def _scalar_value_gte (a : pa .Scalar , b : pa .Scalar ) -> bool :
93+ return _scalar_value_op (a , b , lambda x , y : x >= y )
94+
95+
96+ def _scalar_value_eq (a : pa .Scalar , b : pa .Scalar ) -> bool :
97+ return _scalar_value_op (a , b , lambda x , y : x == y )
5298
5399
54100FileFieldInfo = dict [str , AnyFieldInfo ]
@@ -93,18 +139,29 @@ def _eval_predicate(
93139 if not isinstance (node .left , sqlglot .expressions .Column ):
94140 return None
95141
142+ if node .right .find (sqlglot .expressions .Column ) is not None :
143+ # Can't evaluate this since it has a right hand column ref, ideally
144+ # this should be removed further up.
145+ return None
146+
96147 # The thing on the right side should be something that can be evaluated against a range.
97148 # ideally, its going to be a
98- assert isinstance (
99- node .right ,
100- sqlglot .expressions .Literal
101- | sqlglot .expressions .Null
102- | sqlglot .expressions .Neg ,
103- ), (
104- f"Expected a literal or null on righthand side of predicate { node } got a { type (node .right )} "
105- )
149+ if True : # isinstance(node.right, sqlglot.expressions.Cast):
150+ connection = duckdb .connect (":memory:" )
151+ value_result = connection .execute (
152+ f"select { node .right .sql ('duckdb' )} "
153+ ).arrow ()
154+ assert value_result .num_rows == 1 , (
155+ f"Expected a single row result from cast, got { value_result .num_rows } rows"
156+ )
157+ assert value_result .num_columns == 1 , (
158+ f"Expected a single column result from cast, got { value_result .num_columns } columns"
159+ )
106160
107- right_val = node .right .to_py ()
161+ right_val = value_result .column (0 )[0 ]
162+ # This is an interesting behavior, null is returned with an int32 type.
163+ if type (right_val ) is pa .Int32Scalar and right_val .as_py () is None :
164+ right_val = pa .scalar (None , type = pa .null ())
108165
109166 left_val = node .left
110167 assert isinstance (left_val , sqlglot .expressions .Column ), (
@@ -117,17 +174,19 @@ def _eval_predicate(
117174
118175 field_info = file_info .get (referenced_field_name )
119176
177+ # Right now if the field is not present in the file,
178+ # just note that we couldn't evaluate the expression.
120179 if field_info is None :
121180 return None
122181
123182 if isinstance (field_info , SetFieldInfo ):
124183 match type (node ):
125184 case sqlglot .expressions .EQ :
126- if right_val is None :
185+ if pa . types . is_null ( right_val . type ) :
127186 return False
128187 return right_val in field_info .values
129188 case sqlglot .expressions .NEQ :
130- if right_val is None :
189+ if pa . types . is_null ( right_val . type ) :
131190 return False
132191 return right_val not in field_info .values
133192 case _:
@@ -136,44 +195,70 @@ def _eval_predicate(
136195 )
137196
138197 if type (node ) is sqlglot .expressions .NullSafeNEQ :
139- if right_val is not None and field_info .has_non_nulls is False :
198+ if (
199+ not pa .types .is_null (right_val .type )
200+ and field_info .has_non_nulls is False
201+ ):
140202 return True
141- return not (field_info .min_value == field_info .max_value == right_val )
203+
204+ if pa .types .is_null (right_val .type ):
205+ return field_info .has_non_nulls
206+
207+ return not (
208+ _scalar_value_eq (field_info .min_value , field_info .max_value )
209+ and _scalar_value_eq (field_info .min_value , right_val )
210+ )
211+
142212 elif type (node ) is sqlglot .expressions .NullSafeEQ :
143- if right_val is None and field_info .has_non_nulls :
213+ if pa . types . is_null ( right_val . type ) and field_info .has_non_nulls :
144214 return True
145215 if field_info .min_value is None or field_info .max_value is None :
146216 return False
147- assert right_val is not None
148- return field_info .min_value <= right_val <= field_info .max_value
217+ assert not pa .types .is_null (right_val .type )
218+ return _scalar_value_lte (
219+ field_info .min_value , right_val
220+ ) and _scalar_value_lte (right_val , field_info .max_value )
149221
150222 if field_info .min_value is None or field_info .max_value is None :
151223 return False
152224
153- if right_val is None :
225+ if pa . types . is_null ( right_val . type ) :
154226 return False
155227
156228 match type (node ):
157229 case sqlglot .expressions .EQ :
158- return field_info .min_value <= right_val <= field_info .max_value
230+ return _scalar_value_lte (
231+ field_info .min_value , right_val
232+ ) and _scalar_value_lte (right_val , field_info .max_value )
159233 case sqlglot .expressions .NEQ :
160- return not (field_info .min_value == field_info .max_value == right_val )
234+ return not (
235+ _scalar_value_eq (field_info .min_value , field_info .max_value )
236+ and _scalar_value_eq (field_info .min_value , right_val )
237+ )
161238 case sqlglot .expressions .LT :
162- return field_info .min_value < right_val
239+ return _scalar_value_lt ( field_info .min_value , right_val )
163240 case sqlglot .expressions .LTE :
164- return field_info .min_value <= right_val
241+ return _scalar_value_lte ( field_info .min_value , right_val )
165242 case sqlglot .expressions .GT :
166- return field_info .max_value > right_val
243+ return _scalar_value_gt ( field_info .max_value , right_val )
167244 case sqlglot .expressions .GTE :
168- return field_info .max_value >= right_val
245+ return _scalar_value_gte ( field_info .max_value , right_val )
169246 case sqlglot .expressions .NullSafeEQ :
170- if right_val is None and field_info .has_non_nulls :
247+ if pa . types . is_null ( right_val . type ) and field_info .has_non_nulls :
171248 return True
172- return field_info .min_value <= right_val <= field_info .max_value
249+ return _scalar_value_lte (
250+ field_info .min_value , right_val
251+ ) and _scalar_value_lte (right_val , field_info .max_value )
173252 case sqlglot .expressions .NullSafeNEQ :
174- if right_val is not None and field_info .has_non_nulls is False :
253+ if (
254+ not pa .types .is_null (right_val .type )
255+ and field_info .has_non_nulls is False
256+ ):
175257 return True
176- return not (field_info .min_value == field_info .max_value == right_val )
258+ return not (
259+ _scalar_value_eq (field_info .min_value , field_info .max_value )
260+ and _scalar_value_eq (field_info .min_value , right_val )
261+ )
177262 case _:
178263 raise ValueError (f"Unsupported operator type: { type (node )} " )
179264
@@ -234,14 +319,6 @@ def _evaluate_node_in(
234319 return False
235320
236321 for in_exp in node .expressions :
237- assert isinstance (
238- in_exp ,
239- sqlglot .expressions .Literal
240- | sqlglot .expressions .Neg
241- | sqlglot .expressions .Null ,
242- ), (
243- f"Expected a literal in in side of { node } , got { in_exp } type { type (in_exp )} "
244- )
245322 if self ._eval_predicate (
246323 file_info ,
247324 sqlglot .expressions .EQ (this = in_val , expression = in_exp ),
@@ -381,9 +458,7 @@ def _evaluate_sql_node(
381458
382459 return False
383460
384- def get_matching_files (
385- self , expression : str , * , dialect : str = "duckdb"
386- ) -> set [str ]:
461+ def get_matching_files (self , exp : sqlglot .expressions .Expression | str ) -> set [str ]:
387462 """
388463 Get a set of files that match the given SQL expression.
389464 Args:
@@ -392,15 +467,23 @@ def get_matching_files(
392467 Returns:
393468 A set of filenames that match the expression.
394469 """
395- parse_result = sqlglot .parse_one (expression , dialect = dialect )
470+ if isinstance (exp , str ):
471+ # Parse the expression if it is a string.
472+ expression = sqlglot .parse_one (exp , dialect = "duckdb" )
473+ else :
474+ expression = exp
475+
476+ assert isinstance (expression , sqlglot .expressions .Expression ), (
477+ f"Expected a sqlglot expression, got { type (expression )} "
478+ )
396479
397480 # Simplify the parsed expression, move all of the literals to the right side
398- parse_result = sqlglot .optimizer .simplify .simplify (parse_result )
481+ expression = sqlglot .optimizer .simplify .simplify (expression )
399482
400483 matching_files = set ()
401484
402485 for filename , file_info in self .files :
403- eval_result = self ._evaluate_sql_node (parse_result , file_info )
486+ eval_result = self ._evaluate_sql_node (expression , file_info )
404487 if eval_result is None or eval_result is True :
405488 # If the expression evaluates to True or cannot be evaluated, add the file
406489 # to the result set since the caller will be able to filter the rows further.
0 commit comments