Skip to content

Commit

Permalink
Refactor/unify public and model nodes (#7891)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Jun 21, 2023
1 parent 4cdeff1 commit ecf90d6
Show file tree
Hide file tree
Showing 26 changed files with 955 additions and 1,116 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230616-131503.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: 'Refactoring: consolidating public_nodes and nodes'
time: 2023-06-16T13:15:03.668124-04:00
custom:
Author: michelleark
Issue: "7890"
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]:
return {
self.Relation.create_from(self.config, node).without_identifier()
for node in manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model)
if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node)
}

def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
Expand Down
11 changes: 2 additions & 9 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,7 @@ def resolve(
return self.create_relation(target_model)

def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_public_node:
# Get quoting from publication artifact
pub_metadata = self.manifest.publications[target_model.package_name].metadata
return self.Relation.create_from_node(pub_metadata, target_model)
elif target_model.is_ephemeral_model:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
else:
Expand All @@ -535,10 +531,7 @@ def validate(
target_package: Optional[str],
target_version: Optional[NodeVersion],
) -> None:
if (
resolved.unique_id not in self.model.depends_on.nodes
and resolved.unique_id not in self.model.depends_on.public_nodes
):
if resolved.unique_id not in self.model.depends_on.nodes:
args = self._repack_args(target_name, target_package, target_version)
raise RefBadContextError(node=self.model, args=args)

Expand Down
36 changes: 11 additions & 25 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing_extensions import Protocol
from uuid import UUID

from dbt.contracts.publication import PublicationConfig, PublicModel
from dbt.contracts.publication import PublicationConfig

from dbt.contracts.graph.nodes import (
BaseNode,
Expand All @@ -33,7 +33,6 @@
Group,
Macro,
ManifestNode,
ManifestOrPublicNode,
Metric,
ModelNode,
RelationalNode,
Expand Down Expand Up @@ -162,7 +161,6 @@ class RefableLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)
self.populate_public_nodes(manifest)

def get_unique_id(
self,
Expand Down Expand Up @@ -206,7 +204,9 @@ def find(
[
UnparsedVersion(v.version)
for v in manifest.nodes.values()
if v.name == node.name and v.version is not None
if isinstance(v, ModelNode)
and v.name == node.name
and v.version is not None
]
)
assert node.latest_version is not None # for mypy, whenever i may find it
Expand All @@ -224,7 +224,7 @@ def find(
return node
return None

def add_node(self, node: ManifestOrPublicNode):
def add_node(self, node: ManifestNode):
if node.resource_type in self._lookup_types:
if node.name not in self.storage:
self.storage[node.name] = {}
Expand All @@ -242,15 +242,9 @@ def populate(self, manifest):
for node in manifest.nodes.values():
self.add_node(node)

def populate_public_nodes(self, manifest):
for node in manifest.public_nodes.values():
self.add_node(node)

def perform_lookup(self, unique_id: UniqueID, manifest) -> ManifestOrPublicNode:
def perform_lookup(self, unique_id: UniqueID, manifest) -> ManifestNode:
if unique_id in manifest.nodes:
node = manifest.nodes[unique_id]
elif unique_id in manifest.public_nodes:
node = manifest.public_nodes[unique_id]
else:
raise dbt.exceptions.DbtInternalError(
f"Node {unique_id} found in cache but not found in manifest"
Expand Down Expand Up @@ -436,7 +430,6 @@ def build_node_edges(nodes: List[ManifestNode]):
forward_edges: Dict[str, List[str]] = {n.unique_id: [] for n in nodes}
for node in nodes:
backward_edges[node.unique_id] = node.depends_on_nodes[:]
backward_edges[node.unique_id].extend(node.depends_on_public_nodes[:])
for unique_id in backward_edges[node.unique_id]:
if unique_id in forward_edges.keys():
forward_edges[unique_id].append(node.unique_id)
Expand Down Expand Up @@ -707,7 +700,6 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
source_patches: MutableMapping[SourceKey, SourcePatch] = field(default_factory=dict)
disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict)
env_vars: MutableMapping[str, str] = field(default_factory=dict)
public_nodes: MutableMapping[str, PublicModel] = field(default_factory=dict)
publications: MutableMapping[str, PublicationConfig] = field(default_factory=dict)
semantic_nodes: MutableMapping[str, SemanticModel] = field(default_factory=dict)

Expand Down Expand Up @@ -761,7 +753,6 @@ def build_flat_graph(self):
"metrics": {k: v.to_dict(omit_none=False) for k, v in self.metrics.items()},
"nodes": {k: v.to_dict(omit_none=False) for k, v in self.nodes.items()},
"sources": {k: v.to_dict(omit_none=False) for k, v in self.sources.items()},
"public_nodes": {k: v.to_dict(omit_none=False) for k, v in self.public_nodes.items()},
}

def build_disabled_by_file_id(self):
Expand Down Expand Up @@ -854,7 +845,6 @@ def deepcopy(self):
selectors={k: _deepcopy(v) for k, v in self.selectors.items()},
metadata=self.metadata,
disabled={k: _deepcopy(v) for k, v in self.disabled.items()},
public_nodes={k: _deepcopy(v) for k, v in self.public_nodes.items()},
files={k: _deepcopy(v) for k, v in self.files.items()},
state_check=_deepcopy(self.state_check),
)
Expand All @@ -868,7 +858,6 @@ def build_parent_and_child_maps(self):
self.sources.values(),
self.exposures.values(),
self.metrics.values(),
self.public_nodes.values(),
)
)
forward_edges, backward_edges = build_node_edges(edge_members)
Expand Down Expand Up @@ -912,7 +901,6 @@ def writable_manifest(self) -> "WritableManifest":
selectors=self.selectors,
metadata=self.metadata,
disabled=self.disabled,
public_nodes=self.public_nodes,
child_map=self.child_map,
parent_map=self.parent_map,
group_map=self.group_map,
Expand Down Expand Up @@ -1001,6 +989,10 @@ def pydantic_semantic_manifest(self) -> PydanticSemanticManifest:

