Skip to content

Commit

Permalink
feat: add overwrite_method to postgresql.to_sql (#2820)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger authored May 14, 2024
1 parent 44ae3fb commit 6ed9850
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
34 changes: 31 additions & 3 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ def _validate_connection(con: "pg8000.Connection") -> None:
)


def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str) -> None:
def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)}"
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _truncate_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"TRUNCATE TABLE {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Truncate table query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "pg8000.Cursor", schema: str | None, table: str) -> bool:
schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else ""
cursor.execute(
Expand All @@ -66,12 +75,21 @@ def _create_table(
table: str,
schema: str,
mode: str,
overwrite_method: _ToSqlOverwriteModeLiteral,
index: bool,
dtype: dict[str, str] | None,
varchar_lengths: dict[str, int] | None,
) -> None:
if mode == "overwrite":
_drop_table(cursor=cursor, schema=schema, table=table)
if overwrite_method in ["drop", "cascade"]:
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
elif overwrite_method in ["truncate", "truncate cascade"]:
if _does_table_exist(cursor=cursor, schema=schema, table=table):
_truncate_table(
cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "truncate cascade")
)
else:
raise exceptions.InvalidArgumentValue(f"Invalid overwrite_method: {overwrite_method}")
elif _does_table_exist(cursor=cursor, schema=schema, table=table):
return
postgresql_types: dict[str, str] = _data_types.database_types_from_pandas(
Expand Down Expand Up @@ -485,6 +503,7 @@ def read_sql_table(


_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "truncate cascade"]


@_utils.check_optional_dependency(pg8000, "pg8000")
Expand All @@ -495,6 +514,7 @@ def to_sql(
table: str,
schema: str,
mode: _ToSqlModeLiteral = "append",
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
index: bool = False,
dtype: dict[str, str] | None = None,
varchar_lengths: dict[str, int] | None = None,
Expand Down Expand Up @@ -522,6 +542,13 @@ def to_sql(
overwrite: Drops table and recreates.
upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and
sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode.
overwrite_method : str
Drop, cascade, truncate, or truncate cascade. Only applicable in overwrite mode.
"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
"truncate" - ``TRUNCATE ... RESTRICT`` - truncates the table. Fails if any of the tables have foreign-key references from tables that are not listed in the command.
"truncate cascade" - ``TRUNCATE ... CASCADE`` - truncates the table, and all tables that have foreign-key references to any of the named tables.
index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand Down Expand Up @@ -583,6 +610,7 @@ def to_sql(
table=table,
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
index=index,
dtype=dtype,
varchar_lengths=varchar_lengths,
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,51 @@ def test_read_sql_query_simple(databases_parameters):

def test_to_sql_simple(postgresql_table, postgresql_con):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
wr.postgresql.to_sql(df, postgresql_con, postgresql_table, "public", "overwrite", True)
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
index=True,
)


@pytest.mark.parametrize("overwrite_method", ["drop", "cascade", "truncate", "truncate cascade"])
def test_to_sql_overwrite(postgresql_table, postgresql_con, overwrite_method):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method=overwrite_method,
)
df = pd.DataFrame({"c0": [4, 5, 6], "c1": ["xoo", "yoo", "zoo"]})
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method=overwrite_method,
)
df = wr.postgresql.read_sql_table(table=postgresql_table, schema="public", con=postgresql_con)
assert df.shape == (3, 2)


def test_unknown_overwrite_method_error(postgresql_table, postgresql_con):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method="unknown",
)


def test_sql_types(postgresql_table, postgresql_con):
Expand Down

0 comments on commit 6ed9850

Please sign in to comment.