Skip to content

Commit

Permalink
feat(ingest/dbt): speed up dbt CLL with node_name_patterns (datahub-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and sleeperdeep committed Dec 17, 2024
1 parent c992c0f commit 3696c87
Showing 1 changed file with 53 additions and 5 deletions.
58 changes: 53 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import auto
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

import more_itertools
import pydantic
Expand Down Expand Up @@ -46,6 +46,7 @@
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.dbt.dbt_tests import (
DBTTest,
DBTTestResult,
Expand Down Expand Up @@ -1024,12 +1025,15 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
all_nodes_map,
)

def _is_allowed_node(self, key: str) -> bool:
return self.config.node_name_pattern.allowed(key)

def _filter_nodes(self, all_nodes: List[DBTNode]) -> List[DBTNode]:
nodes = []
for node in all_nodes:
key = node.dbt_name

if not self.config.node_name_pattern.allowed(key):
if not self._is_allowed_node(key):
self.report.nodes_filtered.append(key)
continue

Expand All @@ -1041,6 +1045,36 @@ def _filter_nodes(self, all_nodes: List[DBTNode]) -> List[DBTNode]:
def _to_schema_info(schema_fields: List[SchemaField]) -> SchemaInfo:
return {column.fieldPath: column.nativeDataType for column in schema_fields}

def _determine_cll_required_nodes(
self, all_nodes_map: Dict[str, DBTNode]
) -> Tuple[Set[str], Set[str]]:
# Based on the filter patterns, we only need to do schema inference and CLL
# for a subset of nodes.
# If a node depends on an ephemeral model, the ephemeral model should also be in the CLL list.
# Invariant: If it's in the CLL list, it will also be in the schema list.
# Invariant: The upstream of any node in the CLL list will be in the schema list.
schema_nodes: Set[str] = set()
cll_nodes: Set[str] = set()

def add_node_to_cll_list(dbt_name: str) -> None:
if dbt_name in cll_nodes:
return
for upstream in all_nodes_map[dbt_name].upstream_nodes:
schema_nodes.add(upstream)

upstream_node = all_nodes_map[upstream]
if upstream_node.is_ephemeral_model():
add_node_to_cll_list(upstream)

cll_nodes.add(dbt_name)
schema_nodes.add(dbt_name)

for dbt_name in all_nodes_map.keys():
if self._is_allowed_node(dbt_name):
add_node_to_cll_list(dbt_name)

return schema_nodes, cll_nodes

def _infer_schemas_and_update_cll( # noqa: C901
self, all_nodes_map: Dict[str, DBTNode]
) -> None:
Expand All @@ -1067,7 +1101,7 @@ def _infer_schemas_and_update_cll( # noqa: C901
)
return

graph = self.ctx.graph
graph: Optional[DataHubGraph] = self.ctx.graph

schema_resolver = SchemaResolver(
platform=self.config.target_platform,
Expand All @@ -1079,7 +1113,7 @@ def _infer_schemas_and_update_cll( # noqa: C901

# Iterate over the dbt nodes in topological order.
# This ensures that we process upstream nodes before downstream nodes.
node_order = topological_sort(
all_node_order = topological_sort(
list(all_nodes_map.keys()),
edges=list(
(upstream, node.dbt_name)
Expand All @@ -1088,7 +1122,17 @@ def _infer_schemas_and_update_cll( # noqa: C901
if upstream in all_nodes_map
),
)
for dbt_name in node_order:
schema_required_nodes, cll_required_nodes = self._determine_cll_required_nodes(
all_nodes_map
)

for dbt_name in all_node_order:
if dbt_name not in schema_required_nodes:
logger.debug(
f"Skipping {dbt_name} because it is filtered out by patterns"
)
continue

node = all_nodes_map[dbt_name]
logger.debug(f"Processing CLL/schemas for {node.dbt_name}")

Expand Down Expand Up @@ -1163,6 +1207,10 @@ def _infer_schemas_and_update_cll( # noqa: C901
# For sources, we generate CLL as a 1:1 mapping.
# We don't support CLL for tests (assertions) or seeds.
pass
elif node.dbt_name not in cll_required_nodes:
logger.debug(
f"Not generating CLL for {node.dbt_name} because we don't need it."
)
elif node.compiled_code:
# Add CTE stops based on the upstreams list.
cte_mapping = {
Expand Down

0 comments on commit 3696c87

Please sign in to comment.