Skip to content

Commit 9b33397

Browse files
refactor code base, update dependencies
1 parent 8204abb commit 9b33397

File tree

7 files changed

+595
-536
lines changed

7 files changed

+595
-536
lines changed

fastapi_sa_orm_filter/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@
22
from fastapi_sa_orm_filter.main import FilterCore # noqa
33
from fastapi_sa_orm_filter.operators import Operators as ops # noqa
44

5-
__version__ = "0.2.2"
5+
__version__ = "0.2.3"
6+
7+
from .main import FilterCore as FilterCore # noqa
8+
from .operators import Operators as Operators # noqa
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC, abstractmethod
2+
3+
from fastapi_sa_orm_filter.dto import ParsedFilter
4+
from fastapi_sa_orm_filter.operators import Operators as ops
5+
6+
7+
class QueryParser(ABC):
8+
9+
@abstractmethod
10+
def __init__(self, custom_filter: str, allowed_filters: dict[str, list[ops]]) -> None:
11+
self.custom_filter = custom_filter
12+
self.allowed_filters = allowed_filters
13+
14+
@abstractmethod
15+
def get_parsed_filter(self) -> tuple[list[list[ParsedFilter]], list[str]] | tuple[list, list]:
16+
pass

fastapi_sa_orm_filter/main.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from starlette import status
99
from sqlalchemy.sql import Select
1010

11+
from fastapi_sa_orm_filter.dto import ParsedFilter
1112
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
13+
from fastapi_sa_orm_filter.interfaces import QueryParser
1214
from fastapi_sa_orm_filter.operators import Operators as ops
13-
from fastapi_sa_orm_filter.parsers import FilterQueryParser, OrderByQueryParser
14-
from fastapi_sa_orm_filter.sa_expression_builder import SAFilterExpressionBuilder
15+
from fastapi_sa_orm_filter.parsers import StringQueryParser
16+
from fastapi_sa_orm_filter.sa_expression_builder import SAFilterExpressionBuilder, SAOrderByExpressionBuilder
1517

1618

