Skip to content

Commit

Permalink
Change black line length to 120
Browse files Browse the repository at this point in the history
  • Loading branch information
RustyGuard committed Jan 6, 2024
1 parent 05221ff commit f39dd28
Show file tree
Hide file tree
Showing 29 changed files with 83 additions and 244 deletions.
6 changes: 3 additions & 3 deletions alembic_postgresql_enum/add_create_type_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class ReprWorkaround(postgresql.ENUM):
__module__ = "sqlalchemy.dialects.postgresql"

def __repr__(self):
return f"{super().__repr__()[:-1]}, create_type=False)".replace(
"ReprWorkaround", "ENUM"
).replace(", metadata=MetaData()", "")
return f"{super().__repr__()[:-1]}, create_type=False)".replace("ReprWorkaround", "ENUM").replace(
", metadata=MetaData()", ""
)


def inject_repr_into_enums(column: Column):
Expand Down
12 changes: 3 additions & 9 deletions alembic_postgresql_enum/add_postgres_using_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def reverse(self):


@renderers.dispatch_for(PostgresUsingAlterColumnOp)
def _postgres_using_alter_column(
autogen_context: AutogenContext, op: ops.AlterColumnOp
) -> str:
def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterColumnOp) -> str:
alter_column_expression = render._alter_column(autogen_context, op)

postgresql_using = op.kw.get("postgresql_using", None)
Expand All @@ -42,9 +40,7 @@ def _postgres_using_alter_column(

def add_postgres_using_to_alter_operation(op: AlterColumnOp):
op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}"
log.info(
"postgresql_using added to %r.%r alteration", op.table_name, op.column_name
)
log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name)
op.__class__ = PostgresUsingAlterColumnOp


