Skip to content

Commit

Permalink
Merge pull request #2628 from fishtown-analytics/feature/advanced-nod…
Browse files Browse the repository at this point in the history
…e-selection

split up node selection
  • Loading branch information
beckjake authored Jul 20, 2020
2 parents 7f0c1f8 + 91ff7f1 commit fb88d54
Show file tree
Hide file tree
Showing 25 changed files with 1,671 additions and 1,488 deletions.
14 changes: 7 additions & 7 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ jobs:
Set-Service -InputObject $serviceName -StartupType Automatic
Start-Service -InputObject $serviceName
createdb.exe -U postgres dbt
psql.exe -U postgres -c "CREATE ROLE root WITH PASSWORD 'password';"
psql.exe -U postgres -c "ALTER ROLE root WITH LOGIN;"
psql.exe -U postgres -c "GRANT CREATE, CONNECT ON DATABASE dbt TO root WITH GRANT OPTION;"
psql.exe -U postgres -c "CREATE ROLE noaccess WITH PASSWORD 'password' NOSUPERUSER;"
psql.exe -U postgres -c "ALTER ROLE noaccess WITH LOGIN;"
psql.exe -U postgres -c "GRANT CONNECT ON DATABASE dbt TO noaccess;"
& $env:PGBIN\createdb.exe -U postgres dbt
& $env:PGBIN\psql.exe -U postgres -c "CREATE ROLE root WITH PASSWORD 'password';"
& $env:PGBIN\psql.exe -U postgres -c "ALTER ROLE root WITH LOGIN;"
& $env:PGBIN\psql.exe -U postgres -c "GRANT CREATE, CONNECT ON DATABASE dbt TO root WITH GRANT OPTION;"
& $env:PGBIN\psql.exe -U postgres -c "CREATE ROLE noaccess WITH PASSWORD 'password' NOSUPERUSER;"
& $env:PGBIN\psql.exe -U postgres -c "ALTER ROLE noaccess WITH LOGIN;"
& $env:PGBIN\psql.exe -U postgres -c "GRANT CONNECT ON DATABASE dbt TO noaccess;"
displayName: Install postgresql and set up database
- task: UsePythonVersion@0
Expand Down
178 changes: 126 additions & 52 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,35 @@
import os
from collections import defaultdict
from typing import List, Dict, Any
from typing import List, Dict, Any, Tuple, cast

import dbt.utils
import dbt.include
import dbt.tracking
import networkx as nx # type: ignore

from dbt import flags
from dbt.node_types import NodeType
from dbt.linker import Linker

from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model
from dbt.contracts.graph.compiled import NonSourceNode
from dbt.contracts.graph.manifest import Manifest
import dbt.exceptions
import dbt.config
from dbt.contracts.graph.compiled import (
InjectedCTE,
COMPILED_TYPES,
NonSourceNode,
NonSourceCompiledNode,
CompiledSchemaTestNode,
)
from dbt.contracts.graph.parsed import ParsedNode

from dbt.exceptions import dependency_not_found, InternalException
from dbt.graph import Graph
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt.utils import add_ephemeral_model_prefix, pluralize

graph_file_name = 'graph.gpickle'


def _compiled_type_for(model: ParsedNode):
if type(model) not in COMPILED_TYPES:
raise dbt.exceptions.InternalException(
'Asked to compile {} node, but it has no compiled form'
.format(type(model))
raise InternalException(
f'Asked to compile {type(model)} node, but it has no compiled form'
)
return COMPILED_TYPES[type(model)]

Expand All @@ -53,7 +50,7 @@ def print_compile_stats(stats):
results.update(stats)

stat_line = ", ".join([
dbt.utils.pluralize(ct, names.get(t)) for t, ct in results.items()
pluralize(ct, names.get(t)) for t, ct in results.items()
if t in names
])

Expand Down Expand Up @@ -94,48 +91,54 @@ def _extend_prepended_ctes(prepended_ctes, new_prepended_ctes):
_add_prepended_cte(prepended_ctes, new_cte)


def prepend_ctes(model, manifest):
model, _, manifest = recursively_prepend_ctes(model, manifest)
class Linker:
def __init__(self, data=None):
if data is None:
data = {}
self.graph = nx.DiGraph(**data)

return (model, manifest)


def recursively_prepend_ctes(model, manifest):
if model.extra_ctes_injected:
return (model, model.extra_ctes, manifest)

if flags.STRICT_MODE:
if not isinstance(model, tuple(COMPILED_TYPES.values())):
raise dbt.exceptions.InternalException(
'Bad model type: {}'.format(type(model))
)
def edges(self):
return self.graph.edges()

prepended_ctes: List[InjectedCTE] = []
def nodes(self):
return self.graph.nodes()

for cte in model.extra_ctes:
cte_id = cte.id
cte_to_add = manifest.nodes.get(cte_id)
cte_to_add, new_prepended_ctes, manifest = recursively_prepend_ctes(
cte_to_add, manifest)
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
new_cte_name = '__dbt__CTE__{}'.format(cte_to_add.name)
sql = ' {} as (\n{}\n)'.format(new_cte_name, cte_to_add.compiled_sql)
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte_id, sql=sql))
def find_cycles(self):
try:
cycle = nx.find_cycle(self.graph)
except nx.NetworkXNoCycle:
return None
else:
# cycles is a List[Tuple[str, ...]]
return " --> ".join(c[0] for c in cycle)

