Skip to content

Commit

Permalink
feat(ingest/dbt): add experimental prefer_sql_parser_lineage flag (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jul 31, 2024
1 parent 2333304 commit 4b9844d
Show file tree
Hide file tree
Showing 11 changed files with 3,421 additions and 184 deletions.
125 changes: 98 additions & 27 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,13 @@ class DBTCommonConfig(
description="When enabled, includes the compiled code in the emitted metadata.",
)

prefer_sql_parser_lineage: bool = Field(
default=False,
description="Normally we use dbt's metadata to generate table lineage. When enabled, we prefer results from the SQL parser when generating lineage instead. "
"This can be useful when dbt models reference tables directly, instead of using the ref() macro. "
"This requires that `skip_sources_in_lineage` is enabled.",
)

@validator("target_platform")
def validate_target_platform_value(cls, target_platform: str) -> str:
if target_platform.lower() == DBT_PLATFORM:
Expand Down Expand Up @@ -447,6 +454,16 @@ def validate_skip_sources_in_lineage(

return skip_sources_in_lineage

@validator("prefer_sql_parser_lineage")
def validate_prefer_sql_parser_lineage(
cls, prefer_sql_parser_lineage: bool, values: Dict
) -> bool:
if prefer_sql_parser_lineage and not values.get("skip_sources_in_lineage"):
raise ValueError(
"`prefer_sql_parser_lineage` requires that `skip_sources_in_lineage` is enabled."
)
return prefer_sql_parser_lineage


@dataclass
class DBTColumn:
Expand Down Expand Up @@ -516,6 +533,9 @@ class DBTNode:
columns: List[DBTColumn] = field(default_factory=list)
upstream_nodes: List[str] = field(default_factory=list) # list of upstream dbt_name
upstream_cll: List[DBTColumnLineageInfo] = field(default_factory=list)
raw_sql_parsing_result: Optional[
SqlParsingResult
] = None # only set for nodes that don't depend on ephemeral models
cll_debug_info: Optional[SqlParsingDebugInfo] = None

meta: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -1130,6 +1150,7 @@ def _infer_schemas_and_update_cll( # noqa: C901

# Run sql parser to infer the schema + generate column lineage.
sql_result = None
depends_on_ephemeral_models = False
if node.node_type in {"source", "test", "seed"}:
# For sources, we generate CLL as a 1:1 mapping.
# We don't support CLL for tests (assertions) or seeds.
Expand All @@ -1148,15 +1169,21 @@ def _infer_schemas_and_update_cll( # noqa: C901
upstream_node.name, schema_resolver.platform
)
}
if cte_mapping:
depends_on_ephemeral_models = True

sql_result = self._parse_cll(node, cte_mapping, schema_resolver)
else:
self.report.sql_parser_skipped_missing_code.append(node.dbt_name)

# Save the column lineage.
if self.config.include_column_lineage and sql_result:
# We only save the debug info here. We'll report errors based on it later, after
# applying the configured node filters.
# We save the raw info here. We use this for supporting `prefer_sql_parser_lineage`.
if not depends_on_ephemeral_models:
node.raw_sql_parsing_result = sql_result

# We use this for error reporting. However, we only want to report errors
# after node filters are applied.
node.cll_debug_info = sql_result.debug_info

if sql_result.column_lineage:
Expand All @@ -1171,6 +1198,7 @@ def _infer_schemas_and_update_cll( # noqa: C901
for column_lineage_info in sql_result.column_lineage
for upstream_column in column_lineage_info.upstreams
# Only include the CLL if the table in in the upstream list.
# TODO: Add some telemetry around this - how frequently does it filter stuff out?
if target_platform_urn_to_dbt_name.get(upstream_column.table)
in node.upstream_nodes
]
Expand Down Expand Up @@ -1813,33 +1841,76 @@ def _translate_dbt_name_to_upstream_urn(dbt_name: str) -> str:

if node.cll_debug_info and node.cll_debug_info.error:
self.report.report_warning(
node.dbt_name,
f"Error parsing SQL to generate column lineage: {node.cll_debug_info.error}",
"Error parsing SQL to generate column lineage",
context=node.dbt_name,
exc=node.cll_debug_info.error,
)
cll = [
FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreamType=FineGrainedLineageDownstreamType.FIELD_SET,
upstreams=[
mce_builder.make_schema_field_urn(
_translate_dbt_name_to_upstream_urn(
upstream_column.upstream_dbt_name
),
upstream_column.upstream_col,

cll = None
if self.config.prefer_sql_parser_lineage and node.raw_sql_parsing_result:
sql_parsing_result = node.raw_sql_parsing_result
if sql_parsing_result and not sql_parsing_result.debug_info.table_error:
# If we have some table lineage from SQL parsing, use that.
upstream_urns = sql_parsing_result.in_tables

cll = []
for column_lineage in sql_parsing_result.column_lineage or []:
cll.append(
FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreamType=FineGrainedLineageDownstreamType.FIELD,
upstreams=[
mce_builder.make_schema_field_urn(
upstream.table, upstream.column
)
for upstream in column_lineage.upstreams
],
downstreams=[
mce_builder.make_schema_field_urn(
node_urn, column_lineage.downstream.column
)
],
confidenceScore=sql_parsing_result.debug_info.confidence,
)
)
for upstream_column in upstreams
],
downstreams=[
mce_builder.make_schema_field_urn(node_urn, downstream)
],
confidenceScore=(
node.cll_debug_info.confidence if node.cll_debug_info else None
),
)
for downstream, upstreams in itertools.groupby(
node.upstream_cll, lambda x: x.downstream_col
)
]

else:
if self.config.prefer_sql_parser_lineage:
if node.upstream_cll:
self.report.report_warning(
"SQL parser lineage is not available for this node, falling back to dbt-based column lineage.",
context=node.dbt_name,
)
else:
# SQL parsing failed entirely, which is already reported above.
pass

cll = [
FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreamType=FineGrainedLineageDownstreamType.FIELD,
upstreams=[
mce_builder.make_schema_field_urn(
_translate_dbt_name_to_upstream_urn(
upstream_column.upstream_dbt_name
),
upstream_column.upstream_col,
)
for upstream_column in upstreams
],
downstreams=[
mce_builder.make_schema_field_urn(node_urn, downstream)
],
confidenceScore=(
node.cll_debug_info.confidence
if node.cll_debug_info
else None
),
)
for downstream, upstreams in itertools.groupby(
node.upstream_cll, lambda x: x.downstream_col
)
]

if not upstream_urns:
return None
Expand Down
Loading

0 comments on commit 4b9844d

Please sign in to comment.