Skip to content

Commit

Permalink
Merge pull request #327 from firewall413/feature/update-partitioning-…
Browse files Browse the repository at this point in the history
…to-aws-glue

added support to dynamically registering/adding partitions
  • Loading branch information
jwills authored Feb 21, 2024
2 parents d0bbafe + ea364f2 commit c69bed8
Showing 1 changed file with 78 additions and 10 deletions.
88 changes: 78 additions & 10 deletions dbt/adapters/duckdb/plugins/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mypy_boto3_glue import GlueClient
from mypy_boto3_glue.type_defs import ColumnTypeDef
from mypy_boto3_glue.type_defs import GetTableResponseTypeDef
from mypy_boto3_glue.type_defs import PartitionInputTypeDef
from mypy_boto3_glue.type_defs import SerDeInfoTypeDef
from mypy_boto3_glue.type_defs import StorageDescriptorTypeDef
from mypy_boto3_glue.type_defs import TableInputTypeDef
Expand Down Expand Up @@ -132,12 +133,50 @@ def _convert_columns(column_list: Sequence[Column]) -> Sequence["ColumnTypeDef"]
return column_types


def _create_table(client: "GlueClient", database: str, table_def: "TableInputTypeDef") -> None:
def _create_table(
client: "GlueClient",
database: str,
table_def: "TableInputTypeDef",
partition_columns: List[Dict[str, str]],
) -> None:
client.create_table(DatabaseName=database, TableInput=table_def)
# Create partition if relevant
if partition_columns != []:
partition_input, partition_values = _parse_partition_columns(partition_columns, table_def)

client.create_partition(
DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input
)


def _update_table(client: "GlueClient", database: str, table_def: "TableInputTypeDef") -> None:
def _update_table(
client: "GlueClient",
database: str,
table_def: "TableInputTypeDef",
partition_columns: List[Dict[str, str]],
) -> None:
client.update_table(DatabaseName=database, TableInput=table_def)
# Update or create partition if relevant
if partition_columns != []:
partition_input, partition_values = _parse_partition_columns(partition_columns, table_def)

try:
client.get_partition(
DatabaseName=database,
TableName=table_def["Name"],
PartitionValues=partition_values,
)
client.update_partition(
DatabaseName=database,
TableName=table_def["Name"],
PartitionValueList=partition_values,
PartitionInput=partition_input,
)

except client.exceptions.EntityNotFoundException:
client.create_partition(
DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input
)


def _get_table(
Expand All @@ -163,7 +202,9 @@ def _get_column_type_def(
return None


def _add_partition_columns(table_def: TableInputTypeDef, partition_columns) -> TableInputTypeDef:
def _add_partition_columns(
table_def: TableInputTypeDef, partition_columns: List[Dict[str, str]]
) -> TableInputTypeDef:
partition_keys = []
if "PartitionKeys" not in table_def:
table_def["PartitionKeys"] = []
Expand All @@ -172,18 +213,35 @@ def _add_partition_columns(table_def: TableInputTypeDef, partition_columns) -> T
partition_keys.append(partition_column)
table_def["PartitionKeys"] = partition_keys
# Remove columns from StorageDescriptor if they match with partition columns to avoid duplicate columns
for partition_column in partition_columns:
for p_column in partition_columns:
table_def["StorageDescriptor"]["Columns"] = [
column
for column in table_def["StorageDescriptor"]["Columns"]
if not (
column["Name"] == partition_column["Name"]
and column["Type"] == partition_column["Type"]
)
if not (column["Name"] == p_column["Name"] and column["Type"] == p_column["Type"])
]
return table_def


def _parse_partition_columns(
partition_columns: List[Dict[str, str]], table_def: TableInputTypeDef
):
partition_input = None
if partition_columns:
partition_values = [column["Value"] for column in partition_columns]
partition_location = table_def["StorageDescriptor"]["Location"]
partition_components = [partition_location]
for c in partition_columns:
partition_components.append("=".join((c["Name"], c["Value"])))
partition_location = "/".join(partition_components)

partition_input = PartitionInputTypeDef()
partition_input["Values"] = partition_values
partition_input["StorageDescriptor"] = table_def["StorageDescriptor"]
partition_input["StorageDescriptor"]["Location"] = partition_location

return partition_input, partition_values


def _get_table_def(
table: str,
s3_parent: str,
Expand Down Expand Up @@ -252,9 +310,19 @@ def create_or_update_table(
glue_columns = _get_column_type_def(glue_table)
# Create new version only if columns are changed
if glue_columns != columns:
_update_table(client=client, database=database, table_def=table_def)
_update_table(
client=client,
database=database,
table_def=table_def,
partition_columns=partition_columns,
)
else:
_create_table(client=client, database=database, table_def=table_def)
_create_table(
client=client,
database=database,
table_def=table_def,
partition_columns=partition_columns,
)


class Plugin(BasePlugin):
Expand Down

0 comments on commit c69bed8

Please sign in to comment.