model.prepend_ctes(prepended_ctes)
def dependency(self, node1, node2):
"indicate that node1 depends on node2"
self.graph.add_node(node1)
self.graph.add_node(node2)
self.graph.add_edge(node2, node1)

manifest.update_node(model)
def add_node(self, node):
self.graph.add_node(node)

return (model, prepended_ctes, manifest)
def write_graph(self, outfile: str, manifest: Manifest):
"""Write the graph to a gpickle file. Before doing so, serialize and
include all nodes in their corresponding graph entries.
"""
out_graph = self.graph.copy()
for node_id in self.graph.nodes():
data = manifest.expect(node_id).to_dict()
out_graph.add_node(node_id, **data)
nx.write_gpickle(out_graph, outfile)


class Compiler:
def __init__(self, config):
self.config = config

def initialize(self):
dbt.clients.system.make_directory(self.config.target_path)
dbt.clients.system.make_directory(self.config.modules_path)
make_directory(self.config.target_path)
make_directory(self.config.modules_path)

def _create_node_context(
self,
Expand All @@ -149,11 +152,80 @@ def _create_node_context(
context.update(extra_context)
if isinstance(node, CompiledSchemaTestNode):
# for test nodes, add a special keyword args value to the context
dbt.clients.jinja.add_rendered_test_kwargs(context, node)
jinja.add_rendered_test_kwargs(context, node)

return context

def compile_node(self, node, manifest, extra_context=None):
def _get_compiled_model(
self,
manifest: Manifest,
cte_id: str,
extra_context: Dict[str, Any],
) -> NonSourceCompiledNode:

if cte_id not in manifest.nodes:
raise InternalException(
f'During compilation, found a cte reference that could not be '
f'resolved: {cte_id}'
)
cte_model = manifest.nodes[cte_id]
if getattr(cte_model, 'compiled', False):
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
return cast(NonSourceCompiledNode, cte_model)
elif cte_model.is_ephemeral_model:
# this must be some kind of parsed node that we can compile.
# we know it's not a parsed source definition
assert isinstance(cte_model, tuple(COMPILED_TYPES))
# update the node so
node = self.compile_node(cte_model, manifest, extra_context)
manifest.sync_update_node(node)
return node
else:
raise InternalException(
f'During compilation, found an uncompiled cte that '
f'was not an ephemeral model: {cte_id}'
)

def _recursively_prepend_ctes(
self,
model: NonSourceCompiledNode,
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
if model.extra_ctes_injected:
return (model, model.extra_ctes)

if flags.STRICT_MODE:
if not isinstance(model, tuple(COMPILED_TYPES.values())):
raise InternalException(
f'Bad model type: {type(model)}'
)

prepended_ctes: List[InjectedCTE] = []

for cte in model.extra_ctes:
cte_model = self._get_compiled_model(
manifest,
cte.id,
extra_context,
)
cte_model, new_prepended_ctes = self._recursively_prepend_ctes(
cte_model, manifest, extra_context
)
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
new_cte_name = add_ephemeral_model_prefix(cte_model.name)
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))

model.prepend_ctes(prepended_ctes)

manifest.update_node(model)

return model, prepended_ctes

def compile_node(
self, node: NonSourceNode, manifest, extra_context=None
) -> NonSourceCompiledNode:
if extra_context is None:
extra_context = {}

Expand All @@ -173,14 +245,16 @@ def compile_node(self, node, manifest, extra_context=None):
compiled_node, manifest, extra_context
)

