Skip to content

Commit 27aedbd

Browse files
committed
fix: change to a generator
1 parent 33742df commit 27aedbd

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ files = [
6666
planner = Planner(files)
6767

6868
# Filter files based on SQL expressions
69-
matching_files = planner.get_matching_files("sales_amount > 40000 AND region = 'US'")
69+
matching_files = set(planner.get_matching_files("sales_amount > 40000 AND region = 'US'"))
7070
print(matching_files) # {'data_2023_q1.parquet', 'data_2023_q2.parquet'}
7171

7272
# More complex queries
73-
matching_files = planner.get_matching_files("region IN ('EU', 'UK')")
73+
matching_files = set(planner.get_matching_files("region IN ('EU', 'UK')"))
7474
print(matching_files) # {'data_2023_q2.parquet'}
7575
```
7676

src/query_farm_sql_scan_planning/planner.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass
3-
from typing import Any
3+
from typing import Any, Generator
44
import duckdb
55
import pyarrow as pa
66
import sqlglot
@@ -476,7 +476,9 @@ def _evaluate_sql_node(
476476
f"Supported types: Connector, Predicate, Not, Boolean, Case, Null"
477477
)
478478

479-
def get_matching_files(self, exp: sqlglot.expressions.Expression | str) -> set[str]:
479+
def get_matching_files(
480+
self, exp: sqlglot.expressions.Expression | str, *, dialect: str = "duckdb"
481+
) -> Generator[str, None, None]:
480482
"""
481483
Get a set of files that match the given SQL expression.
482484
Args:
@@ -487,24 +489,19 @@ def get_matching_files(self, exp: sqlglot.expressions.Expression | str) -> set[s
487489
"""
488490
if isinstance(exp, str):
489491
# Parse the expression if it is a string.
490-
expression = sqlglot.parse_one(exp, dialect="duckdb")
492+
expression = sqlglot.parse_one(exp, dialect=dialect)
491493
else:
492494
expression = exp
493495

494-
assert isinstance(expression, sqlglot.expressions.Expression), (
495-
f"Expected a sqlglot expression, got {type(expression)}"
496-
)
496+
if not isinstance(expression, sqlglot.expressions.Expression):
497+
raise ValueError(f"Expected a sqlglot expression, got {type(expression)}")
497498

498499
# Simplify the parsed expression, move all of the literals to the right side
499500
expression = sqlglot.optimizer.simplify.simplify(expression)
500501

501-
matching_files = set()
502-
503502
for filename, file_info in self.files:
504503
eval_result = self._evaluate_sql_node(expression, file_info)
505504
if eval_result is None or eval_result is True:
506505
# If the expression evaluates to True or cannot be evaluated, add the file
507506
# to the result set since the caller will be able to filter the rows further.
508-
matching_files.add(filename)
509-
510-
return matching_files
507+
yield filename

src/query_farm_sql_scan_planning/test_planner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ def test_scan_planning(
305305
filter_obj = Planner(sample_files)
306306

307307
# Apply the filter
308-
result = filter_obj.get_matching_files(sqlglot.parse_one(clause, dialect="duckdb"))
308+
result = set(
309+
filter_obj.get_matching_files(sqlglot.parse_one(clause, dialect="duckdb"))
310+
)
309311

310312
# Check if files were filtered as expected
311313
if result != expected_files:

0 commit comments

Comments
 (0)