Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when adding column that uses existing changing enum #77

Merged
merged 9 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test_on_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ on:
- tests/**
- alembic_postgresql_enum/**
- .github/workflows/test_on_push.yaml
pull_request: { }

jobs:
run_tests:
Expand Down
10 changes: 10 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:latest

COPY ./alembic_postgresql_enum ./alembic_postgresql_enum
COPY ./tests ./tests

WORKDIR ./tests

RUN pip install -r requirements.txt

ENTRYPOINT pytest
17 changes: 10 additions & 7 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def compare_enums(
for each defined enum that has changed new entries when compared to its
declared version.
"""
assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)

if autogen_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping"
Expand All @@ -49,19 +56,15 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
schema = default_schema

definitions = get_defined_enums(autogen_context.connection, schema)
declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection)
declarations = get_declared_enums(
autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops
)

create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def sync_changed_enums(
enum_name,
list(old_values),
list(new_values),
list(affected_columns),
sorted( # Sort references alphabetically for consistency of generated text
affected_columns,
key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name),
AlexandrovRoman marked this conversation as resolved.
Show resolved Hide resolved
),
)
upgrade_ops.ops.append(op)
12 changes: 10 additions & 2 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Optional

import sqlalchemy
from alembic.operations.ops import UpgradeOps
from sqlalchemy import MetaData
from sqlalchemy.dialects import postgresql

from alembic_postgresql_enum.get_enum_data.get_default_from_alembic_ops import get_just_added_defaults
from alembic_postgresql_enum.sql_commands.column_default import get_column_default

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_declared_enums(
schema: str,
default_schema: str,
connection: "Connection",
upgrade_ops: Optional[UpgradeOps] = None,
) -> DeclaredEnumValues:
"""
Return a dict mapping SQLAlchemy declared enumeration types to the set of their values
Expand All @@ -62,6 +64,8 @@ def get_declared_enums(
Default schema name, likely will be "public"
:param connection:
Database connection
:param upgrade_ops:
Upgrade operations in current migration
:returns DeclaredEnumValues:
enum_values: {
"my_enum": tuple(["a", "b", "c"]),
Expand All @@ -75,6 +79,8 @@ def get_declared_enums(
enum_name_to_values = dict()
enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set)

just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema)

if isinstance(metadata, list):
metadata_list = metadata
else:
Expand Down Expand Up @@ -103,6 +109,8 @@ def get_declared_enums(

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
if (table_schema, table.name, column.name) in just_added_defaults:
column_default = just_added_defaults[table_schema, table.name, column.name]
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Optional, Dict, Tuple

from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, AlterColumnOp, CreateTableOp
from sqlalchemy import Column

SchemaName = str
TableName = str
ColumnName = str
ColumnLocation = Tuple[SchemaName, TableName, ColumnName]


def _get_default_from_add_column_op(op: AddColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.column.server_default is None:
raise AttributeError("No new server_default")
return (
(op.schema or default_schema, op.table_name, op.column.name),
op.column.server_default.arg.text, # type: ignore[attr-defined]
)


def _get_default_from_alter_column_op(op: AlterColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.modify_server_default is False:
raise AttributeError("No new server_default")
return (op.schema or default_schema, op.table_name, op.column_name), op.modify_server_default


def _get_default_from_column(column: Column, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if column.server_default is None:
raise AttributeError("No new server_default")
return (
(column.table.schema or default_schema, column.table.name, column.name),
column.server_default.arg.text, # type: ignore[attr-defined]
)


def get_just_added_defaults(
upgrade_ops: Optional[UpgradeOps], default_schema: str
) -> Dict[ColumnLocation, Optional[str]]:
"""Get all server defaults that will be added in current migration"""
if upgrade_ops is None:
return {}

new_server_defaults = {}

for operations_group in upgrade_ops.ops:
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
try:
column_location, column_new_default = _get_default_from_add_column_op(operation, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operation, AlterColumnOp):
try:
column_location, column_new_default = _get_default_from_alter_column_op(
operation, default_schema
)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
try:
column_location, column_new_default = _get_default_from_column(column, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

return new_server_defaults
40 changes: 32 additions & 8 deletions alembic_postgresql_enum/sql_commands/column_default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import TYPE_CHECKING, Union, List, Tuple

import sqlalchemy
Expand Down Expand Up @@ -54,15 +55,38 @@ def rename_default_if_required(
enum_name: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
is_array = default_value.endswith("[]")
if schema:
new_enum = f"{schema}.{enum_name}"
else:
new_enum = enum_name

if default_value.startswith("ARRAY["):
column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename)
column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value)
return column_default_value

if default_value.endswith("[]"):

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

return f"{column_default_value}::{new_enum}[]"

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

for old_value, new_value in enum_values_to_rename:
column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'")
column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"')
column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

suffix = "[]" if is_array else ""
if schema:
return f"{column_default_value}::{schema}.{enum_name}{suffix}"
return f"{column_default_value}::{enum_name}{suffix}"
return f"{column_default_value}::{new_enum}"


def _replace_strings_in_quotes(
old_default: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
for old_value, new_value in enum_values_to_rename:
old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'")
old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"')
return old_default
28 changes: 28 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
version: "3.8"

services:
run-tests:
# entrypoint: pytest
build: .
stdin_open: true
tty: true
command:
- pytest
environment:
DATABASE_URI: postgresql://test_user:test_password@db:5432/test_db
depends_on:
- db
links:
- "db:database"
db:
image: postgres:12
environment:
POSTGRES_DB: "test_db"
POSTGRES_USER: "test_user"
POSTGRES_PASSWORD: "test_password"
PGUSER: "postgres"

ports:
- "5432:5432"
volumes:
- ./api/db/postgres-test-data:/var/lib/postgresql/data
15 changes: 13 additions & 2 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# How to run tests

Create database for testing
# With `docker compose`

Just run:
```commandline
docker compose up --build --exit-code-from run-tests
```

# Manually

## Create database

Start postgres through docker compose:

## Env variables

Expand All @@ -24,4 +35,4 @@ pip install -R tests/requirements.txt
Run tests
```
pytest
```
```
33 changes: 33 additions & 0 deletions tests/sync_enum_values/test_rename_default_if_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,36 @@ def test_array_default_value_with_schema():
old_default_value = """'{}'::test.order_status_old[]"""

assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]"""


def test_caps_array_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_default_value_with_schema():
old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_with_schema():
old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]"""

assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]"""
Loading
Loading