compiled_node.compiled_sql = dbt.clients.jinja.get_rendered(
compiled_node.compiled_sql = jinja.get_rendered(
node.raw_sql,
context,
node)

compiled_node.compiled = True

injected_node, _ = prepend_ctes(compiled_node, manifest)
injected_node, _ = self._recursively_prepend_ctes(
compiled_node, manifest, extra_context
)

return injected_node

Expand All @@ -207,7 +281,7 @@ def link_node(
(manifest.sources[dependency].unique_id)
)
else:
dbt.exceptions.dependency_not_found(node, dependency)
dependency_not_found(node, dependency)

def link_graph(self, linker: Linker, manifest: Manifest):
for source in manifest.sources.values():
Expand All @@ -220,7 +294,7 @@ def link_graph(self, linker: Linker, manifest: Manifest):
if cycle:
raise RuntimeError("Found a cycle: {}".format(cycle))

def compile(self, manifest: Manifest, write=True):
def compile(self, manifest: Manifest, write=True) -> Graph:
linker = Linker()

self.link_graph(linker, manifest)
Expand All @@ -231,10 +305,10 @@ def compile(self, manifest: Manifest, write=True):
self.write_graph_file(linker, manifest)
print_compile_stats(stats)

return linker
return Graph(linker.graph)


def compile_manifest(config, manifest, write=True) -> Linker:
def compile_manifest(config, manifest, write=True) -> Graph:
compiler = Compiler(config)
compiler.initialize()
return compiler.compile(manifest, write=write)
Expand Down
33 changes: 31 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
from dataclasses import dataclass, field
from datetime import datetime
from itertools import chain
from multiprocessing.synchronize import Lock
from typing import (
Dict, List, Optional, Union, Mapping, MutableMapping, Any, Set, Tuple,
TypeVar, Callable, Iterable, Generic
TypeVar, Callable, Iterable, Generic, cast
)
from typing_extensions import Protocol
from uuid import UUID

from hologram import JsonSchemaMixin

from dbt.contracts.graph.compiled import CompileResultNode, NonSourceNode
from dbt.contracts.graph.compiled import (
CompileResultNode, NonSourceNode, NonSourceCompiledNode
)
from dbt.contracts.graph.parsed import (
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
ParsedSourceDefinition
Expand All @@ -28,6 +31,7 @@
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt import deprecations
from dbt import flags
from dbt import tracking
import dbt.utils

Expand Down Expand Up @@ -546,6 +550,10 @@ class Disabled(Generic[D]):


def _update_into(dest: MutableMapping[str, T], new_item: T):
"""Update dest to overwrite whatever is at dest[new_item.unique_id] with
new_itme. There must be an existing value to overwrite, and they two nodes
must have the same original file path.
"""
unique_id = new_item.unique_id
if unique_id not in dest:
raise dbt.exceptions.RuntimeException(
Expand Down Expand Up @@ -577,6 +585,7 @@ class Manifest:
_docs_cache: Optional[DocCache] = None
_sources_cache: Optional[SourceCache] = None
_refs_cache: Optional[RefableCache] = None
_lock: Lock = field(default_factory=flags.MP_CONTEXT.Lock)

@classmethod
def from_macros(
Expand All @@ -598,6 +607,26 @@ def from_macros(
files=files,
)

def sync_update_node(
self, new_node: NonSourceCompiledNode
) -> NonSourceCompiledNode:
"""update the node with a lock. The only time we should want to lock is
when compiling an ephemeral ancestor of a node at runtime, because
multiple threads could be just-in-time compiling the same ephemeral
dependency, and we want them to have a consistent view of the manifest.
If the existing node is not compiled, update it with the new node and
return that. If the existing node is compiled, do not update the
manifest and return the existing node.
"""
with self._lock:
existing = self.nodes[new_node.unique_id]
if getattr(existing, 'compiled', False):
# already compiled -> must be a NonSourceCompiledNode
return cast(NonSourceCompiledNode, existing)
_update_into(self.nodes, new_node)
return new_node

def update_node(self, new_node: NonSourceNode):
_update_into(self.nodes, new_node)

Expand Down
14 changes: 14 additions & 0 deletions core/dbt/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .selector_spec import ( # noqa: F401
SelectionUnion,
SelectionSpec,
SelectionIntersection,
SelectionDifference,
SelectionCriteria,
)
from .selector import ( # noqa: F401
ResourceTypeSelector,
NodeSelector,
)
from .cli import parse_difference # noqa: F401
from .queue import GraphQueue # noqa: F401
from .graph import Graph, UniqueId # noqa: F401
Loading

0 comments on commit fb88d54

Please sign in to comment.