from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Sequence import boto3 from mypy_boto3_glue import GlueClient from mypy_boto3_glue.type_defs import ColumnTypeDef from mypy_boto3_glue.type_defs import GetTableResponseTypeDef from mypy_boto3_glue.type_defs import PartitionInputTypeDef from mypy_boto3_glue.type_defs import SerDeInfoTypeDef from mypy_boto3_glue.type_defs import StorageDescriptorTypeDef from mypy_boto3_glue.type_defs import TableInputTypeDef from . import BasePlugin from ..utils import TargetConfig from dbt.adapters.base.column import Column class UnsupportedFormatType(Exception): """UnsupportedFormatType exception.""" class UnsupportedType(Exception): """UnsupportedType exception.""" class UndetectedType(Exception): """UndetectedType exception.""" def _dbt2glue(dtype: str, ignore_null: bool = False) -> str: # pragma: no cover """DuckDB to Glue data types conversion.""" data_type = dtype.split("(")[0] if data_type.lower() in ["int1", "tinyint"]: return "tinyint" if data_type.lower() in ["int2", "smallint", "short", "utinyint"]: return "smallint" if data_type.lower() in ["int4", "int", "integer", "signed", "usmallint"]: return "int" if data_type.lower() in ["int8", "long", "bigint", "signed", "uinteger"]: return "bigint" if data_type.lower() in ["hugeint", "ubigint"]: raise UnsupportedType( "There is no support for hugeint or ubigint, please consider bigint or uinteger." ) if data_type.lower() in ["float4", "float", "real"]: return "float" if data_type.lower() in ["float8", "numeric", "decimal", "double"]: return "double" if data_type.lower() in ["boolean", "bool", "logical"]: return "boolean" if data_type.lower() in ["varchar", "char", "bpchar", "text", "string", "uuid"]: return "string" if data_type.lower() in [ "timestamp", "datetime", "timestamptz", "timestamp with time zone", ]: return "timestamp" if data_type.lower() in ["date"]: return "date" if data_type.lower() in ["blob", "bytea", "binary", "varbinary"]: return "binary" if data_type is None: if ignore_null: return "" raise UndetectedType("We can not infer the data type from an entire null object column") raise UnsupportedType(f"Unsupported type: {dtype}") def _get_parquet_table_def( table: str, s3_parent: str, columns: Sequence["ColumnTypeDef"] ) -> "TableInputTypeDef": """Create table definition for Glue table. See https://docs.aws.amazon.com/glue/latest/webapi/API_CreateTable.html#API_CreateTable_RequestSyntax""" table_def = TableInputTypeDef( Name=table, TableType="EXTERNAL_TABLE", Parameters={"classification": "parquet"}, StorageDescriptor=StorageDescriptorTypeDef( Columns=columns, Location=s3_parent, InputFormat="org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", OutputFormat="org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", Compressed=False, NumberOfBuckets=-1, SerdeInfo=SerDeInfoTypeDef( SerializationLibrary="org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", Parameters={"serialization.format": "1"}, ), BucketColumns=[], StoredAsSubDirectories=False, SortColumns=[], ), ) return table_def def _get_csv_table_def( table: str, s3_parent: str, columns: Sequence["ColumnTypeDef"], delimiter: str = "," ) -> "TableInputTypeDef": """Create table definition for Glue table. See https://docs.aws.amazon.com/glue/latest/webapi/API_CreateTable.html#API_CreateTable_RequestSyntax""" table_def = TableInputTypeDef( Name=table, TableType="EXTERNAL_TABLE", Parameters={"classification": "csv"}, StorageDescriptor=StorageDescriptorTypeDef( Columns=columns, Location=s3_parent, InputFormat="org.apache.hadoop.mapred.TextInputFormat", OutputFormat="org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", Compressed=False, NumberOfBuckets=-1, SerdeInfo=SerDeInfoTypeDef( SerializationLibrary="org.apache.hadoop.hive.serde2.OpenCSVSerde", Parameters={"separatorChar": delimiter}, ), BucketColumns=[], StoredAsSubDirectories=False, SortColumns=[], ), ) return table_def def _convert_columns(column_list: Sequence[Column]) -> Sequence["ColumnTypeDef"]: """Convert dbt schema into Glue compliable Schema""" column_types = [] for column in column_list: column_types.append(ColumnTypeDef(Name=column.name, Type=_dbt2glue(column.dtype))) return column_types def _create_table( client: "GlueClient", database: str, table_def: "TableInputTypeDef", partition_columns: List[Dict[str, str]], ) -> None: client.create_table(DatabaseName=database, TableInput=table_def) # Create partition if relevant if partition_columns != []: partition_input, partition_values = _parse_partition_columns(partition_columns, table_def) client.create_partition( DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input ) def _update_table( client: "GlueClient", database: str, table_def: "TableInputTypeDef", partition_columns: List[Dict[str, str]], ) -> None: client.update_table(DatabaseName=database, TableInput=table_def) # Update or create partition if relevant if partition_columns != []: partition_input, partition_values = _parse_partition_columns(partition_columns, table_def) try: client.get_partition( DatabaseName=database, TableName=table_def["Name"], PartitionValues=partition_values, ) client.update_partition( DatabaseName=database, TableName=table_def["Name"], PartitionValueList=partition_values, PartitionInput=partition_input, ) except client.exceptions.EntityNotFoundException: client.create_partition( DatabaseName=database, TableName=table_def["Name"], PartitionInput=partition_input ) def _get_table( client: "GlueClient", database: str, table: str ) -> Optional["GetTableResponseTypeDef"]: try: return client.get_table(DatabaseName=database, Name=table) except client.exceptions.EntityNotFoundException: # pragma: no cover return None def _get_column_type_def( table_def: "GetTableResponseTypeDef", ) -> Optional[Sequence["ColumnTypeDef"]]: """Get columns definition from Glue Table Definition""" raw = table_def.get("Table", {}).get("StorageDescriptor", {}).get("Columns") if raw: converted = [] for column in raw: converted.append(ColumnTypeDef(Name=column["Name"], Type=column["Type"])) return converted else: return None def _add_partition_columns( table_def: TableInputTypeDef, partition_columns: List[Dict[str, str]] ) -> TableInputTypeDef: partition_keys = [] if "PartitionKeys" not in table_def: table_def["PartitionKeys"] = [] for column in partition_columns: partition_column = ColumnTypeDef(Name=column["Name"], Type=column["Type"]) partition_keys.append(partition_column) table_def["PartitionKeys"] = partition_keys # Remove columns from StorageDescriptor if they match with partition columns to avoid duplicate columns for p_column in partition_columns: table_def["StorageDescriptor"]["Columns"] = [ column for column in table_def["StorageDescriptor"]["Columns"] if not (column["Name"] == p_column["Name"] and column["Type"] == p_column["Type"]) ] return table_def def _parse_partition_columns( partition_columns: List[Dict[str, str]], table_def: TableInputTypeDef ): partition_input, partition_values = None, None if partition_columns: partition_values = [column["Value"] for column in partition_columns] partition_location = table_def["StorageDescriptor"]["Location"] partition_components = [partition_location] for c in partition_columns: partition_components.append("=".join((c["Name"], c["Value"]))) partition_location = "/".join(partition_components) partition_input = PartitionInputTypeDef() partition_input["Values"] = partition_values partition_input["StorageDescriptor"] = table_def["StorageDescriptor"] partition_input["StorageDescriptor"]["Location"] = partition_location return partition_input, partition_values def _get_table_def( table: str, s3_parent: str, columns: Sequence["ColumnTypeDef"], file_format: str, delimiter: str, ): if file_format == "csv": table_def = _get_csv_table_def( table=table, s3_parent=s3_parent, columns=columns, delimiter=delimiter, ) elif file_format == "parquet": table_def = _get_parquet_table_def(table=table, s3_parent=s3_parent, columns=columns) else: raise UnsupportedFormatType("Format %s is not supported in Glue registrar." % file_format) return table_def def _get_glue_client(settings: Dict[str, Any]) -> "GlueClient": if settings: return boto3.client( "glue", aws_access_key_id=settings.get("s3_access_key_id"), aws_secret_access_key=settings.get("s3_secret_access_key"), aws_session_token=settings.get("s3_session_token"), region_name=settings.get("s3_region"), ) else: return boto3.client("glue") def create_or_update_table( client: GlueClient, database: str, table: str, column_list: Sequence[Column], s3_path: str, file_format: str, delimiter: str, partition_columns: List[Dict[str, str]] = [], ) -> None: # Set s3 original path if partitioning is used, else use parent path if partition_columns != []: s3_parent = s3_path if partition_columns == []: s3_parent = "/".join(s3_path.split("/")[:-1]) # Existing table in AWS Glue catalog glue_table = _get_table(client=client, database=database, table=table) columns = _convert_columns(column_list) table_def = _get_table_def( table=table, s3_parent=s3_parent, columns=columns, file_format=file_format, delimiter=delimiter, ) # Add partition columns if partition_columns != []: table_def = _add_partition_columns(table_def, partition_columns) if glue_table: # Existing columns in AWS Glue catalog glue_columns = _get_column_type_def(glue_table) # Create new version only if columns are changed if glue_columns != columns: _update_table( client=client, database=database, table_def=table_def, partition_columns=partition_columns, ) else: _create_table( client=client, database=database, table_def=table_def, partition_columns=partition_columns, ) class Plugin(BasePlugin): def initialize(self, config: Dict[str, Any]): self.client = _get_glue_client(config) self.database = config.get("glue_database", "default") self.delimiter = config.get("delimiter", ",") def store(self, target_config: TargetConfig): assert target_config.location is not None assert target_config.relation.identifier is not None table: str = target_config.relation.identifier partition_columns = target_config.config.get("partition_columns", []) create_or_update_table( self.client, self.database, table, target_config.column_list, target_config.location.path, target_config.location.format, self.delimiter, partition_columns, )