Expand All @@ -54,7 +50,5 @@ def add_postgres_using_to_text(upgrade_ops: UpgradeOps):
if isinstance(group_op, ModifyTableOps):
for i, op in enumerate(group_op.ops):
if isinstance(op, AlterColumnOp):
if isinstance(op.existing_type, String) and column_type_is_enum(
op.modify_type
):
if isinstance(op.existing_type, String) and column_type_is_enum(op.modify_type):
add_postgres_using_to_alter_operation(op)
9 changes: 2 additions & 7 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ def compare_enums(
# Issue #40
# Add schema if it is gonna be created inside the migration
for operations_group in upgrade_ops.ops:
if (
isinstance(operations_group, CreateTableOp)
and operations_group.schema not in schema_names
):
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

for schema in schema_names:
Expand All @@ -48,9 +45,7 @@ def compare_enums(
schema = default_schema

definitions = get_defined_enums(autogen_context.connection, schema)
declarations = get_declared_enums(
autogen_context.metadata, schema, default_schema, autogen_context.connection
)
declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection)

create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ def create_new_enums(
for name, new_values in declared_enums.items():
if name not in defined_enums:
log.info("Detected added enum %r with values %r", name, new_values)
upgrade_ops.ops.insert(
0, CreateEnumOp(name=name, schema=schema, enum_values=new_values)
)
upgrade_ops.ops.insert(0, CreateEnumOp(name=name, schema=schema, enum_values=new_values))
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ def drop_unused_enums(
for name, new_values in defined_enums.items():
if name not in declared_enums:
log.info("Detected unused enum %r with values %r", name, new_values)
upgrade_ops.ops.append(
DropEnumOp(name=name, schema=schema, enum_values=new_values)
)
upgrade_ops.ops.append(DropEnumOp(name=name, schema=schema, enum_values=new_values))
16 changes: 4 additions & 12 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def get_enum_values(enum_type: sqlalchemy.Enum) -> "Tuple[str, ...]":
dialect = postgresql.dialect

def value_processor(value):
return enum_type.process_bind_param(
enum_type.impl.result_processor(dialect, enum_type)(value), dialect
)
return enum_type.process_bind_param(enum_type.impl.result_processor(dialect, enum_type)(value), dialect)

else:

Expand Down Expand Up @@ -74,9 +72,7 @@ def get_declared_enums(
}
"""
enum_name_to_values = dict()
enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(
set
)
enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set)

if isinstance(metadata, list):
metadata_list = metadata
Expand Down Expand Up @@ -104,13 +100,9 @@ def get_declared_enums(
if column_type.name not in enum_name_to_values:
enum_name_to_values[column_type.name] = get_enum_values(column_type)

column_default = get_column_default(
connection, schema, table.name, column.name
)
column_default = get_column_default(connection, schema, table.name, column.name)
enum_name_to_table_references[column_type.name].add(
TableReference(
table.name, column.name, column_type_wrapper, column_default
)
TableReference(table.name, column.name, column_type_wrapper, column_default)
)

return DeclaredEnumValues(
Expand Down
5 changes: 1 addition & 4 deletions alembic_postgresql_enum/get_enum_data/defined_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,4 @@ def get_defined_enums(connection: "Connection", schema: str) -> EnumNamesToValue
"my_enum": tuple(["a", "b", "c"]),
}
"""
return {
_remove_schema_prefix(name, schema): tuple(values)
for name, values in get_all_enums(connection, schema)
}
return {_remove_schema_prefix(name, schema): tuple(values) for name, values in get_all_enums(connection, schema)}
29 changes: 6 additions & 23 deletions alembic_postgresql_enum/operations/sync_enum_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def _set_enum_values(
rename_type(connection, schema, enum_name, temporary_enum_name)
create_type(connection, schema, enum_name, new_values)

create_comparison_operators(
connection, schema, enum_name, temporary_enum_name, enum_values_to_rename
)
create_comparison_operators(connection, schema, enum_name, temporary_enum_name, enum_values_to_rename)

for table_reference in affected_columns:
column_default = table_reference.existing_server_default
Expand All @@ -102,9 +100,7 @@ def _set_enum_values(
) from error

if column_default is not None:
column_default = rename_default_if_required(
schema, column_default, enum_name, enum_values_to_rename
)
column_default = rename_default_if_required(schema, column_default, enum_name, enum_values_to_rename)

set_default(connection, schema, table_reference, column_default)

Expand Down Expand Up @@ -155,21 +151,13 @@ def sync_enum_values(
column_type = affected_column[2]
else:
column_type = ColumnType.COMMON
column_default = get_column_default(
connection, schema, table_name, column_name
)
table_references.append(
TableReference(
table_name, column_name, column_type, column_default
)
)
column_default = get_column_default(connection, schema, table_name, column_name)
table_references.append(TableReference(table_name, column_name, column_type, column_default))

elif isinstance(affected_column, TableReference):
table_references.append(affected_column)
else:
raise ValueError(
"Affected columns must contain tuples or TableReferences"
)
raise ValueError("Affected columns must contain tuples or TableReferences")

cls._set_enum_values(
connection,
Expand All @@ -190,12 +178,7 @@ def to_diff_tuple(self) -> Tuple[Any, ...]:

@property
def is_column_type_import_needed(self) -> bool:
return any(
(
affected_column.is_column_type_import_needed
for affected_column in self.affected_columns
)
)
return any((affected_column.is_column_type_import_needed for affected_column in self.affected_columns))


@alembic.autogenerate.render.renderers.dispatch_for(SyncEnumValuesOp)
Expand Down
8 changes: 2 additions & 6 deletions alembic_postgresql_enum/sql_commands/column_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,8 @@ def rename_default_if_required(
column_default_value = default_value[: default_value.find("::")]

for old_value, new_value in enum_values_to_rename:
column_default_value = column_default_value.replace(
f"'{old_value}'", f"'{new_value}'"
)
column_default_value = column_default_value.replace(
f'"{old_value}"', f'"{new_value}"'
)
column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'")
column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"')

suffix = "[]" if is_array else ""
if schema:
Expand Down
4 changes: 1 addition & 3 deletions alembic_postgresql_enum/sql_commands/comparison_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,4 @@ def drop_comparison_operators(
old_enum_name: str,
):
for _, comparison_function_name in OPERATORS_TO_CREATE:
_drop_comparison_operator(
connection, schema, enum_name, old_enum_name, comparison_function_name
)
_drop_comparison_operator(connection, schema, enum_name, old_enum_name, comparison_function_name)
14 changes: 3 additions & 11 deletions alembic_postgresql_enum/sql_commands/enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,11 @@ def drop_type(connection: "Connection", schema: str, type_name: str):
connection.execute(sqlalchemy.text(f"""DROP TYPE {schema}.{type_name}"""))


def rename_type(
connection: "Connection", schema: str, type_name: str, new_type_name: str
):
connection.execute(
sqlalchemy.text(
f"""ALTER TYPE {schema}.{type_name} RENAME TO {new_type_name}"""
)
)
def rename_type(connection: "Connection", schema: str, type_name: str, new_type_name: str):
connection.execute(sqlalchemy.text(f"""ALTER TYPE {schema}.{type_name} RENAME TO {new_type_name}"""))


def create_type(
connection: "Connection", schema: str, type_name: str, enum_values: List[str]
):
def create_type(connection: "Connection", schema: str, type_name: str, enum_values: List[str]):
connection.execute(
sqlalchemy.text(
f"""CREATE TYPE {schema}.{type_name} AS ENUM({', '.join(f"'{value}'" for value in enum_values)})"""
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ alembic = ">=1.7"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 120
8 changes: 2 additions & 6 deletions tests/base/render_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ def compare_and_run(
expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ")
expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ")

assert (
upgrade_code == expected_upgrade
), f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert (
downgrade_code == expected_downgrade
), f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"
assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

exec(
upgrade_code,
Expand Down
12 changes: 3 additions & 9 deletions tests/get_enum_data/test_get_declared_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,18 @@ def test_with_user_schema(connection: "Connection"):
enum_variants = ["active", "passive"]
declared_schema = get_schema_with_enum_variants(enum_variants)

function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)
function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)

assert function_result.enum_values == {USER_STATUS_ENUM_NAME: tuple(enum_variants)}
assert function_result.enum_table_references == {
USER_STATUS_ENUM_NAME: frozenset(
(TableReference(USER_TABLE_NAME, USER_STATUS_COLUMN_NAME),)
)
USER_STATUS_ENUM_NAME: frozenset((TableReference(USER_TABLE_NAME, USER_STATUS_COLUMN_NAME),))
}


def test_with_multiple_enums(connection: "Connection"):
declared_enum_values = get_declared_enum_values_with_orders_and_users()
declared_schema = get_schema_by_declared_enum_values(declared_enum_values)

function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)
function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)

assert function_result == declared_enum_values
8 changes: 2 additions & 6 deletions tests/get_enum_data/test_type_decorator_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,11 @@ def get_schema_with_custom_enum() -> MetaData:
def test_get_declared_enums_for_custom_enum(connection: "Connection"):
declared_schema = get_schema_with_custom_enum()

function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)
function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)

assert function_result.enum_values == {
# All declared enum variants must be taken from OrderDeliveryStatus values, see ValuesEnum
ORDER_DELIVERY_STATUS_ENUM_NAME: tuple(
enum_item.value for enum_item in OrderDeliveryStatus
)
ORDER_DELIVERY_STATUS_ENUM_NAME: tuple(enum_item.value for enum_item in OrderDeliveryStatus)
}
assert function_result.enum_table_references == {
ORDER_DELIVERY_STATUS_ENUM_NAME: frozenset(
Expand Down
20 changes: 5 additions & 15 deletions tests/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def get_declared_enum_values_with_orders_and_users() -> DeclaredEnumValues:
)
),
"order_status_enum": frozenset((TableReference("orders", "order_status"),)),
"car_color_enum": frozenset(
(TableReference("cars", "colors", ColumnType.ARRAY),)
),
"car_color_enum": frozenset((TableReference("cars", "colors", ColumnType.ARRAY),)),
},
)

Expand All @@ -119,28 +117,20 @@ def _enum_column_factory(
column_type: ColumnType,
) -> Column:
if column_type == ColumnType.COMMON:
return Column(
column_name, postgresql.ENUM(*target.enum_values[enum_name], name=enum_name)
)
return Column(column_name, postgresql.ENUM(*target.enum_values[enum_name], name=enum_name))
return Column(
column_name,
column_type.value(
postgresql.ENUM(*target.enum_values[enum_name], name=enum_name)
),
column_type.value(postgresql.ENUM(*target.enum_values[enum_name], name=enum_name)),
)


def get_schema_by_declared_enum_values(target: DeclaredEnumValues) -> MetaData:
schema = MetaData()

tables_to_columns: DefaultDict[Any, Set[Tuple[str, str, ColumnType]]] = defaultdict(
set
)
tables_to_columns: DefaultDict[Any, Set[Tuple[str, str, ColumnType]]] = defaultdict(set)
for enum_name, references in target.enum_table_references.items():
for reference in references:
tables_to_columns[reference.table_name].add(
(reference.column_name, enum_name, reference.column_type)
)
tables_to_columns[reference.table_name].add((reference.column_name, enum_name, reference.column_type))

for table_name, columns_with_enum_names in tables_to_columns.items():
Table(
Expand Down
11 changes: 3 additions & 8 deletions tests/sync_enum_values/test_array_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def test_add_new_enum_value_render_with_array(connection: "Connection"):
autogenerate._render_migration_diffs(context, template_args)

assert template_args["imports"] == (
"from alembic_postgresql_enum import ColumnType"
"\nfrom alembic_postgresql_enum import TableReference"
"from alembic_postgresql_enum import ColumnType" "\nfrom alembic_postgresql_enum import TableReference"
)

assert (
Expand Down Expand Up @@ -125,9 +124,7 @@ def test_remove_enum_value_diff_tuple_with_array(connection: "Connection"):
assert operation_name == SyncEnumValuesOp.operation_name
assert old_values == old_enum_variants
assert new_values == new_enum_variants
assert affected_columns == [
TableReference(CAR_TABLE_NAME, CAR_COLORS_COLUMN_NAME, ColumnType.ARRAY)
]
assert affected_columns == [TableReference(CAR_TABLE_NAME, CAR_COLORS_COLUMN_NAME, ColumnType.ARRAY)]


def test_rename_enum_value_diff_tuple_with_array(connection: "Connection"):
Expand Down Expand Up @@ -164,6 +161,4 @@ def test_rename_enum_value_diff_tuple_with_array(connection: "Connection"):
assert operation_name == SyncEnumValuesOp.operation_name
assert old_values == old_enum_variants
assert new_values == new_enum_variants
assert affected_columns == [
TableReference(CAR_TABLE_NAME, CAR_COLORS_COLUMN_NAME, ColumnType.ARRAY)
]
assert affected_columns == [TableReference(CAR_TABLE_NAME, CAR_COLORS_COLUMN_NAME, ColumnType.ARRAY)]
Loading

0 comments on commit f39dd28

Please sign in to comment.