diff --git a/docs/src/piccolo/schema/index.rst b/docs/src/piccolo/schema/index.rst index ec9b887e6..f7a4c07e5 100644 --- a/docs/src/piccolo/schema/index.rst +++ b/docs/src/piccolo/schema/index.rst @@ -8,6 +8,7 @@ The schema is how you define your database tables, columns and relationships. ./defining ./column_types + ./indexes ./m2m ./one_to_one ./advanced diff --git a/docs/src/piccolo/schema/indexes.rst b/docs/src/piccolo/schema/indexes.rst new file mode 100644 index 000000000..850ffa6b6 --- /dev/null +++ b/docs/src/piccolo/schema/indexes.rst @@ -0,0 +1,27 @@ +======= +Indexes +======= + +Single column index +=================== + +Index can be added to a single column using the ``index=True`` +argument of ``Column``: + +.. code-block:: python + + class Band(Table): + name = Varchar(index=True) + +Multi-column (composite) index +============================== + +To manually create and drop multi-column indexes, we can use Piccolo's +built-in methods ``create_index`` and ``drop_index``. + +If you are using automatic migrations, we can specify the ``CompositeIndex`` +argument and they handle the creation and deletion of these composite indexes. + +.. currentmodule:: piccolo.composite_index + +.. autoclass:: CompositeIndex \ No newline at end of file diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index a3e80500f..1b6d74793 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -5,14 +5,17 @@ from piccolo.apps.migrations.auto.operations import ( AddColumn, + AddCompositeIndex, AlterColumn, DropColumn, + DropCompositeIndex, ) from piccolo.apps.migrations.auto.serialisation import ( deserialise_params, serialise_params, ) from piccolo.columns.base import Column +from piccolo.composite_index import Composite from piccolo.table import Table, create_table_class @@ -62,6 +65,12 @@ class TableDelta: add_columns: list[AddColumn] = field(default_factory=list) drop_columns: list[DropColumn] = field(default_factory=list) alter_columns: list[AlterColumn] = field(default_factory=list) + add_composite_indexes: list[AddCompositeIndex] = field( + default_factory=list + ) + drop_composite_indexes: list[DropCompositeIndex] = field( + default_factory=list + ) def __eq__(self, value: TableDelta) -> bool: # type: ignore """ @@ -92,6 +101,22 @@ def __eq__(self, value) -> bool: return False +@dataclass +class CompositeIndexComparison: + composite_index: Composite + + def __hash__(self) -> int: + return self.composite_index.__hash__() + + def __eq__(self, value) -> bool: + if isinstance(value, CompositeIndexComparison): + return ( + self.composite_index._meta.name + == value.composite_index._meta.name + ) + return False + + @dataclass class DiffableTable: """ @@ -103,6 +128,7 @@ class DiffableTable: tablename: str schema: Optional[str] = None columns: list[Column] = field(default_factory=list) + composite_indexes: list[Composite] = field(default_factory=list) previous_class_name: Optional[str] = None def __post_init__(self) -> None: @@ -196,10 +222,54 @@ def __sub__(self, value: DiffableTable) -> TableDelta: ) ) + add_composite_indexes = [ + AddCompositeIndex( + table_class_name=self.class_name, + composite_index_name=i.composite_index._meta.name, + composite_index_class_name=i.composite_index.__class__.__name__, # noqa: E501 + composite_index_class=i.composite_index.__class__, + params=i.composite_index._meta.params, + schema=self.schema, + ) + for i in sorted( + { + CompositeIndexComparison(composite_index=composite_index) + for composite_index in self.composite_indexes + } + - { + CompositeIndexComparison(composite_index=composite_index) + for composite_index in value.composite_indexes + }, + key=lambda x: x.composite_index._meta.name, + ) + ] + + drop_composite_indexes = [ + DropCompositeIndex( + table_class_name=self.class_name, + composite_index_name=i.composite_index._meta.name, + tablename=value.tablename, + schema=self.schema, + ) + for i in sorted( + { + CompositeIndexComparison(composite_index=composite_index) + for composite_index in value.composite_indexes + } + - { + CompositeIndexComparison(composite_index=composite_index) + for composite_index in self.composite_indexes + }, + key=lambda x: x.composite_index._meta.name, + ) + ] + return TableDelta( add_columns=add_columns, drop_columns=drop_columns, alter_columns=alter_columns, + add_composite_indexes=add_composite_indexes, + drop_composite_indexes=drop_composite_indexes, ) def __hash__(self) -> int: @@ -225,10 +295,14 @@ def to_table_class(self) -> type[Table]: """ Converts the DiffableTable into a Table subclass. """ + class_members: dict[str, Any] = {} + for column in self.columns: + class_members[column._meta.name] = column + for composite_index in self.composite_indexes: + class_members[composite_index._meta.name] = composite_index + return create_table_class( class_name=self.class_name, class_kwargs={"tablename": self.tablename, "schema": self.schema}, - class_members={ - column._meta.name: column for column in self.columns - }, + class_members=class_members, ) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 53f4a3a7c..cf6911269 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -11,12 +11,14 @@ AlterColumn, ChangeTableSchema, DropColumn, + DropCompositeIndex, RenameColumn, RenameTable, ) from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.composite_index import Composite, CompositeIndex from piccolo.engine import engine_finder from piccolo.query import Query from piccolo.query.base import DDL @@ -125,6 +127,69 @@ def table_class_names(self) -> list[str]: return list({i.table_class_name for i in self.alter_columns}) +@dataclass +class AddCompositeIndexClass: + composite_index: Composite + table_class_name: str + tablename: str + schema: Optional[str] + + +@dataclass +class AddCompositeIndexCollection: + add_composite_indexes: list[AddCompositeIndexClass] = field( + default_factory=list + ) + + def append(self, add_composite_index: AddCompositeIndexClass): + self.add_composite_indexes.append(add_composite_index) + + def for_table_class_name( + self, table_class_name: str + ) -> list[AddCompositeIndexClass]: + return [ + i + for i in self.add_composite_indexes + if i.table_class_name == table_class_name + ] + + def composite_indexes_for_table_class_name( + self, table_class_name: str + ) -> list[Composite]: + return [ + i.composite_index + for i in self.add_composite_indexes + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.add_composite_indexes}) + + +@dataclass +class DropCompositeIndexCollection: + drop_composite_indexes: list[DropCompositeIndex] = field( + default_factory=list + ) + + def append(self, drop_composite_index: DropCompositeIndex): + self.drop_composite_indexes.append(drop_composite_index) + + def for_table_class_name( + self, table_class_name: str + ) -> list[DropCompositeIndex]: + return [ + i + for i in self.drop_composite_indexes + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> list[str]: + return list({i.table_class_name for i in self.drop_composite_indexes}) + + AsyncFunction = Callable[[], Coroutine] @@ -171,6 +236,12 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) + add_composite_indexes: AddCompositeIndexCollection = field( + default_factory=AddCompositeIndexCollection + ) + drop_composite_indexes: DropCompositeIndexCollection = field( + default_factory=DropCompositeIndexCollection + ) raw: list[Union[Callable, AsyncFunction]] = field(default_factory=list) raw_backwards: list[Union[Callable, AsyncFunction]] = field( default_factory=list @@ -358,6 +429,48 @@ def alter_column( ) ) + def add_composite_index( + self, + table_class_name: str, + tablename: str, + composite_index_name: str, + composite_index_class: type[Composite], + params: dict[str, Any], + schema: Optional[str] = None, + ): + if composite_index_class is CompositeIndex: + composite_index = CompositeIndex(**params) + else: + raise ValueError("Unrecognised composite index type") + + composite_index._meta.name = composite_index_name + composite_index.columns = params["columns"] + + self.add_composite_indexes.append( + AddCompositeIndexClass( + composite_index=composite_index, + table_class_name=table_class_name, + tablename=tablename, + schema=schema, + ) + ) + + def drop_composite_index( + self, + table_class_name: str, + tablename: str, + composite_index_name: str, + schema: Optional[str] = None, + ): + self.drop_composite_indexes.append( + DropCompositeIndex( + table_class_name=table_class_name, + composite_index_name=composite_index_name, + tablename=tablename, + schema=schema, + ) + ) + def add_raw(self, raw: Union[Callable, AsyncFunction]): """ A migration manager can execute arbitrary functions or coroutines when @@ -417,8 +530,8 @@ async def _print_query(query: Union[DDL, Query, SchemaDDLBase]): async def _run_query(self, query: Union[DDL, Query, SchemaDDLBase]): """ - If MigrationManager is in preview mode then it just print the query - instead of executing it. + If MigrationManager is not in the preview mode, + executes the queries. else, prints the query. """ if self.preview: await self._print_query(query) @@ -786,16 +899,26 @@ async def _run_add_tables(self, backwards: bool = False): add_columns: list[AddColumnClass] = ( self.add_columns.for_table_class_name(add_table.class_name) ) + add_composite_indexes: list[AddCompositeIndexClass] = ( + self.add_composite_indexes.for_table_class_name( + add_table.class_name + ) + ) + class_members: dict[str, Any] = {} + for add_column in add_columns: + class_members[add_column.column._meta.name] = add_column.column + for add_composite_index in add_composite_indexes: + class_members[ + add_composite_index.composite_index._meta.name + ] = add_composite_index.composite_index + _Table: type[Table] = create_table_class( class_name=add_table.class_name, class_kwargs={ "tablename": add_table.tablename, "schema": add_table.schema, }, - class_members={ - add_column.column._meta.name: add_column.column - for add_column in add_columns - }, + class_members=class_members, ) table_classes.append(_Table) @@ -968,6 +1091,121 @@ async def _run_change_table_schema(self, backwards: bool = False): ) ) + async def _run_add_composite_indexes(self, backwards: bool = False): + if backwards: + for ( + add_composite_index + ) in self.add_composite_indexes.add_composite_indexes: + if add_composite_index.table_class_name in [ + i.class_name for i in self.add_tables + ]: + # Don't reverse the add composite index as the table + # is going to be deleted. + continue + + _Table = create_table_class( + class_name=add_composite_index.table_class_name, + class_kwargs={ + "tablename": add_composite_index.tablename, + "schema": add_composite_index.schema, + }, + ) + + await self._run_query( + _Table.drop_index( + columns=add_composite_index.composite_index._meta.params[ # noqa: E501 + "columns" + ], + name=add_composite_index.composite_index._meta.name, + ) + ) + else: + for ( + table_class_name + ) in self.add_composite_indexes.table_class_names: + add_composite_indexes: list[AddCompositeIndexClass] = ( + self.add_composite_indexes.for_table_class_name( + table_class_name + ) + ) + + _Table = create_table_class( + class_name=add_composite_indexes[0].table_class_name, + class_kwargs={ + "tablename": add_composite_indexes[0].tablename, + "schema": add_composite_indexes[0].schema, + }, + ) + + await self._run_query( + _Table.create_index( + columns=add_composite_indexes[ + 0 + ].composite_index._meta.params["columns"], + method=add_composite_indexes[ + 0 + ].composite_index._meta.params["index_type"], + name=add_composite_indexes[ + 0 + ].composite_index._meta.name, + ) + ) + + async def _run_drop_composite_indexes(self, backwards: bool = False): + if backwards: + for ( + drop_composite_index + ) in self.drop_composite_indexes.drop_composite_indexes: + _Table = await self.get_table_from_snapshot( + table_class_name=drop_composite_index.table_class_name, + app_name=self.app_name, + offset=-1, + ) + composite_index_to_restore = ( + _Table._meta.get_composite_index_by_name( + drop_composite_index.composite_index_name + ) + ) + + await self._run_query( + _Table.create_index( + columns=composite_index_to_restore._meta.params[ + "columns" + ], + method=composite_index_to_restore._meta.params[ + "index_type" + ], + name=composite_index_to_restore._meta._name, + ) + ) + else: + for ( + table_class_name + ) in self.drop_composite_indexes.table_class_names: + composite_indexes = ( + self.drop_composite_indexes.for_table_class_name( + table_class_name + ) + ) + + if not composite_indexes: + continue + + _Table = create_table_class( + class_name=table_class_name, + class_kwargs={ + "tablename": composite_indexes[0].tablename, + "schema": composite_indexes[0].schema, + }, + ) + + await self._run_query( + _Table.drop_index( + columns=[], # placeholder value + name=composite_indexes[0].composite_index_name, + ) + ) + async def run(self, backwards: bool = False): direction = "backwards" if backwards else "forwards" if self.preview: @@ -1003,6 +1241,8 @@ async def run(self, backwards: bool = False): await self._run_drop_columns(backwards=backwards) await self._run_drop_tables(backwards=backwards) await self._run_rename_columns(backwards=backwards) + await self._run_add_composite_indexes(backwards=backwards) + await self._run_drop_composite_indexes(backwards=backwards) # We can remove this for cockroach when resolved. # https://github.com/cockroachdb/cockroach/issues/49351 # "ALTER COLUMN TYPE is not supported inside a transaction" diff --git a/piccolo/apps/migrations/auto/operations.py b/piccolo/apps/migrations/auto/operations.py index 84e0d261a..603ccf38e 100644 --- a/piccolo/apps/migrations/auto/operations.py +++ b/piccolo/apps/migrations/auto/operations.py @@ -2,6 +2,7 @@ from typing import Any, Optional from piccolo.columns.base import Column +from piccolo.composite_index import Composite @dataclass @@ -63,3 +64,21 @@ class AddColumn: column_class: type[Column] params: dict[str, Any] schema: Optional[str] = None + + +@dataclass +class AddCompositeIndex: + table_class_name: str + composite_index_name: str + composite_index_class_name: str + composite_index_class: type[Composite] + params: dict[str, Any] + schema: Optional[str] = None + + +@dataclass +class DropCompositeIndex: + table_class_name: str + composite_index_name: str + tablename: str + schema: Optional[str] = None diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 7dbc9a469..22dbc02ef 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -629,6 +629,76 @@ def rename_columns(self) -> AlterStatements: return alter_statements + @property + def add_composite_indexes(self) -> AlterStatements: + response: list[str] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for add_composite_index in delta.add_composite_indexes: + params = serialise_params(add_composite_index.params) + cleaned_params = params.params + extra_imports.extend(params.extra_imports) + extra_definitions.extend(params.extra_definitions) + + composite_index_class = ( + add_composite_index.composite_index_class + ) + extra_imports.append( + Import( + module=composite_index_class.__module__, + target=composite_index_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{composite_index_class.__name__.upper()}", + None, + ), + ) + ) + + schema_str = ( + "None" + if add_composite_index.schema is None + else f'"{add_composite_index.schema}"' + ) + + response.append( + f"manager.add_composite_index(table_class_name='{table.class_name}', tablename='{table.tablename}', composite_index_name='{add_composite_index.composite_index_name}', composite_index_class={composite_index_class.__name__}, params={str(cleaned_params)}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + + @property + def drop_composite_indexes(self) -> AlterStatements: + response = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for drop_composite_index in delta.drop_composite_indexes: + schema_str = ( + "None" + if drop_composite_index.schema is None + else f'"{drop_composite_index.schema}"' + ) + + response.append( + f"manager.drop_composite_index(table_class_name='{table.class_name}', tablename='{table.tablename}', composite_index_name='{drop_composite_index.composite_index_name}', schema={schema_str})" # noqa: E501 + ) + return AlterStatements(statements=response) + ########################################################################### @property @@ -680,6 +750,48 @@ def new_table_columns(self) -> AlterStatements: extra_definitions=extra_definitions, ) + @property + def new_table_composite_indexes(self) -> AlterStatements: + new_tables: list[DiffableTable] = list( + set(self.schema) - set(self.schema_snapshot) + ) + + response: list[str] = [] + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] + for table in new_tables: + if ( + table.class_name + in self.rename_tables_collection.new_class_names + ): + continue + + for composite_index in table.composite_indexes: + extra_imports.append( + Import( + module=composite_index.__class__.__module__, + target=composite_index.__class__.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{composite_index.__class__.__name__.upper()}", # noqa: E501 + None, + ), + ) + ) + + schema_str = ( + "None" if table.schema is None else f'"{table.schema}"' + ) + + response.append( + f"manager.add_composite_index(table_class_name='{table.class_name}', tablename='{table.tablename}', composite_index_name='{composite_index._meta.name}', composite_index_class={composite_index.__class__.__name__}, params={composite_index._meta.params}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + ########################################################################### def get_alter_statements(self) -> list[AlterStatements]: @@ -692,10 +804,13 @@ def get_alter_statements(self) -> list[AlterStatements]: "Renamed tables": self.rename_tables, "Tables which changed schema": self.change_table_schemas, "Created table columns": self.new_table_columns, + "Created table composite indexes": self.new_table_composite_indexes, # noqa: E501 "Dropped columns": self.drop_columns, "Columns added to existing tables": self.add_columns, "Renamed columns": self.rename_columns, "Altered columns": self.alter_columns, + "Dropped composite index": self.drop_composite_indexes, + "Composite index added to existing tables": self.add_composite_indexes, # noqa: E501 } for message, statements in alter_statements.items(): diff --git a/piccolo/apps/migrations/auto/schema_snapshot.py b/piccolo/apps/migrations/auto/schema_snapshot.py index 5bf343063..a546f66fe 100644 --- a/piccolo/apps/migrations/auto/schema_snapshot.py +++ b/piccolo/apps/migrations/auto/schema_snapshot.py @@ -111,4 +111,22 @@ def get_snapshot(self) -> list[DiffableTable]: rename_column.new_db_column_name ) + add_composite_indexes = manager.add_composite_indexes.composite_indexes_for_table_class_name( # noqa: E501 + table.class_name + ) + table.composite_indexes.extend(add_composite_indexes) + + drop_composite_indexes = ( + manager.drop_composite_indexes.for_table_class_name( + table.class_name + ) + ) + for drop_composite_index in drop_composite_indexes: + table.composite_indexes = [ + i + for i in table.composite_indexes + if i._meta.name + != drop_composite_index.composite_index_name + ] + return tables diff --git a/piccolo/apps/migrations/commands/new.py b/piccolo/apps/migrations/commands/new.py index 172de96ed..1e59c412d 100644 --- a/piccolo/apps/migrations/commands/new.py +++ b/piccolo/apps/migrations/commands/new.py @@ -193,6 +193,7 @@ async def get_alter_statements( class_name=i.__name__, tablename=i._meta.tablename, columns=i._meta.non_default_columns, + composite_indexes=i._meta.composite_indexes, schema=i._meta.schema, ) for i in app_config.table_classes diff --git a/piccolo/columns/indexes.py b/piccolo/columns/indexes.py index 79060277f..cdfc448d8 100644 --- a/piccolo/columns/indexes.py +++ b/piccolo/columns/indexes.py @@ -11,6 +11,7 @@ class IndexMethod(str, Enum): hash = "hash" gist = "gist" gin = "gin" + brin = "brin" def __str__(self): return f"{self.__class__.__name__}.{self.name}" diff --git a/piccolo/composite_index.py b/piccolo/composite_index.py new file mode 100644 index 000000000..b821bfc9e --- /dev/null +++ b/piccolo/composite_index.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from piccolo.columns.indexes import IndexMethod + + +class Composite: + """ + All other composite indexes inherit from ``Composite``. + Don't use it directly. + """ + + def __init__(self, **kwargs) -> None: + self._meta = CompositeMeta(params=kwargs) + + def __hash__(self): + return hash(self._meta.name) + + +@dataclass +class CompositeMeta: + """ + This is used to store info about the composite index. + """ + + # Used for representing the table in migrations. + params: dict[str, Any] = field(default_factory=dict) + + # Set by the Table Metaclass: + _name: Optional[str] = None + + @property + def name(self) -> str: + if not self._name: + raise ValueError( + "`_name` isn't defined - the Table Metaclass should set it." + ) + return self._name + + @name.setter + def name(self, value: str): + self._name = value + + +class CompositeIndex(Composite): + + def __init__( + self, + columns: list[str], + index_type: IndexMethod = IndexMethod.btree, + **kwargs, + ) -> None: + """ + Add a composite index to multiple columns. For example:: + + from piccolo.columns import Varchar, Boolean + from piccolo.composite_index import CompositeIndex + from piccolo.table import Table + + class Album(Table): + name = Varchar() + released = Boolean(default=False) + name_released_idx = CompositeIndex(["name", "released"]) + + This way we create composite index ``name_released_idx`` + on ``Album`` table. + + To drop the composite index, simply delete or comment out + the composite index argument and perform another migration. + + :param columns: + The table column name that should be in composite index. + + :param index_type: + Index type for a composite index. Default to ``B-tree``. + An Postgres extension must be created to use an index + type other than a ``B-tree``. + + """ + self.columns = columns + self.index_type = index_type + kwargs.update({"columns": columns, "index_type": index_type}) + super().__init__(**kwargs) diff --git a/piccolo/query/methods/create_index.py b/piccolo/query/methods/create_index.py index 64ae4b4d8..82e17147f 100644 --- a/piccolo/query/methods/create_index.py +++ b/piccolo/query/methods/create_index.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union from piccolo.columns import Column from piccolo.columns.indexes import IndexMethod @@ -18,11 +18,13 @@ def __init__( columns: Union[list[Column], list[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, + name: Optional[str] = None, **kwargs, ): self.columns = columns self.method = method self.if_not_exists = if_not_exists + self.name = name super().__init__(table, **kwargs) @property @@ -42,7 +44,10 @@ def prefix(self) -> str: @property def postgres_ddl(self) -> Sequence[str]: column_names = self.column_names - index_name = self.table._get_index_name(column_names) + if self.name is not None: + index_name = self.name + else: + index_name = self.table._get_index_name(column_names) tablename = self.table._meta.get_formatted_tablename() method_name = self.method.value column_names_str = ", ".join([f'"{i}"' for i in self.column_names]) @@ -60,7 +65,10 @@ def cockroach_ddl(self) -> Sequence[str]: @property def sqlite_ddl(self) -> Sequence[str]: column_names = self.column_names - index_name = self.table._get_index_name(column_names) + if self.name is not None: + index_name = self.name + else: + index_name = self.table._get_index_name(column_names) tablename = self.table._meta.get_formatted_tablename() method_name = self.method.value diff --git a/piccolo/query/methods/drop_index.py b/piccolo/query/methods/drop_index.py index 1b2d9f082..1fe65eb60 100644 --- a/piccolo/query/methods/drop_index.py +++ b/piccolo/query/methods/drop_index.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union from piccolo.columns.base import Column from piccolo.query.base import Query @@ -17,10 +17,12 @@ def __init__( table: type[Table], columns: Union[list[Column], list[str]], if_exists: bool = True, + name: Optional[str] = None, **kwargs, ): self.columns = columns self.if_exists = if_exists + self.name = name super().__init__(table, **kwargs) @property @@ -32,7 +34,10 @@ def column_names(self) -> list[str]: @property def default_querystrings(self) -> Sequence[QueryString]: column_names = self.column_names - index_name = self.table._get_index_name(column_names) + if self.name is not None: + index_name = self.name + else: + index_name = self.table._get_index_name(column_names) query = "DROP INDEX" if self.if_exists: query += " IF EXISTS" diff --git a/piccolo/table.py b/piccolo/table.py index bdfda2cdd..988f12ba7 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -29,6 +29,7 @@ ) from piccolo.columns.readable import Readable from piccolo.columns.reference import LAZY_COLUMN_REFERENCES +from piccolo.composite_index import Composite from piccolo.custom_types import TableInstance from piccolo.engine import Engine, engine_finder from piccolo.query import ( @@ -84,6 +85,7 @@ class TableMeta: primary_key: Column = field(default_factory=Column) json_columns: list[Union[JSON, JSONB]] = field(default_factory=list) secret_columns: list[Column] = field(default_factory=list) + composite_indexes: list[Composite] = field(default_factory=list) auto_update_columns: list[Column] = field(default_factory=list) tags: list[str] = field(default_factory=list) help_text: Optional[str] = None @@ -173,6 +175,17 @@ def get_column_by_name(self, name: str) -> Column: return column_object + def get_composite_index_by_name(self, name: str) -> Composite: + """ + Returns a composite index which matches the given name. + """ + for composite_index in self.composite_indexes: + if composite_index._meta.name == name: + return composite_index + raise ValueError( + f"No matching composite index found with name == {name}" + ) + def get_auto_update_values(self) -> dict[Column, Any]: """ If columns have ``auto_update`` defined, then we retrieve these values. @@ -279,6 +292,7 @@ def __init_subclass__( auto_update_columns: list[Column] = [] primary_key: Optional[Column] = None m2m_relationships: list[M2M] = [] + composite_indexes: list[Composite] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -331,6 +345,10 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) + if isinstance(attribute, Composite): + attribute._meta.name = attribute_name + composite_indexes.append(attribute) + if not primary_key: primary_key = cls._create_serial_primary_key() setattr(cls, "id", primary_key) @@ -355,6 +373,7 @@ def __init_subclass__( _db=db, m2m_relationships=m2m_relationships, schema=schema, + composite_indexes=composite_indexes, ) for foreign_key_column in foreign_key_columns: @@ -1355,6 +1374,7 @@ def create_index( columns: Union[list[Column], list[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, + name: Optional[str] = None, ) -> CreateIndex: """ Create a table index. If multiple columns are specified, this refers @@ -1370,6 +1390,7 @@ def create_index( columns=columns, method=method, if_not_exists=if_not_exists, + name=name, ) @classmethod @@ -1377,6 +1398,7 @@ def drop_index( cls, columns: Union[list[Column], list[str]], if_exists: bool = True, + name: Optional[str] = None, ) -> DropIndex: """ Drop a table index. If multiple columns are specified, this refers @@ -1387,7 +1409,12 @@ def drop_index( await Band.drop_index([Band.name]) """ - return DropIndex(table=cls, columns=columns, if_exists=if_exists) + return DropIndex( + table=cls, + columns=columns, + if_exists=if_exists, + name=name, + ) ########################################################################### diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 0952e1895..527bd428c 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -10,6 +10,8 @@ from piccolo.columns import Text, Varchar from piccolo.columns.base import OnDelete, OnUpdate from piccolo.columns.column_types import ForeignKey +from piccolo.columns.indexes import IndexMethod +from piccolo.composite_index import CompositeIndex from piccolo.conf.apps import AppConfig from piccolo.engine import engine_finder from piccolo.query.constraints import get_fk_constraint_rules @@ -1099,6 +1101,172 @@ def test_change_table_schema(self): ' - 1 [preview forwards]... CREATE SCHEMA IF NOT EXISTS "schema_1"\nALTER TABLE "manager" SET SCHEMA "schema_1"\n', # noqa: E501 ) + @engines_only("postgres", "cockroach") + def test_add_table_with_composite_index(self): + self.run_sync("DROP TABLE IF EXISTS musician;") + + manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager.add_composite_index( + table_class_name="Musician", + tablename="musician", + composite_index_name="name_label", + composite_index_class=CompositeIndex, + params={ + "columns": ["name", "label"], + "index_type": IndexMethod.btree, + }, + schema=None, + ) + asyncio.run(manager.run()) + result = self.run_sync( + "SELECT indexname FROM pg_indexes WHERE tablename='musician'" + ) + self.assertEqual(result[-1]["indexname"], "name_label") + self.assertEqual(len(result), 2) + + # Reverse + asyncio.run(manager.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_add_composite_index_to_existing_table( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + self.run_sync("DROP TABLE IF EXISTS musician;") + + manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + asyncio.run(manager.run()) + + manager_2 = MigrationManager() + manager_2.add_composite_index( + table_class_name="Musician", + tablename="musician", + composite_index_name="name_label", + composite_index_class=CompositeIndex, + params={ + "columns": ["name", "label"], + "index_type": IndexMethod.btree, + }, + ) + asyncio.run(manager_2.run()) + + sql = "SELECT indexname FROM pg_indexes WHERE tablename='musician'" + + result = self.run_sync(sql) + self.assertEqual(result[-1]["indexname"], "name_label") + self.assertEqual(len(result), 2) + + # Reverse + asyncio.run(manager_2.run(backwards=True)) + + result = self.run_sync(sql) + self.assertEqual(result, [{"indexname": "musician_pkey"}]) + self.assertEqual(len(result), 1) + + asyncio.run(manager_2.run()) + + result = self.run_sync(sql) + self.assertEqual(result[-1]["indexname"], "name_label") + self.assertEqual(len(result), 2) + + manager_2 = MigrationManager() + manager_2.drop_composite_index( + table_class_name="Musician", + tablename="musician", + composite_index_name="name_label", + ) + asyncio.run(manager_2.run()) + result = self.run_sync(sql) + self.assertEqual(result, [{"indexname": "musician_pkey"}]) + self.assertEqual(len(result), 1) + + self.run_sync("DROP TABLE IF EXISTS musician;") + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_composite_index_from_existing_table( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + self.run_sync("DROP TABLE IF EXISTS musician;") + + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_composite_index( + table_class_name="Musician", + tablename="musician", + composite_index_name="name_label", + composite_index_class=CompositeIndex, + params={"columns": ["name", "label"]}, + ) + asyncio.run(manager_1.run()) + + sql = "SELECT indexname FROM pg_indexes WHERE tablename='musician'" + + result = self.run_sync(sql) + self.assertEqual(result[-1]["indexname"], "name_label") + self.assertEqual(len(result), 2) + + manager_2 = MigrationManager() + manager_2.drop_composite_index( + table_class_name="Musician", + tablename="musician", + composite_index_name="name_label", + ) + asyncio.run(manager_2.run()) + + result = self.run_sync(sql) + self.assertEqual(result, [{"indexname": "musician_pkey"}]) + self.assertEqual(len(result), 1) + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + class TestWrapInTransaction(IsolatedAsyncioTestCase): diff --git a/tests/postgres_conf.py b/tests/postgres_conf.py index af21dcbc5..36763b7eb 100644 --- a/tests/postgres_conf.py +++ b/tests/postgres_conf.py @@ -8,7 +8,7 @@ "host": os.environ.get("PG_HOST", "localhost"), "port": os.environ.get("PG_PORT", "5432"), "user": os.environ.get("PG_USER", "postgres"), - "password": os.environ.get("PG_PASSWORD", ""), + "password": os.environ.get("PG_PASSWORD", "postgres"), "database": os.environ.get("PG_DATABASE", "piccolo"), } )