Skip to content

Commit

Permalink
feat(ingest): standardize sql type mappings (datahub-project#11982)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and sleeperdeep committed Dec 17, 2024
1 parent 984f45f commit c9f45c6
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 137 deletions.
68 changes: 7 additions & 61 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,7 @@
make_assertion_from_test,
make_assertion_result_from_test,
)
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
BIGQUERY_TYPES_MAP,
POSTGRES_TYPES_MAP,
SNOWFLAKE_TYPES_MAP,
SPARK_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
VERTICA_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_postgres_modified_type,
resolve_trino_modified_type,
resolve_vertica_modified_type,
)
from datahub.ingestion.source.sql.sql_types import resolve_sql_type
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
Expand All @@ -89,17 +77,11 @@
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
BooleanTypeClass,
DateTypeClass,
MySqlDDL,
NullTypeClass,
NumberTypeClass,
RecordType,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
TimeTypeClass,
)
from datahub.metadata.schema_classes import (
DataPlatformInstanceClass,
Expand Down Expand Up @@ -804,28 +786,6 @@ def make_mapping_upstream_lineage(
)


# See https://github.com/fishtown-analytics/dbt/blob/master/core/dbt/adapters/sql/impl.py
_field_type_mapping = {
"boolean": BooleanTypeClass,
"date": DateTypeClass,
"time": TimeTypeClass,
"numeric": NumberTypeClass,
"text": StringTypeClass,
"timestamp with time zone": DateTypeClass,
"timestamp without time zone": DateTypeClass,
"integer": NumberTypeClass,
"float8": NumberTypeClass,
"struct": RecordType,
**POSTGRES_TYPES_MAP,
**SNOWFLAKE_TYPES_MAP,
**BIGQUERY_TYPES_MAP,
**SPARK_SQL_TYPES_MAP,
**TRINO_SQL_TYPES_MAP,
**ATHENA_SQL_TYPES_MAP,
**VERTICA_SQL_TYPES_MAP,
}


def get_column_type(
report: DBTSourceReport,
dataset_name: str,
Expand All @@ -835,24 +795,10 @@ def get_column_type(
"""
Maps known DBT types to datahub types
"""
TypeClass: Any = _field_type_mapping.get(column_type) if column_type else None

if TypeClass is None and column_type:
# resolve a modified type
if dbt_adapter == "trino":
TypeClass = resolve_trino_modified_type(column_type)
elif dbt_adapter == "athena":
TypeClass = resolve_athena_modified_type(column_type)
elif dbt_adapter == "postgres" or dbt_adapter == "redshift":
# Redshift uses a variant of Postgres, so we can use the same logic.
TypeClass = resolve_postgres_modified_type(column_type)
elif dbt_adapter == "vertica":
TypeClass = resolve_vertica_modified_type(column_type)
elif dbt_adapter == "snowflake":
# Snowflake types are uppercase, so we check that.
TypeClass = _field_type_mapping.get(column_type.upper())

# if still not found, report the warning

TypeClass = resolve_sql_type(column_type, dbt_adapter)

# if still not found, report a warning
if TypeClass is None:
if column_type:
report.info(
Expand All @@ -861,9 +807,9 @@ def get_column_type(
context=f"{dataset_name} - {column_type}",
log=False,
)
TypeClass = NullTypeClass
TypeClass = NullTypeClass()

return SchemaFieldDataType(type=TypeClass())
return SchemaFieldDataType(type=TypeClass)


@platform_name("dbt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TimeType,
)

# TODO: Replace with standardized types in sql_types.py
FIELD_TYPE_MAPPING: Dict[
str,
Type[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
```
"""

# TODO: Replace with standardized types in sql_types.py
REDSHIFT_FIELD_TYPE_MAPPINGS: Dict[
str,
Type[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
logger = logging.getLogger(__name__)

# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html
# TODO: Move to the standardized types in sql_types.py
SNOWFLAKE_FIELD_TYPE_MAPPINGS = {
"DATE": DateType,
"BIGINT": NumberType,
Expand Down
79 changes: 72 additions & 7 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, ValuesView
from typing import Any, Dict, Optional, Type, Union, ValuesView

from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayType,
Expand All @@ -16,14 +16,28 @@
UnionType,
)

# these can be obtained by running `select format_type(oid, null),* from pg_type;`
# we've omitted the types without a meaningful DataHub type (e.g. postgres-specific types, index vectors, etc.)
# (run `\copy (select format_type(oid, null),* from pg_type) to 'pg_type.csv' csv header;` to get a CSV)
DATAHUB_FIELD_TYPE = Union[
ArrayType,
BooleanType,
BytesType,
DateType,
EnumType,
MapType,
NullType,
NumberType,
RecordType,
StringType,
TimeType,
UnionType,
]

# we map from format_type since this is what dbt uses
# see https://github.com/fishtown-analytics/dbt/blob/master/plugins/postgres/dbt/include/postgres/macros/catalog.sql#L22

# see https://www.npgsql.org/dev/types.html for helpful type annotations
# These can be obtained by running `select format_type(oid, null),* from pg_type;`
# We've omitted the types without a meaningful DataHub type (e.g. postgres-specific types, index vectors, etc.)
# (run `\copy (select format_type(oid, null),* from pg_type) to 'pg_type.csv' csv header;` to get a CSV)
# We map from format_type since this is what dbt uses.
# See https://github.com/fishtown-analytics/dbt/blob/master/plugins/postgres/dbt/include/postgres/macros/catalog.sql#L22
# See https://www.npgsql.org/dev/types.html for helpful type annotations
POSTGRES_TYPES_MAP: Dict[str, Any] = {
"boolean": BooleanType,
"bytea": BytesType,
Expand Down Expand Up @@ -430,3 +444,54 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
"geography": None,
"uuid": StringType,
}


_merged_mapping = {
"boolean": BooleanType,
"date": DateType,
"time": TimeType,
"numeric": NumberType,
"text": StringType,
"timestamp with time zone": DateType,
"timestamp without time zone": DateType,
"integer": NumberType,
"float8": NumberType,
"struct": RecordType,
**POSTGRES_TYPES_MAP,
**SNOWFLAKE_TYPES_MAP,
**BIGQUERY_TYPES_MAP,
**SPARK_SQL_TYPES_MAP,
**TRINO_SQL_TYPES_MAP,
**ATHENA_SQL_TYPES_MAP,
**VERTICA_SQL_TYPES_MAP,
}


def resolve_sql_type(
column_type: Optional[str],
platform: Optional[str] = None,
) -> Optional[DATAHUB_FIELD_TYPE]:
# In theory, we should use the platform-specific mapping where available.
# However, the types don't ever conflict, so the merged mapping is fine.
TypeClass: Optional[Type[DATAHUB_FIELD_TYPE]] = (
_merged_mapping.get(column_type) if column_type else None
)

if TypeClass is None and column_type:
# resolve a modified type
if platform == "trino":
TypeClass = resolve_trino_modified_type(column_type)
elif platform == "athena":
TypeClass = resolve_athena_modified_type(column_type)
elif platform == "postgres" or platform == "redshift":
# Redshift uses a variant of Postgres, so we can use the same logic.
TypeClass = resolve_postgres_modified_type(column_type)
elif platform == "vertica":
TypeClass = resolve_vertica_modified_type(column_type)
elif platform == "snowflake":
# Snowflake types are uppercase, so we check that.
TypeClass = _merged_mapping.get(column_type.upper())

if TypeClass:
return TypeClass()
return None
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

logger = logging.getLogger(__name__)

# TODO: (maybe) Replace with standardized types in sql_types.py
DATA_TYPE_REGISTRY: dict = {
ColumnTypeName.BOOLEAN: BooleanTypeClass,
ColumnTypeName.BYTE: BytesTypeClass,
Expand Down
69 changes: 0 additions & 69 deletions metadata-ingestion/tests/integration/dbt/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.dbt.dbt_common import DBTEntitiesEnabled, EmitDirective
from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig, DBTCoreSource
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_trino_modified_type,
)
from tests.test_helpers import mce_helpers, test_connection_helpers

FROZEN_TIME = "2022-02-03 07:00:00"
Expand Down Expand Up @@ -362,69 +356,6 @@ def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwar
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("real", "real"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("varbinary", "varbinary"),
("json", "json"),
("date", "date"),
("time", "time"),
("time(12)", "time"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("row(x bigint, y double)", "row"),
("array(row(x bigint, y double))", "array"),
("map(varchar, varchar)", "map"),
],
)
def test_resolve_trino_modified_type(data_type, expected_data_type):
assert (
resolve_trino_modified_type(data_type)
== TRINO_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("float", "float"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("binary", "binary"),
("date", "date"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("struct<x timestamp(3), y timestamp>", "struct"),
("array<struct<x bigint, y double>>", "array"),
("map<varchar, varchar>", "map"),
],
)
def test_resolve_athena_modified_type(data_type, expected_data_type):
assert (
resolve_athena_modified_type(data_type)
== ATHENA_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests_only_assertions(
Expand Down
78 changes: 78 additions & 0 deletions metadata-ingestion/tests/unit/test_sql_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest

from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_sql_type,
resolve_trino_modified_type,
)
from datahub.metadata.schema_classes import BooleanTypeClass, StringTypeClass


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("real", "real"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("varbinary", "varbinary"),
("json", "json"),
("date", "date"),
("time", "time"),
("time(12)", "time"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("row(x bigint, y double)", "row"),
("array(row(x bigint, y double))", "array"),
("map(varchar, varchar)", "map"),
],
)
def test_resolve_trino_modified_type(data_type, expected_data_type):
assert (
resolve_trino_modified_type(data_type)
== TRINO_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("float", "float"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("binary", "binary"),
("date", "date"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("struct<x timestamp(3), y timestamp>", "struct"),
("array<struct<x bigint, y double>>", "array"),
("map<varchar, varchar>", "map"),
],
)
def test_resolve_athena_modified_type(data_type, expected_data_type):
assert (
resolve_athena_modified_type(data_type)
== ATHENA_SQL_TYPES_MAP[expected_data_type]
)


def test_resolve_sql_type() -> None:
assert resolve_sql_type("boolean") == BooleanTypeClass()
assert resolve_sql_type("varchar") == StringTypeClass()

0 comments on commit c9f45c6

Please sign in to comment.