1719
class FilterCore:
@@ -39,7 +41,7 @@ def __init__(
3941
"""
4042
self.model = model
4143
self._allowed_filters = allowed_filters
42-
self.select_query_part = select_query_part
44+
self.select_sql_query = select_query_part
4345

4446
def get_query(self, custom_filter: str) -> Select[Any]:
4547
"""
@@ -63,56 +65,56 @@ def get_query(self, custom_filter: str) -> Select[Any]:
6365
model.category == 'Medicine'
6466
).order_by(model.id.desc())
6567
"""
66-
split_query = self._split_by_order_by(custom_filter)
6768
try:
68-
complete_query = self._get_complete_query(*split_query)
69+
query_parser = self._get_query_parser(custom_filter)
70+
filter_query, order_by_query = query_parser.get_parsed_filter()
71+
complete_query = self._get_complete_query(filter_query, order_by_query)
6972
except SAFilterOrmException as e:
7073
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.args[0])
7174
return complete_query
7275

73-
def _get_complete_query(self, filter_query_str: str, order_by_query_str: str | None = None) -> Select[Any]:
74-
select_query_part = self.get_select_query_part()
75-
filter_query_part = self._get_filter_query_part(filter_query_str)
76-
complete_query = select_query_part.filter(*filter_query_part)
77-
group_query_part = self.get_group_by_query_part()
78-
if group_query_part:
79-
complete_query = complete_query.group_by(*group_query_part)
80-
if order_by_query_str is not None:
81-
order_by_query = self.get_order_by_query_part(order_by_query_str)
82-
complete_query = complete_query.order_by(*order_by_query)
76+
def _get_complete_query(
77+
self, filter_query: list[list[ParsedFilter]] | list, order_by_query: list[str] | list
78+
) -> Select[Any]:
79+
select_sa_query = self.get_select_query_part()
80+
filter_sa_query = self._get_filter_sa_query(filter_query)
81+
group_by_sa_query = self._get_group_by_sa_query()
82+
order_by_sa_query = self._get_order_by_sa_query(order_by_query)
83+
84+
complete_query = (
85+
select_sa_query
86+
.filter(*filter_sa_query)
87+
.group_by(*group_by_sa_query)
88+
.order_by(*order_by_sa_query)
89+
)
8390
return complete_query
8491

8592
def get_select_query_part(self) -> Select[Any]:
86-
if self.select_query_part is not None:
87-
return self.select_query_part
93+
if self.select_sql_query is not None:
94+
return self.select_sql_query
8895
return select(self.model)
8996

90-
def _get_filter_query_part(self, filter_query_str: str) -> list[Any]:
91-
conditions = self._get_filter_query(filter_query_str)
92-
if len(conditions) == 0:
93-
return conditions
97+
def _get_filter_sa_query(self, filter_query: list[list[ParsedFilter]] | list) -> list[BinaryExpression] | list:
98+
if len(filter_query) == 0:
99+
return []
100+
sa_builder = SAFilterExpressionBuilder(self.model)
101+
conditions = sa_builder.get_expressions(filter_query)
94102
return [or_(*conditions)]
95103

96-
def get_group_by_query_part(self) -> list:
97-
return []
104+
def _get_order_by_sa_query(self, order_by_query: list[str] | list) -> list[UnaryExpression]:
105+
if len(order_by_query) == 0:
106+
return []
107+
order_by_parser = SAOrderByExpressionBuilder(self.model)
108+
return order_by_parser.get_order_by_query(order_by_query)
98109

99-
def get_order_by_query_part(self, order_by_query_str: str) -> list[UnaryExpression]:
100-
order_by_parser = OrderByQueryParser(self.model)
101-
return order_by_parser.get_order_by_query(order_by_query_str)
110+
def _get_group_by_sa_query(self) -> list[BinaryExpression] | list:
111+
group_query_part = self.get_group_by_query_part()
112+
if len(group_query_part) == 0:
113+
return []
114+
return group_query_part
102115

103-
def _get_filter_query(self, custom_filter: str) -> list[BinaryExpression]:
104-
filter_conditions = []
105-
if custom_filter == "":
106-
return filter_conditions
116+
def get_group_by_query_part(self) -> list:
117+
return []
107118

108-
parser = FilterQueryParser(custom_filter, self._allowed_filters)
109-
parsed_filters = parser.get_parsed_query()
110-
sa_builder = SAFilterExpressionBuilder(self.model)
111-
return sa_builder.get_expressions(parsed_filters)
112-
113-
@staticmethod
114-
def _split_by_order_by(query) -> list:
115-
split_query = [query_part.strip("&") for query_part in query.split("order_by=")]
116-
if len(split_query) > 2:
117-
raise SAFilterOrmException("Use only one order_by directive")
118-
return split_query
119+
def _get_query_parser(self, custom_filter: str) -> QueryParser:
120+
return StringQueryParser(custom_filter, self._allowed_filters)

fastapi_sa_orm_filter/parsers.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,73 @@
1-
from sqlalchemy.orm import DeclarativeBase
2-
from sqlalchemy.sql.elements import UnaryExpression
3-
41
from fastapi_sa_orm_filter.dto import ParsedFilter
52
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
3+
from fastapi_sa_orm_filter.interfaces import QueryParser
64
from fastapi_sa_orm_filter.operators import Operators as ops
7-
from fastapi_sa_orm_filter.operators import OrderSequence
85

96

10-
class OrderByQueryParser:
7+
class StringQueryParser(QueryParser):
8+
9+
def __init__(self, custom_filter: str, allowed_filters: dict[str, list[ops]]) -> None:
10+
self.custom_filter = custom_filter
11+
self.allowed_filters = allowed_filters
12+
13+
def get_parsed_filter(self) -> tuple[list[list[ParsedFilter]], list[str]] | tuple[list, list]:
14+
parsed_filter = []
15+
parsed_order_by = []
16+
17+
if self.custom_filter == "":
18+
return parsed_filter, parsed_order_by
19+
20+
split_query = [query_part.strip("&") for query_part in self.custom_filter.split("order_by=")]
21+
22+
if len(split_query) > 2:
23+
raise SAFilterOrmException("Use only one order_by directive")
24+
25+
parsed_filter = self._get_filter_query_part(split_query[0])
26+
27+
if len(split_query) == 2:
28+
parsed_order_by = self._get_order_by_query_part(split_query[1])
29+
30+
return parsed_filter, parsed_order_by
31+
32+
def _get_filter_query_part(self, filter_query_str: str) -> list[list[ParsedFilter]] | list:
33+
if filter_query_str == "":
34+
return []
35+
filter_parser = StringFilterQueryParser(self.allowed_filters)
36+
return filter_parser.get_parsed_query(filter_query_str)
37+
38+
def _get_order_by_query_part(self, order_by_query_str: str) -> list[str] | list:
39+
if order_by_query_str == "":
40+
return []
41+
order_by_parser = StringOrderByQueryParser()
42+
return order_by_parser.get_order_by_query(order_by_query_str)
43+
44+
45+
class StringOrderByQueryParser:
1146
"""
1247
Class parse order by part of request query string.
1348
"""
14-
def __init__(self, model: type[DeclarativeBase]) -> None:
15-
self._model = model
16-
17-
def get_order_by_query(self, order_by_query_str: str) -> list[UnaryExpression]:
18-
order_by_fields = self._validate_order_by_fields(order_by_query_str)
19-
order_by_query = []
20-
for field in order_by_fields:
21-
if '-' in field:
22-
column = getattr(self._model, field.strip('-'))
23-
order_by_query.append(getattr(column, OrderSequence.desc)())
24-
else:
25-
column = getattr(self._model, field.strip('+'))
26-
order_by_query.append(getattr(column, OrderSequence.asc)())
27-
return order_by_query
28-
29-
def _validate_order_by_fields(self, order_by_query_str: str) -> list[str]:
30-
"""
31-
:return:
32-
[
33-
+field_name,
34-
-field_name
35-
]
36-
"""
37-
order_by_fields = order_by_query_str.split(",")
38-
model_fields = self._model.__table__.columns.keys()
39-
for field in order_by_fields:
40-
field = field.strip('+').strip('-')
41-
if field in model_fields:
42-
continue
43-
raise SAFilterOrmException(f"Incorrect order_by field name {field} for model {self._model.__name__}")
44-
return order_by_fields
49+
def get_order_by_query(self, order_by_query_str: str) -> list[str]:
50+
return order_by_query_str.split(",")
4551

4652

47-
class FilterQueryParser:
53+
class StringFilterQueryParser:
4854
"""
4955
Class parse filter part of request query string.
5056
"""
5157

5258
def __init__(
53-
self, query: str,
54-
allowed_filters: dict[str, list[ops]]
59+
self, allowed_filters: dict[str, list[ops]]
5560
) -> None:
56-
self._query = query
5761
self._allowed_filters = allowed_filters
5862

59-
def get_parsed_query(self) -> list[list[ParsedFilter]]:
63+
def get_parsed_query(self, filter_query_str: str) -> list[list[ParsedFilter]]:
6064
"""
6165
:return:
6266
[
6367
[ParsedFilter, ParsedFilter, ParsedFilter]
6468
]
6569
"""
66-
and_blocks = self._parse_by_conjunctions()
70+
and_blocks = self._parse_by_conjunctions(filter_query_str)
6771
parsed_query = []
6872
for and_block in and_blocks:
6973
parsed_and_blocks = []
@@ -74,7 +78,7 @@ def get_parsed_query(self) -> list[list[ParsedFilter]]:
7478
parsed_query.append(parsed_and_blocks)
7579
return parsed_query
7680

77-
def _parse_by_conjunctions(self) -> list[list[str]]:
81+
def _parse_by_conjunctions(self, filter_query_str: str) -> list[list[str]]:
7882
"""
7983
Split request query string by 'OR' and 'AND' conjunctions
8084
to divide query string to field's conditions
@@ -84,7 +88,7 @@ def _parse_by_conjunctions(self) -> list[list[str]]:
8488
['field_name__operator=value']
8589
]
8690
"""
87-
and_blocks = [block.split("&") for block in self._query.split("|")]
91+
and_blocks = [block.split("&") for block in filter_query_str.split("|")]
8892
return and_blocks
8993

9094
def _parse_expression(

fastapi_sa_orm_filter/sa_expression_builder.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from typing import Any
33

44
import pydantic
5-
from pydantic import create_model
5+
from pydantic import create_model, BaseModel
66
from pydantic._internal._model_construction import ModelMetaclass
7-
from sqlalchemy import inspect, BinaryExpression, and_
7+
from sqlalchemy import inspect, BinaryExpression, and_, UnaryExpression
88
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute
99
from sqlalchemy_to_pydantic import sqlalchemy_to_pydantic
1010

1111
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
12-
from fastapi_sa_orm_filter.operators import Operators as ops
12+
from fastapi_sa_orm_filter.operators import Operators as ops, OrderSequence
1313

1414

1515
class SAFilterExpressionBuilder:
@@ -32,8 +32,8 @@ def get_expressions(self, parsed_filters) -> list[BinaryExpression]:
3232
model = self.get_relation_model(and_filter.relation)
3333
table = model.__tablename__
3434
column = self.get_column(model, and_filter.field_name)
35-
serialized_dict = self.serialize_expression_value(table, column, and_filter.operator, and_filter.value)
36-
value = serialized_dict[column.name]
35+
serialized_dict = self.serialize_expression_value(table, and_filter.field_name, and_filter.operator, and_filter.value)
36+
value = serialized_dict[and_filter.field_name]
3737
expr = self.get_orm_for_field(column, and_filter.operator, value)
3838
and_expr.append(expr)
3939
or_expr.append(and_(*and_expr))
@@ -83,7 +83,7 @@ class model.__name__(BaseModel):
8383

8484
return serializers
8585

86-
def get_relations_classes(self) -> list:
86+
def get_relations_classes(self) -> list[type[DeclarativeBase]]:
8787
return [relation[1].mapper.class_ for relation in self._relationships]
8888

8989
def get_orm_for_field(
@@ -97,7 +97,7 @@ def get_orm_for_field(
9797
return getattr(column, ops[operator].value)(value)
9898

9999
def serialize_expression_value(
100-
self, table: str, column: InstrumentedAttribute, operator: str, value: str
100+
self, table: str, field_name: str, operator: str, value: str
101101
) -> dict[str, Any]:
102102
"""
103103
Serialize expression value from string to python type value,
@@ -112,14 +112,14 @@ def serialize_expression_value(
112112
model_serializer = self._model_serializers[table]["optional_model"]
113113
else:
114114
model_serializer = self._model_serializers[table]["optional_list_model"]
115-
return model_serializer(**{column.name: value}).model_dump(exclude_none=True)
115+
return model_serializer(**{field_name: value}).model_dump(exclude_none=True)
116116
except pydantic.ValidationError as e:
117117
raise SAFilterOrmException(json.loads(e.json()))
118118
except ValueError:
119119
raise SAFilterOrmException(f"Incorrect filter value '{value}'")
120120

121121
@staticmethod
122-
def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = False):
122+
def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = False) -> BaseModel:
123123
fields = {}
124124
for k, v in pydantic_serializer.model_fields.items():
125125
origin_annotation = getattr(v, 'annotation')
@@ -129,3 +129,37 @@ def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = Fals
129129
fields[k] = (origin_annotation, None)
130130
pydantic_model = create_model(model.__name__, **fields)
131131
return pydantic_model
132+
133+
134+
class SAOrderByExpressionBuilder:
135+
136+
def __init__(self, model: type[DeclarativeBase]) -> None:
137+
self._model = model
138+
139+
def get_order_by_query(self, order_by_query: list[str]) -> list[UnaryExpression]:
140+
order_by_fields = self._validate_order_by_fields(order_by_query)
141+
order_by_sql_query = []
142+
for field in order_by_fields:
143+
if '-' in field:
144+
column = getattr(self._model, field.strip('-'))
145+
order_by_sql_query.append(getattr(column, OrderSequence.desc)())
146+
else:
147+
column = getattr(self._model, field.strip('+'))
148+
order_by_sql_query.append(getattr(column, OrderSequence.asc)())
149+
return order_by_sql_query
150+
151+
def _validate_order_by_fields(self, order_by_fields: list[str]) -> list[str]:
152+
"""
153+
:return:
154+
[
155+
+field_name,
156+
-field_name
157+
]
158+
"""
159+
model_fields = self._model.__table__.columns.keys()
160+
for field in order_by_fields:
161+
field = field.strip('+').strip('-')
162+
if field in model_fields:
163+
continue
164+
raise SAFilterOrmException(f"Incorrect order_by field name {field} for model {self._model.__name__}")
165+
return order_by_fields

0 commit comments

Comments
 (0)