diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index e21a516..611da42 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -101,7 +101,7 @@ def get_declared_enums( enum_name_to_values[column_type.name] = get_enum_values(column_type) table_schema = table.schema or default_schema - column_default = get_column_default(connection, table.schema, table.name, column.name) + column_default = get_column_default(connection, table_schema, table.name, column.name) enum_name_to_table_references[column_type.name].add( TableReference( table_schema=table_schema, diff --git a/pyproject.toml b/pyproject.toml index 28bface..60fec30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "alembic-postgresql-enum" -version = "1.1.1" +version = "1.1.2" description = "Alembic autogenerate support for creation, alteration and deletion of enums" authors = ["RustyGuard"] license = "MIT" diff --git a/tests/base/render_and_run.py b/tests/base/render_and_run.py index 7062847..ff39f60 100644 --- a/tests/base/render_and_run.py +++ b/tests/base/render_and_run.py @@ -7,7 +7,7 @@ from sqlalchemy import MetaData from sqlalchemy.dialects import postgresql -from alembic_postgresql_enum import ColumnType +from alembic_postgresql_enum import ColumnType, TableReference from tests.utils.migration_context import create_migration_context if TYPE_CHECKING: @@ -46,6 +46,7 @@ def compare_and_run( "sa": sqlalchemy, "postgresql": postgresql, "ColumnType": ColumnType, + "TableReference": TableReference, }, ) exec( @@ -55,5 +56,6 @@ def compare_and_run( "sa": sqlalchemy, "postgresql": postgresql, "ColumnType": ColumnType, + "TableReference": TableReference, }, ) diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py new file mode 100644 index 0000000..fb2b7cb --- /dev/null +++ b/tests/base/run_migration_test_abc.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from tests.base.render_and_run import compare_and_run + +if TYPE_CHECKING: + from sqlalchemy import Connection +from sqlalchemy import MetaData + + +class CompareAndRunTestCase(ABC): + @abstractmethod + def get_database_schema(self) -> MetaData: ... + + @abstractmethod + def get_target_schema(self) -> MetaData: ... + + @abstractmethod + def get_expected_upgrade(self) -> str: ... + + @abstractmethod + def get_expected_downgrade(self) -> str: ... + + def test_run(self, connection: "Connection"): + database_schema = self.get_database_schema() + target_schema = self.get_target_schema() + + database_schema.create_all(connection) + + compare_and_run( + connection, + target_schema, + expected_upgrade=self.get_expected_upgrade(), + expected_downgrade=self.get_expected_downgrade(), + ) diff --git a/tests/sync_enum_values/test_render.py b/tests/sync_enum_values/test_render.py index 4a1ed99..f100cae 100644 --- a/tests/sync_enum_values/test_render.py +++ b/tests/sync_enum_values/test_render.py @@ -1,14 +1,19 @@ +import enum from typing import TYPE_CHECKING from alembic import autogenerate from alembic.autogenerate import api from alembic.operations import ops +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import declarative_base from alembic_postgresql_enum.get_enum_data import TableReference from alembic_postgresql_enum.operations import SyncEnumValuesOp +from tests.base.run_migration_test_abc import CompareAndRunTestCase if TYPE_CHECKING: from sqlalchemy import Connection +from sqlalchemy import MetaData, Column, Integer from tests.schemas import ( get_schema_with_enum_variants, @@ -161,3 +166,58 @@ def test_rename_enum_value_diff_tuple(connection: "Connection"): assert affected_columns == [ TableReference(table_schema=DEFAULT_SCHEMA, table_name=USER_TABLE_NAME, column_name=USER_STATUS_COLUMN_NAME) ] + + +class TestServerDefault(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + schema = MetaData() + + Base = declarative_base(metadata=schema) + + class MyEnum(enum.Enum): + one = 1 + two = 2 + three = 3 + + class ExampleTable(Base): + __tablename__ = "example_table" + test_field = Column(Integer, primary_key=True, autoincrement=False) + enum_field = Column(postgresql.ENUM(MyEnum, name="my_enum"), server_default=MyEnum.one.name) + + return schema + + def get_target_schema(self) -> MetaData: + schema = MetaData() + + Base = declarative_base(metadata=schema) + + class NewMyEnum(enum.Enum): + one = 1 + two = 2 + three = 3 + four = 4 # added + + class ExampleTable(Base): + __tablename__ = "example_table" + test_field = Column(Integer, primary_key=True, autoincrement=False) + enum_field = Column(postgresql.ENUM(NewMyEnum, name="my_enum"), server_default=NewMyEnum.one.name) + + return schema + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['one', 'two', 'three', 'four'], + [TableReference(table_schema='public', table_name='example_table', column_name='enum_field', existing_server_default="'one'::my_enum")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['one', 'two', 'three'], + [TableReference(table_schema='public', table_name='example_table', column_name='enum_field', existing_server_default="'one'::my_enum")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """