11from collections .abc import Callable
22from dataclasses import dataclass
3- from typing import Any
3+ from typing import Any , Generator
44import duckdb
55import pyarrow as pa
66import sqlglot
@@ -476,7 +476,9 @@ def _evaluate_sql_node(
476476 f"Supported types: Connector, Predicate, Not, Boolean, Case, Null"
477477 )
478478
479- def get_matching_files (self , exp : sqlglot .expressions .Expression | str ) -> set [str ]:
479+ def get_matching_files (
480+ self , exp : sqlglot .expressions .Expression | str , * , dialect : str = "duckdb"
481+ ) -> Generator [str , None , None ]:
480482 """
481483 Get a set of files that match the given SQL expression.
482484 Args:
@@ -487,24 +489,19 @@ def get_matching_files(self, exp: sqlglot.expressions.Expression | str) -> set[s
487489 """
488490 if isinstance (exp , str ):
489491 # Parse the expression if it is a string.
490- expression = sqlglot .parse_one (exp , dialect = "duckdb" )
492+ expression = sqlglot .parse_one (exp , dialect = dialect )
491493 else :
492494 expression = exp
493495
494- assert isinstance (expression , sqlglot .expressions .Expression ), (
495- f"Expected a sqlglot expression, got { type (expression )} "
496- )
496+ if not isinstance (expression , sqlglot .expressions .Expression ):
497+ raise ValueError (f"Expected a sqlglot expression, got { type (expression )} " )
497498
498499 # Simplify the parsed expression, move all of the literals to the right side
499500 expression = sqlglot .optimizer .simplify .simplify (expression )
500501
501- matching_files = set ()
502-
503502 for filename , file_info in self .files :
504503 eval_result = self ._evaluate_sql_node (expression , file_info )
505504 if eval_result is None or eval_result is True :
506505 # If the expression evaluates to True or cannot be evaluated, add the file
507506 # to the result set since the caller will be able to filter the rows further.
508- matching_files .add (filename )
509-
510- return matching_files
507+ yield filename
0 commit comments