Skip to content

Commit

Permalink
Have the adapter be responsible for producing the compiler
Browse files Browse the repository at this point in the history
The adapter's Relation is consulted for adding the ephemeral model prefix

Also hide some things from Jinja

Have the adapter be responsible for producing the compiler, move CTE generation into the Relation object
  • Loading branch information
Jacob Beck committed Aug 18, 2020
1 parent c29892e commit f80a759
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 94 deletions.
13 changes: 11 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
)
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.compiled import (
CompileResultNode, CompiledSeedNode
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
Expand Down Expand Up @@ -289,7 +291,10 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]:
return {
self.Relation.create_from(self.config, node).without_identifier()
for node in manifest.nodes.values()
if node.resource_type in NodeType.executable()
if (
node.resource_type in NodeType.executable() and
not node.is_ephemeral_model
)
}

def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
Expand Down Expand Up @@ -1142,6 +1147,10 @@ def get_rows_different_sql(

return sql

def get_compiler(self):
from dbt.compilation import Compiler
return Compiler(self.config)


COLUMNS_EQUAL_SQL = '''
with diff_count as (
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,23 @@ def create_from_source(
**kwargs
)

@staticmethod
def add_ephemeral_prefix(name: str):
return f'__dbt__CTE__{name}'

@classmethod
def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: Union[ParsedNode, CompiledNode],
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
return cls.create(
type=cls.CTE,
identifier=identifier,
).quote(identifier=False)

@classmethod
def create_from_node(
cls: Type[Self],
Expand Down
33 changes: 30 additions & 3 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from dataclasses import dataclass
from typing import (
Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, ClassVar,
Tuple, Union
Tuple, Union, Dict, Any
)
from typing_extensions import Protocol

import agate

from dbt.contracts.connection import Connection, AdapterRequiredConfig
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.compiled import (
CompiledNode, NonSourceNode, NonSourceCompiledNode
)
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
from dbt.contracts.graph.model_config import BaseConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.relation import Policy, HasQuoting

from dbt.graph import Graph


@dataclass
class AdapterConfig(BaseConfig):
Expand Down Expand Up @@ -45,6 +49,19 @@ def create_from(
...


class CompilerProtocol(Protocol):
def compile(self, manifest: Manifest, write=True) -> Graph:
...

def compile_node(
self,
node: NonSourceNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
) -> NonSourceCompiledNode:
...


AdapterConfig_T = TypeVar(
'AdapterConfig_T', bound=AdapterConfig
)
Expand All @@ -57,11 +74,18 @@ def create_from(
Column_T = TypeVar(
'Column_T', bound=ColumnProtocol
)
Compiler_T = TypeVar('Compiler_T', bound=CompilerProtocol)


class AdapterProtocol(
Protocol,
Generic[AdapterConfig_T, ConnectionManager_T, Relation_T, Column_T]
Generic[
AdapterConfig_T,
ConnectionManager_T,
Relation_T,
Column_T,
Compiler_T,
]
):
AdapterSpecificConfigs: ClassVar[Type[AdapterConfig_T]]
Column: ClassVar[Type[Column_T]]
Expand Down Expand Up @@ -132,3 +156,6 @@ def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[str, agate.Table]:
...

def get_compiler(self) -> Compiler_T:
...
73 changes: 47 additions & 26 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from collections import defaultdict
from typing import List, Dict, Any, Tuple, cast
from typing import List, Dict, Any, Tuple, cast, Optional

import networkx as nx # type: ignore

from dbt import flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model
Expand All @@ -21,7 +22,7 @@
from dbt.graph import Graph
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt.utils import add_ephemeral_model_prefix, pluralize
from dbt.utils import pluralize

graph_file_name = 'graph.gpickle'

Expand Down Expand Up @@ -156,6 +157,11 @@ def _create_node_context(

return context

def add_ephemeral_prefix(self, name: str):
adapter = get_adapter(self.config)
relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name)

def _get_compiled_model(
self,
manifest: Manifest,
Expand Down Expand Up @@ -213,7 +219,8 @@ def _recursively_prepend_ctes(
cte_model, manifest, extra_context
)
_extend_prepended_ctes(prepended_ctes, new_prepended_ctes)
new_cte_name = add_ephemeral_model_prefix(cte_model.name)

new_cte_name = self.add_ephemeral_prefix(cte_model.name)
sql = f' {new_cte_name} as (\n{cte_model.compiled_sql}\n)'
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))

Expand All @@ -223,8 +230,11 @@ def _recursively_prepend_ctes(

return model, prepended_ctes

def compile_node(
self, node: NonSourceNode, manifest, extra_context=None
def _compile_node(
self,
node: NonSourceNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
) -> NonSourceCompiledNode:
if extra_context is None:
extra_context = {}
Expand Down Expand Up @@ -295,6 +305,7 @@ def link_graph(self, linker: Linker, manifest: Manifest):
raise RuntimeError("Found a cycle: {}".format(cycle))

def compile(self, manifest: Manifest, write=True) -> Graph:
self.initialize()
linker = Linker()

self.link_graph(linker, manifest)
Expand All @@ -307,11 +318,38 @@ def compile(self, manifest: Manifest, write=True) -> Graph:

return Graph(linker.graph)

def _write_node(self, node: NonSourceNode) -> NonSourceNode:
if not _is_writable(node):
return node
logger.debug(f'Writing injected SQL for node "{node.unique_id}"')

if node.injected_sql is None:
# this should not really happen, but it'd be a shame to crash
# over it
logger.error(
f'Compiled node "{node.unique_id}" had no injected_sql, '
'cannot write sql!'
)
else:
node.build_path = node.write_node(
self.config.target_path,
'compiled',
node.injected_sql
)
return node

def compile_node(
self,
node: NonSourceNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
write: bool = True,
) -> NonSourceCompiledNode:
node = self._compile_node(node, manifest, extra_context)

def compile_manifest(config, manifest, write=True) -> Graph:
compiler = Compiler(config)
compiler.initialize()
return compiler.compile(manifest, write=write)
if write and _is_writable(node):
self._write_node(node)
return node


def _is_writable(node):
Expand All @@ -322,20 +360,3 @@ def _is_writable(node):
return False

return True


def compile_node(adapter, config, node, manifest, extra_context, write=True):
compiler = Compiler(config)
node = compiler.compile_node(node, manifest, extra_context)

if write and _is_writable(node):
logger.debug('Writing injected SQL for node "{}"'.format(
node.unique_id))

node.build_path = node.write_node(
config.target_path,
'compiled',
node.injected_sql
)

return node
Loading

0 comments on commit f80a759

Please sign in to comment.