|
2 | 2 | # -*- coding: utf-8 -*- |
3 | 3 | from typing import Any, Generic, Iterable, Sequence, Type |
4 | 4 |
|
5 | | -from sqlalchemy import Row, RowMapping, Select, delete, select, update |
| 5 | +from sqlalchemy import Row, RowMapping, Select, delete, inspect, select, update |
6 | 6 | from sqlalchemy.ext.asyncio import AsyncSession |
7 | 7 |
|
8 | | -from sqlalchemy_crud_plus.errors import MultipleResultsError |
| 8 | +from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError |
9 | 9 | from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema |
10 | 10 | from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class CRUDPlus(Generic[Model]): |
14 | 14 | def __init__(self, model: Type[Model]): |
15 | 15 | self.model = model |
| 16 | + self.primary_key = self._get_primary_key() |
| 17 | + |
| 18 | + def _get_primary_key(self): |
| 19 | + """ |
| 20 | + Dynamically retrieve the primary key column(s) for the model. |
| 21 | + """ |
| 22 | + mapper = inspect(self.model) |
| 23 | + primary_key = mapper.primary_key |
| 24 | + if len(primary_key) == 1: |
| 25 | + return primary_key[0] |
| 26 | + else: |
| 27 | + raise CompositePrimaryKeysError('Composite primary keys are not supported') |
16 | 28 |
|
17 | 29 | async def create_model( |
18 | 30 | self, |
@@ -69,7 +81,7 @@ async def select_model(self, session: AsyncSession, pk: int) -> Model | None: |
69 | 81 | :param pk: The database primary key value. |
70 | 82 | :return: |
71 | 83 | """ |
72 | | - stmt = select(self.model).where(self.model.id == pk) |
| 84 | + stmt = select(self.model).where(self.primary_key == pk) |
73 | 85 | query = await session.execute(stmt) |
74 | 86 | return query.scalars().first() |
75 | 87 |
|
@@ -166,7 +178,7 @@ async def update_model( |
166 | 178 | instance_data = obj |
167 | 179 | else: |
168 | 180 | instance_data = obj.model_dump(exclude_unset=True) |
169 | | - stmt = update(self.model).where(self.model.id == pk).values(**instance_data) |
| 181 | + stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) |
170 | 182 | result = await session.execute(stmt) |
171 | 183 | if commit: |
172 | 184 | await session.commit() |
@@ -218,7 +230,7 @@ async def delete_model( |
218 | 230 | :param commit: If `True`, commits the transaction immediately. Default is `False`. |
219 | 231 | :return: |
220 | 232 | """ |
221 | | - stmt = delete(self.model).where(self.model.id == pk) |
| 233 | + stmt = delete(self.model).where(self.primary_key == pk) |
222 | 234 | result = await session.execute(stmt) |
223 | 235 | if commit: |
224 | 236 | await session.commit() |
|
0 commit comments