Skip to content

Commit

Permalink
Merge pull request #2590 from fishtown-analytics/fix/adapter-macro-na…
Browse files Browse the repository at this point in the history
…mespacing

Fix adapter macro namespacing (#2548)
  • Loading branch information
beckjake authored Jul 1, 2020
2 parents 7fa6a36 + b0c7b3a commit 3af8a22
Show file tree
Hide file tree
Showing 65 changed files with 1,641 additions and 1,245 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
## dbt 0.18.0 (Release TBD)

### Breaking changes
- Previously, dbt put macros from all installed plugins into the namespace. This version of dbt will not include adapter plugin macros unless they are from the currently-in-use adapter or one of its dependencies [#2590](https://github.com/fishtown-analytics/dbt/pull/2590)

### Features
- Added support for Snowflake query tags at the connection and model level ([#1030](https://github.com/fishtown-analytics/dbt/issues/1030), [#2555](https://github.com/fishtown-analytics/dbt/pull/2555/))
- Added option to specify profile when connecting to Redshift via IAM ([#2437](https://github.com/fishtown-analytics/dbt/issues/2437), [#2581](https://github.com/fishtown-analytics/dbt/pull/2581))

### Fixes
- Adapter plugins can once again override plugins defined in core ([#2548](https://github.com/fishtown-analytics/dbt/issues/2548), [#2590](https://github.com/fishtown-analytics/dbt/pull/2590))

Contributors:
- [@brunomurino](https://github.com/brunomurino) ([#2437](https://github.com/fishtown-analytics/dbt/pull/2581))
- [@DrMcTaco](https://github.com/DrMcTaco) ([#1030](https://github.com/fishtown-analytics/dbt/issues/1030)),[#2555](https://github.com/fishtown-analytics/dbt/pull/2555/))
Expand Down
10 changes: 5 additions & 5 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import agate

import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle
)
Expand All @@ -19,6 +18,7 @@
MacroQueryStringSetter,
)
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import flags


class BaseConnectionManager(metaclass=abc.ABCMeta):
Expand All @@ -39,7 +39,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.lock: RLock = flags.MP_CONTEXT.RLock()
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest: Manifest) -> None:
Expand All @@ -60,7 +60,7 @@ def get_thread_connection(self) -> Connection:
)
return self.thread_connections[key]

def set_thread_connection(self, conn):
def set_thread_connection(self, conn: Connection) -> None:
key = self.get_thread_identifier()
if key in self.thread_connections:
raise dbt.exceptions.InternalException(
Expand Down Expand Up @@ -235,7 +235,7 @@ def _close_handle(cls, connection: Connection) -> None:
@classmethod
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
Expand All @@ -253,7 +253,7 @@ def _rollback(cls, connection: Connection) -> None:

@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
if flags.STRICT_MODE:
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
Expand Down
29 changes: 13 additions & 16 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
from concurrent.futures import as_completed, Future
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from itertools import chain
from typing import (
Expand All @@ -17,21 +16,24 @@
get_relation_returned_multiple_results,
InternalException, NotImplementedException, RuntimeException,
)
import dbt.flags
from dbt import flags

from dbt import deprecations
from dbt.adapters.protocol import (
AdapterConfig,
ConnectionManagerProtocol,
)
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.contracts.graph.model_config import BaseConfig
from dbt.exceptions import warn_or_error
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values, executor

from dbt.adapters.base.connections import BaseConnectionManager, Connection
from dbt.adapters.base.connections import Connection
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
ComponentName, BaseRelation, InformationSchema, SchemaSearchMap
Expand Down Expand Up @@ -108,11 +110,6 @@ def _relation_name(rel: Optional[BaseRelation]) -> str:
return str(rel)


@dataclass
class AdapterConfig(BaseConfig):
pass


class BaseAdapter(metaclass=AdapterMeta):
"""The BaseAdapter provides an abstract base class for adapters.
Expand Down Expand Up @@ -151,7 +148,7 @@ class BaseAdapter(metaclass=AdapterMeta):
"""
Relation: Type[BaseRelation] = BaseRelation
Column: Type[BaseColumn] = BaseColumn
ConnectionManager: Type[BaseConnectionManager]
ConnectionManager: Type[ConnectionManagerProtocol]

# A set of clobber config fields accepted by this adapter
# for use in materializations
Expand Down Expand Up @@ -267,7 +264,7 @@ def load_internal_manifest(self) -> Manifest:
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
"""Check if the schema is cached, and by default logs if it is not."""

if dbt.flags.USE_CACHE is False:
if flags.USE_CACHE is False:
return False
elif (database, schema) not in self.cache:
logger.debug(
Expand Down Expand Up @@ -323,7 +320,7 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
if not flags.USE_CACHE:
return

cache_schemas = self._get_cache_schemas(manifest)
Expand Down Expand Up @@ -352,7 +349,7 @@ def set_relations_cache(
"""Run a query that gets a populated cache of the relations in the
database and set the cache on this adapter.
"""
if not dbt.flags.USE_CACHE:
if not flags.USE_CACHE:
return

with self.cache.lock:
Expand All @@ -368,7 +365,7 @@ def cache_added(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to cache a null relation for {}'.format(name)
)
if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.add(relation)
# so jinja doesn't render things
return ''
Expand All @@ -383,7 +380,7 @@ def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
raise_compiler_error(
'Attempted to drop a null relation for {}'.format(name)
)
if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.drop(relation)
return ''

Expand All @@ -405,7 +402,7 @@ def cache_renamed(
.format(src_name, dst_name, name)
)

if dbt.flags.USE_CACHE:
if flags.USE_CACHE:
self.cache.rename(from_relation, to_relation)
return ''

Expand Down
27 changes: 16 additions & 11 deletions core/dbt/adapters/base/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from typing import List, Optional, Type

from dbt.adapters.base import BaseAdapter, Credentials
from dbt.adapters.base import Credentials
from dbt.exceptions import CompilationException
from dbt.adapters.protocol import AdapterProtocol


def project_name_from_path(include_path: str) -> str:
# avoid an import cycle
from dbt.config.project import Project
partial = Project.partial_load(include_path)
if partial.project_name is None:
raise CompilationException(
f'Invalid project at {include_path}: name not set!'
)
return partial.project_name


class AdapterPlugin:
Expand All @@ -13,23 +25,16 @@ class AdapterPlugin:
"""
def __init__(
self,
adapter: Type[BaseAdapter],
adapter: Type[AdapterProtocol],
credentials: Type[Credentials],
include_path: str,
dependencies: Optional[List[str]] = None
):
# avoid an import cycle
from dbt.config.project import Project

self.adapter: Type[BaseAdapter] = adapter
self.adapter: Type[AdapterProtocol] = adapter
self.credentials: Type[Credentials] = credentials
self.include_path: str = include_path
partial = Project.partial_load(include_path)
if partial.project_name is None:
raise CompilationException(
f'Invalid project at {include_path}: name not set!'
)
self.project_name: str = partial.project_name
self.project_name: str = project_name_from_path(include_path)
self.dependencies: List[str]
if dependencies is None:
self.dependencies = []
Expand Down
133 changes: 10 additions & 123 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
@@ -1,132 +1,19 @@
from dbt.utils import filter_null_values, deep_merge, classproperty
from dbt.node_types import NodeType

import dbt.exceptions

from collections.abc import Mapping, Hashable
from dataclasses import dataclass, fields
from collections.abc import Hashable
from dataclasses import dataclass
from typing import (
Optional, TypeVar, Generic, Any, Type, Dict, Union, Iterator, Tuple,
Set
Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set
)
from typing_extensions import Protocol

from hologram import JsonSchemaMixin
from hologram.helpers import StrEnum

from dbt.contracts.util import Replaceable
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.contracts.relation import (
RelationType, ComponentName, HasQuoting, FakeAPIObject, Policy, Path
)
from dbt.exceptions import InternalException
from dbt import deprecations


class RelationType(StrEnum):
Table = 'table'
View = 'view'
CTE = 'cte'
MaterializedView = 'materializedview'
External = 'external'


class ComponentName(StrEnum):
Database = 'database'
Schema = 'schema'
Identifier = 'identifier'


class HasQuoting(Protocol):
quoting: Dict[str, bool]


class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping):
# override the mapping truthiness, len is always >1
def __bool__(self):
return True

def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key) from None

def __iter__(self):
deprecations.warn('not-a-dictionary', obj=self)
for _, name in self._get_fields():
yield name

def __len__(self):
deprecations.warn('not-a-dictionary', obj=self)
return len(fields(self.__class__))

def incorporate(self, **kwargs):
value = self.to_dict()
value = deep_merge(value, kwargs)
return self.from_dict(value)


T = TypeVar('T')


@dataclass
class _ComponentObject(FakeAPIObject, Generic[T]):
database: T
schema: T
identifier: T

def get_part(self, key: ComponentName) -> T:
if key == ComponentName.Database:
return self.database
elif key == ComponentName.Schema:
return self.schema
elif key == ComponentName.Identifier:
return self.identifier
else:
raise ValueError(
'Got a key of {}, expected one of {}'
.format(key, list(ComponentName))
)

def replace_dict(self, dct: Dict[ComponentName, T]):
kwargs: Dict[str, T] = {}
for k, v in dct.items():
kwargs[str(k)] = v
return self.replace(**kwargs)


@dataclass
class Policy(_ComponentObject[bool]):
database: bool = True
schema: bool = True
identifier: bool = True


@dataclass
class Path(_ComponentObject[Optional[str]]):
database: Optional[str]
schema: Optional[str]
identifier: Optional[str]

def __post_init__(self):
# handle pesky jinja2.Undefined sneaking in here and messing up render
if not isinstance(self.database, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path database: {}'.format(self.database)
)
if not isinstance(self.schema, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path schema: {}'.format(self.schema)
)
if not isinstance(self.identifier, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path identifier: {}'.format(self.identifier)
)
from dbt.node_types import NodeType
from dbt.utils import filter_null_values, deep_merge, classproperty

def get_lowered_part(self, key: ComponentName) -> Optional[str]:
part = self.get_part(key)
if part is not None:
part = part.lower()
return part
import dbt.exceptions


Self = TypeVar('Self', bound='BaseRelation')
Expand Down Expand Up @@ -161,7 +48,7 @@ def __eq__(self, other):
return self.to_dict() == other.to_dict()

@classmethod
def get_default_quote_policy(cls: Type[Self]) -> Policy:
def get_default_quote_policy(cls) -> Policy:
return cls._get_field_named('quote_policy').default

def get(self, key, default=None):
Expand Down
Loading

0 comments on commit 3af8a22

Please sign in to comment.