return pydantic_semantic_manifest

@property
def external_node_unique_ids(self):
return [node.unique_id for node in self.nodes.values() if node.is_external_node]

def resolve_refs(
self, source_node: GraphMemberNode, current_project: str
) -> List[MaybeNonSource]:
Expand Down Expand Up @@ -1039,9 +1031,7 @@ def resolve_ref(
target_model_name, pkg, target_model_version, self, source_node
)

if node is not None and (
(hasattr(node, "config") and node.config.enabled) or node.is_public_node
):
if node is not None and hasattr(node, "config") and node.config.enabled:
return node

# it's possible that the node is disabled
Expand Down Expand Up @@ -1296,7 +1286,6 @@ def __reduce_ex__(self, protocol):
self.source_patches,
self.disabled,
self.env_vars,
self.public_nodes,
self._doc_lookup,
self._source_lookup,
self._ref_lookup,
Expand Down Expand Up @@ -1366,9 +1355,6 @@ class WritableManifest(ArtifactMixin):
description="A mapping from group names to their nodes",
)
)
public_nodes: Mapping[UniqueID, PublicModel] = field(
metadata=dict(description=("The public models used in the dbt project"))
)
semantic_nodes: Mapping[UniqueID, SemanticModel] = field(
metadata=dict(description=("The semantic models defined in the dbt project"))
)
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/contracts/graph/manifest_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def upgrade_manifest_json(manifest: dict, manifest_schema_version: int) -> dict:
manifest["groups"] = {}
if "group_map" not in manifest:
manifest["group_map"] = {}
if "public_nodes" not in manifest:
manifest["public_nodes"] = {}
for metric_content in manifest.get("metrics", {}).values():
# handle attr renames + value translation ("expression" -> "derived")
metric_content = upgrade_ref_content(metric_content)
Expand Down
19 changes: 19 additions & 0 deletions core/dbt/contracts/graph/node_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

from dbt.contracts.graph.unparsed import NodeVersion


@dataclass
class ModelNodeArgs:
name: str
package_name: str
identifier: str
schema: str
database: Optional[str] = None
relation_name: Optional[str] = None
version: Optional[NodeVersion] = None
latest_version: Optional[NodeVersion] = None
deprecation_date: Optional[datetime] = None
generated_at: datetime = field(default_factory=datetime.utcnow)
96 changes: 29 additions & 67 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib

from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Protocol
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator

from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin

Expand Down Expand Up @@ -35,6 +35,7 @@
UnparsedSourceTableDefinition,
UnparsedColumn,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
from dbt.events.functions import warn_or_error
from dbt.exceptions import ParsingError, ContractBreakingChangeError
Expand Down Expand Up @@ -268,16 +269,11 @@ def identifier(self):
@dataclass
class DependsOn(MacroDependsOn):
nodes: List[str] = field(default_factory=list)
public_nodes: List[str] = field(default_factory=list)

def add_node(self, value: str):
if value not in self.nodes:
self.nodes.append(value)

def add_public_node(self, value: str):
if value not in self.public_nodes:
self.public_nodes.append(value)


@dataclass
class StateRelation(dbtClassMixin):
Expand Down Expand Up @@ -487,7 +483,7 @@ def same_contents(self, old, adapter_type) -> bool:
)

@property
def is_public_node(self):
def is_external_node(self):
return False


Expand Down Expand Up @@ -552,10 +548,6 @@ def __post_serialize__(self, dct):
def depends_on_nodes(self):
return self.depends_on.nodes

@property
def depends_on_public_nodes(self):
return self.depends_on.public_nodes

@property
def depends_on_macros(self):
return self.depends_on.macros
Expand Down Expand Up @@ -587,6 +579,32 @@ class ModelNode(CompiledNode):
deprecation_date: Optional[datetime] = None
state_relation: Optional[StateRelation] = None

@classmethod
def from_args(cls, args: ModelNodeArgs) -> "ModelNode":
unique_id = f"{NodeType.Model}.{args.package_name}.{args.name}"

return cls(
resource_type=NodeType.Model,
name=args.name,
package_name=args.package_name,
unique_id=unique_id,
fqn=[args.package_name, args.name],
version=args.version,
latest_version=args.latest_version,
relation_name=args.relation_name,
database=args.database,
schema=args.schema,
alias=args.identifier,
deprecation_date=args.deprecation_date,
checksum=FileHash.from_contents(f"{unique_id},{args.generated_at}"),
original_file_path="",
path="",
)

@property
def is_external_node(self) -> bool:
return not self.original_file_path and not self.path

@property
def is_latest_version(self) -> bool:
return self.version is not None and self.version == self.latest_version
Expand Down Expand Up @@ -849,10 +867,6 @@ def same_body(self, other) -> bool:
def depends_on_nodes(self):
return []

@property
def depends_on_public_nodes(self):
return []

@property
def depends_on_macros(self) -> List[str]:
return self.depends_on.macros
Expand Down Expand Up @@ -1180,10 +1194,6 @@ def is_ephemeral_model(self):
def depends_on_nodes(self):
return []

@property
def depends_on_public_nodes(self):
return []

@property
def depends_on(self):
return DependsOn(macros=[], nodes=[])
Expand Down Expand Up @@ -1233,10 +1243,6 @@ class Exposure(GraphNode):
def depends_on_nodes(self):
return self.depends_on.nodes

@property
def depends_on_public_nodes(self):
return self.depends_on.public_nodes

@property
def search_name(self):
return self.name
Expand Down Expand Up @@ -1375,10 +1381,6 @@ class Metric(GraphNode):
def depends_on_nodes(self):
return self.depends_on.nodes

@property
def depends_on_public_nodes(self):
return self.depends_on.public_nodes

@property
def search_name(self):
return self.name
Expand Down Expand Up @@ -1559,46 +1561,6 @@ class ParsedMacroPatch(ParsedPatch):
# ====================================


class ManifestOrPublicNode(Protocol):
name: str
package_name: str
unique_id: str
version: Optional[NodeVersion]
latest_version: Optional[NodeVersion]
relation_name: str
database: Optional[str]
schema: Optional[str]
identifier: Optional[str]

@property
def is_latest_version(self):
pass

@property
def resource_type(self):
pass

@property
def access(self):
pass

@property
def search_name(self):
pass

@property
def is_public_node(self):
pass

@property
def is_versioned(self):
pass

@property
def alias(self):
pass


# ManifestNode without SeedNode, which doesn't have the
# SQL related attributes
ManifestSQLNode = Union[
Expand Down
Loading

0 comments on commit ecf90d6

Please sign in to comment.