diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e286d32f82..2cbaf6f608e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Features - Added support for Snowflake query tags at the connection and model level ([#1030](https://github.com/fishtown-analytics/dbt/issues/1030), [#2555](https://github.com/fishtown-analytics/dbt/pull/2555/)) +- Added new node selector methods (`config`, `test_type`, `test_name`, `package`) ([#2425](https://github.com/fishtown-analytics/dbt/issues/2425), [#2629](https://github.com/fishtown-analytics/dbt/pull/2629)) - Added option to specify profile when connecting to Redshift via IAM ([#2437](https://github.com/fishtown-analytics/dbt/issues/2437), [#2581](https://github.com/fishtown-analytics/dbt/pull/2581)) - Add more helpful error message for misconfiguration in profiles.yml ([#2569](https://github.com/fishtown-analytics/dbt/issues/2569), [#2627](https://github.com/fishtown-analytics/dbt/pull/2627)) ### Fixes diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index dfd5555b37c..494cc99ee00 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -120,6 +120,7 @@ def insensitive_patterns(*patterns: str): Severity = NewType('Severity', str) + register_pattern(Severity, insensitive_patterns('warn', 'error')) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 06366d691c5..41e6ded9a43 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -393,6 +393,12 @@ def __init__(self, thread_id, known, node=None): ) +class InvalidSelectorException(RuntimeException): + def __init__(self, name: str): + self.name = name + super().__init__(name) + + def raise_compiler_error(msg, node=None) -> NoReturn: raise CompilationException(msg, node) diff --git a/core/dbt/graph/__init__.py b/core/dbt/graph/__init__.py index e2355950e8f..4a96f7da650 100644 --- a/core/dbt/graph/__init__.py +++ b/core/dbt/graph/__init__.py @@ -9,6 +9,9 @@ ResourceTypeSelector, NodeSelector, ) -from .cli import parse_difference # noqa: F401 +from .cli import ( # noqa: F401 + parse_difference, + parse_test_selectors, +) from .queue import GraphQueue # noqa: F401 from .graph import Graph, UniqueId # noqa: F401 diff --git a/core/dbt/graph/cli.py b/core/dbt/graph/cli.py index 216250be4c3..3b0d946e3e8 100644 --- a/core/dbt/graph/cli.py +++ b/core/dbt/graph/cli.py @@ -17,6 +17,8 @@ DEFAULT_INCLUDES: List[str] = ['fqn:*', 'source:*'] DEFAULT_EXCLUDES: List[str] = [] +DATA_TEST_SELECTOR: str = 'test_type:data' +SCHEMA_TEST_SELECTOR: str = 'test_type:schema' def parse_union( @@ -64,3 +66,34 @@ def parse_difference( included = parse_union_from_default(include, DEFAULT_INCLUDES) excluded = parse_union_from_default(exclude, DEFAULT_EXCLUDES) return SelectionDifference(components=[included, excluded]) + + +def parse_test_selectors( + data: bool, schema: bool, base: SelectionSpec +) -> SelectionSpec: + union_components = [] + + if data: + union_components.append( + SelectionCriteria.from_single_spec(DATA_TEST_SELECTOR) + ) + if schema: + union_components.append( + SelectionCriteria.from_single_spec(SCHEMA_TEST_SELECTOR) + ) + + intersect_with: SelectionSpec + if not union_components: + return base + elif len(union_components) == 1: + intersect_with = union_components[0] + else: # data and schema tests + intersect_with = SelectionUnion( + components=union_components, + expect_exists=True, + raw=[DATA_TEST_SELECTOR, SCHEMA_TEST_SELECTOR], + ) + + return SelectionIntersection( + components=[base, intersect_with], expect_exists=True + ) diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index 44e57def5e0..f47e1bc48c2 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -1,23 +1,18 @@ -from typing import ( - Set, List, Dict, Union, Type -) +from typing import Set, List, Union from .graph import Graph, UniqueId from .queue import GraphQueue -from .selector_methods import ( - MethodName, - SelectorMethod, - QualifiedNameSelectorMethod, - TagSelectorMethod, - SourceSelectorMethod, - PathSelectorMethod, -) +from .selector_methods import MethodManager from .selector_spec import SelectionCriteria, SelectionSpec from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType -from dbt.exceptions import InternalException, warn_or_error +from dbt.exceptions import ( + InternalException, + InvalidSelectorException, + warn_or_error, +) from dbt.contracts.graph.compiled import NonSourceNode, CompileResultNode from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSourceDefinition @@ -35,16 +30,9 @@ def alert_non_existence(raw_spec, nodes): ) -class InvalidSelectorError(Exception): - # this internal exception should never escape the module. - pass - - -class NodeSelector: +class NodeSelector(MethodManager): """The node selector is aware of the graph and manifest, """ - SELECTOR_METHODS: Dict[MethodName, Type[SelectorMethod]] = {} - def __init__( self, graph: Graph, @@ -53,16 +41,13 @@ def __init__( self.full_graph = graph self.manifest = manifest - @classmethod - def register_method(cls, name: MethodName, method: Type[SelectorMethod]): - cls.SELECTOR_METHODS[name] = method - - def get_method(self, method: MethodName) -> SelectorMethod: - if method in self.SELECTOR_METHODS: - cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method] - return cls(self.manifest) - else: - raise InvalidSelectorError(method) + # build a subgraph containing only non-empty, enabled nodes and enabled + # sources. + graph_members = { + unique_id for unique_id in self.full_graph.nodes() + if self._is_graph_member(unique_id) + } + self.graph = self.full_graph.subgraph(graph_members) def select_included( self, included_nodes: Set[UniqueId], spec: SelectionCriteria, @@ -70,22 +55,24 @@ def select_included( """Select the explicitly included nodes, using the given spec. Return the selected set of unique IDs. """ - method = self.get_method(spec.method) + method = self.get_method(spec.method, spec.method_arguments) return set(method.search(included_nodes, spec.value)) def get_nodes_from_criteria( - self, graph: Graph, spec: SelectionCriteria + self, + spec: SelectionCriteria, ) -> Set[UniqueId]: - """Given a Graph, get all nodes specified by the spec. + """Get all nodes specified by the single selection criteria. - collect the directly included nodes - find their specified relatives - perform any selector-specific expansion """ - nodes = graph.nodes() + + nodes = self.graph.nodes() try: collected = self.select_included(nodes, spec) - except InvalidSelectorError: + except InvalidSelectorException: valid_selectors = ", ".join(self.SELECTOR_METHODS) logger.info( f"The '{spec.method}' selector specified in {spec.raw} is " @@ -93,12 +80,12 @@ def get_nodes_from_criteria( ) return set() - extras = self.collect_specified_neighbors(spec, graph, collected) - result = self.expand_selection(graph, collected | extras) + extras = self.collect_specified_neighbors(spec, collected) + result = self.expand_selection(collected | extras) return result def collect_specified_neighbors( - self, spec: SelectionCriteria, graph: Graph, selected: Set[UniqueId] + self, spec: SelectionCriteria, selected: Set[UniqueId] ) -> Set[UniqueId]: """Given the set of models selected by the explicit part of the selector (like "tag:foo"), apply the modifiers on the spec ("+"/"@"). @@ -107,18 +94,18 @@ def collect_specified_neighbors( """ additional: Set[UniqueId] = set() if spec.select_childrens_parents: - additional.update(graph.select_childrens_parents(selected)) + additional.update(self.graph.select_childrens_parents(selected)) + if spec.select_parents: - additional.update( - graph.select_parents(selected, spec.select_parents_max_depth) - ) + depth = spec.select_parents_max_depth + additional.update(self.graph.select_parents(selected, depth)) + if spec.select_children: - additional.update( - graph.select_children(selected, spec.select_children_max_depth) - ) + depth = spec.select_children_max_depth + additional.update(self.graph.select_children(selected, depth)) return additional - def select_nodes(self, graph: Graph, spec: SelectionSpec) -> Set[UniqueId]: + def select_nodes(self, spec: SelectionSpec) -> Set[UniqueId]: """Select the nodes in the graph according to the spec. If the spec is a composite spec (a union, difference, or intersection), @@ -126,16 +113,13 @@ def select_nodes(self, graph: Graph, spec: SelectionSpec) -> Set[UniqueId]: selection criteria, resolve that using the given graph. """ if isinstance(spec, SelectionCriteria): - result = self.get_nodes_from_criteria(graph, spec) + result = self.get_nodes_from_criteria(spec) else: node_selections = [ - self.select_nodes(graph, component) + self.select_nodes(component) for component in spec ] - if node_selections: - result = spec.combine_selections(node_selections) - else: - result = set() + result = spec.combined(node_selections) if spec.expect_exists: alert_non_existence(spec.raw, result) return result @@ -168,16 +152,6 @@ def _is_match(self, unique_id: UniqueId) -> bool: ) return self.node_is_match(node) - def build_graph_member_subgraph(self) -> Graph: - """Build a subgraph of all enabled, non-empty nodes based on the full - graph. - """ - graph_members = { - unique_id for unique_id in self.full_graph.nodes() - if self._is_graph_member(unique_id) - } - return self.full_graph.subgraph(graph_members) - def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]: """Return the subset of selected nodes that is a match for this selector. @@ -186,17 +160,13 @@ def filter_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]: unique_id for unique_id in selected if self._is_match(unique_id) } - def expand_selection( - self, filtered_graph: Graph, selected: Set[UniqueId] - ) -> Set[UniqueId]: + def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]: """Perform selector-specific expansion.""" return selected def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]: """get_selected runs trhough the node selection process: - - build a subgraph containing only non-empty, enabled nodes and - enabled sources. - node selection. Based on the include/exclude sets, the set of matched unique IDs is returned - expand the graph at each leaf node, before combination @@ -206,8 +176,7 @@ def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]: - selectors can filter the nodes after all of them have been selected """ - filtered_graph = self.build_graph_member_subgraph() - selected_nodes = self.select_nodes(filtered_graph, spec) + selected_nodes = self.select_nodes(spec) filtered_nodes = self.filter_selection(selected_nodes) return filtered_nodes @@ -236,9 +205,3 @@ def __init__( def node_is_match(self, node): return node.resource_type in self.resource_types - - -NodeSelector.register_method(MethodName.FQN, QualifiedNameSelectorMethod) -NodeSelector.register_method(MethodName.Tag, TagSelectorMethod) -NodeSelector.register_method(MethodName.Source, SourceSelectorMethod) -NodeSelector.register_method(MethodName.Path, PathSelectorMethod) diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index 84af87cf23c..1b029396028 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -1,17 +1,45 @@ import abc from itertools import chain from pathlib import Path -from typing import Set, List +from typing import Set, List, Dict, Iterator, Tuple, Any, Union, Type from hologram.helpers import StrEnum -from dbt.exceptions import RuntimeException from .graph import UniqueId +from dbt.contracts.graph.compiled import ( + CompiledDataTestNode, + CompiledSchemaTestNode, + NonSourceNode, +) +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.parsed import ( + HasTestMetadata, + ParsedDataTestNode, + ParsedSchemaTestNode, + ParsedSourceDefinition, +) +from dbt.exceptions import ( + InternalException, + RuntimeException, +) + + SELECTOR_GLOB = '*' SELECTOR_DELIMITER = ':' +class MethodName(StrEnum): + FQN = 'fqn' + Tag = 'tag' + Source = 'source' + Path = 'path' + Package = 'package' + Config = 'config' + TestName = 'test_name' + TestType = 'test_type' + + def is_selected_node(real_node, node_selector): for i, selector_part in enumerate(node_selector): @@ -38,24 +66,49 @@ def is_selected_node(real_node, node_selector): return True +SelectorTarget = Union[ParsedSourceDefinition, NonSourceNode] + + class SelectorMethod(metaclass=abc.ABCMeta): - def __init__(self, manifest): - self.manifest = manifest + def __init__(self, manifest: Manifest, arguments: List[str]): + self.manifest: Manifest = manifest + self.arguments: List[str] = arguments - def parsed_nodes(self, included_nodes): - for unique_id, node in self.manifest.nodes.items(): + def parsed_nodes( + self, + included_nodes: Set[UniqueId] + ) -> Iterator[Tuple[UniqueId, NonSourceNode]]: + + for key, node in self.manifest.nodes.items(): + unique_id = UniqueId(key) if unique_id not in included_nodes: continue yield unique_id, node - def source_nodes(self, included_nodes): - for unique_id, source in self.manifest.sources.items(): + def source_nodes( + self, + included_nodes: Set[UniqueId] + ) -> Iterator[Tuple[UniqueId, ParsedSourceDefinition]]: + + for key, source in self.manifest.sources.items(): + unique_id = UniqueId(key) if unique_id not in included_nodes: continue yield unique_id, source + def all_nodes( + self, + included_nodes: Set[UniqueId] + ) -> Iterator[Tuple[UniqueId, SelectorTarget]]: + yield from chain(self.parsed_nodes(included_nodes), + self.source_nodes(included_nodes)) + @abc.abstractmethod - def search(self, included_nodes: Set[UniqueId], selector: str): + def search( + self, + included_nodes: Set[UniqueId], + selector: str, + ) -> Iterator[UniqueId]: raise NotImplementedError('subclasses should implement this') @@ -88,7 +141,9 @@ def node_is_match( return False - def search(self, included_nodes, selector): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: """Yield all nodes in the graph that match the selector. :param str selector: The selector or node name @@ -106,18 +161,20 @@ def search(self, included_nodes, selector): class TagSelectorMethod(SelectorMethod): - def search(self, included_nodes, selector): - """ yields nodes from graph that have the specified tag """ - search = chain(self.parsed_nodes(included_nodes), - self.source_nodes(included_nodes)) - for node, real_node in search: + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + """ yields nodes from included that have the specified tag """ + for node, real_node in self.all_nodes(included_nodes): if selector in real_node.tags: yield node class SourceSelectorMethod(SelectorMethod): - def search(self, included_nodes, selector): - """yields nodes from graph are the specified source.""" + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + """yields nodes from included are the specified source.""" parts = selector.split('.') target_package = SELECTOR_GLOB if len(parts) == 1: @@ -145,17 +202,16 @@ def search(self, included_nodes, selector): class PathSelectorMethod(SelectorMethod): - def search(self, included_nodes, selector): - """Yield all nodes in the graph that match the given path. + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + """Yields nodes from inclucded that match the given path. - :param str selector: The path selector """ # use '.' and not 'root' for easy comparison root = Path.cwd() paths = set(p.relative_to(root) for p in root.glob(selector)) - search = chain(self.parsed_nodes(included_nodes), - self.source_nodes(included_nodes)) - for node, real_node in search: + for node, real_node in self.all_nodes(included_nodes): if Path(real_node.root_path) != root: continue ofp = Path(real_node.original_file_path) @@ -165,8 +221,122 @@ def search(self, included_nodes, selector): yield node -class MethodName(StrEnum): - FQN = 'fqn' - Tag = 'tag' - Source = 'source' - Path = 'path' +class PackageSelectorMethod(SelectorMethod): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + """Yields nodes from included that have the specified package""" + for node, real_node in self.all_nodes(included_nodes): + if real_node.package_name == selector: + yield node + + +def _getattr_descend(obj: Any, attrs: List[str]) -> Any: + value = obj + for attr in attrs: + try: + value = getattr(value, attr) + except AttributeError: + # if it implements getitem (dict, list, ...), use that. On failure, + # raise an attribute error instead of the KeyError, TypeError, etc. + # that arbitrary getitem calls might raise + try: + value = value[attr] + except Exception as exc: + raise AttributeError( + f"'{type(value)}' object has no attribute '{attr}'" + ) from exc + return value + + +class CaseInsensitive(str): + def __eq__(self, other): + if isinstance(other, str): + return self.upper() == other.upper() + else: + return self.upper() == other + + +class ConfigSelectorMethod(SelectorMethod): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + parts = self.arguments + # special case: if the user wanted to compare test severity, + # make the comparison case-insensitive + if parts == ['severity']: + selector = CaseInsensitive(selector) + + # search sources is kind of useless now source configs only have + # 'enabled', which you can't really filter on anyway, but maybe we'll + # add more someday, so search them anyway. + for node, real_node in self.all_nodes(included_nodes): + try: + value = _getattr_descend(real_node.config, parts) + except AttributeError: + continue + else: + # the selector can only be a str, so call str() on the value. + # of course, if one wished to render the selector in the jinja + # native env, this would no longer be true + + if selector == str(value): + yield node + + +class TestNameSelectorMethod(SelectorMethod): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + for node, real_node in self.parsed_nodes(included_nodes): + if isinstance(real_node, HasTestMetadata): + if real_node.test_metadata.name == selector: + yield node + + +class TestTypeSelectorMethod(SelectorMethod): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + search_types: Tuple[Type, ...] + if selector == 'schema': + search_types = (ParsedSchemaTestNode, CompiledSchemaTestNode) + elif selector == 'data': + search_types = (ParsedDataTestNode, CompiledDataTestNode) + else: + raise RuntimeException( + f'Invalid test type selector {selector}: expected "data" or ' + '"schema"' + ) + + for node, real_node in self.parsed_nodes(included_nodes): + if isinstance(real_node, search_types): + yield node + + +class MethodManager: + SELECTOR_METHODS: Dict[MethodName, Type[SelectorMethod]] = { + MethodName.FQN: QualifiedNameSelectorMethod, + MethodName.Tag: TagSelectorMethod, + MethodName.Source: SourceSelectorMethod, + MethodName.Path: PathSelectorMethod, + MethodName.Package: PackageSelectorMethod, + MethodName.Config: ConfigSelectorMethod, + MethodName.TestName: TestNameSelectorMethod, + MethodName.TestType: TestTypeSelectorMethod, + } + + def __init__(self, manifest: Manifest): + self.manifest = manifest + + def get_method( + self, method: MethodName, method_arguments: List[str] + ) -> SelectorMethod: + + if method not in self.SELECTOR_METHODS: + raise InternalException( + f'Method name "{method}" is a valid node selection ' + f'method name, but it is not handled' + ) + cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method] + return cls(self.manifest, method_arguments) diff --git a/core/dbt/graph/selector_spec.py b/core/dbt/graph/selector_spec.py index d0ab4cf8103..ee86e091ca9 100644 --- a/core/dbt/graph/selector_spec.py +++ b/core/dbt/graph/selector_spec.py @@ -4,21 +4,22 @@ from dataclasses import dataclass from typing import ( - Set, Iterator, List, Optional, Dict, Union, Any, Sequence + Set, Iterator, List, Optional, Dict, Union, Any, Iterable, Tuple ) from .graph import UniqueId from .selector_methods import MethodName -from dbt.exceptions import RuntimeException +from dbt.exceptions import RuntimeException, InvalidSelectorException RAW_SELECTOR_PATTERN = re.compile( r'\A' r'(?P(\@))?' r'(?P((?P(\d*))\+))?' - r'((?P(\w+)):)?(?P(.*?))' + r'((?P([\w.]+)):)?(?P(.*?))' r'(?P(\+(?P(\d*))))?' r'\Z' ) +SELECTOR_METHOD_SEPARATOR = '.' def _probably_path(value: str): @@ -58,6 +59,7 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]: class SelectionCriteria: raw: str method: MethodName + method_arguments: List[str] value: str select_childrens_parents: bool select_parents: bool @@ -80,17 +82,22 @@ def default_method(cls, value: str) -> MethodName: return MethodName.FQN @classmethod - def parse_method(cls, raw: str, groupdict: Dict[str, Any]) -> MethodName: + def parse_method( + cls, raw: str, groupdict: Dict[str, Any] + ) -> Tuple[MethodName, List[str]]: raw_method = groupdict.get('method') if raw_method is None: - return cls.default_method(groupdict['value']) + return cls.default_method(groupdict['value']), [] + method_parts: List[str] = raw_method.split(SELECTOR_METHOD_SEPARATOR) try: - return MethodName(raw_method) - except ValueError: - raise RuntimeException( - f'unknown selector filter "{raw_method}" in "{raw}"' - ) from None + method_name = MethodName(method_parts[0]) + except ValueError as exc: + raise InvalidSelectorException(method_parts[0]) from exc + + method_arguments: List[str] = method_parts[1:] + + return method_name, method_arguments @classmethod def from_single_spec(cls, raw: str) -> 'SelectionCriteria': @@ -105,14 +112,15 @@ def from_single_spec(cls, raw: str) -> 'SelectionCriteria': f'Invalid node spec "{raw}" - no search value!' ) - method = cls.parse_method(raw, result_dict) + method_name, method_arguments = cls.parse_method(raw, result_dict) parents_max_depth = _match_to_int(result_dict, 'parents_depth') children_max_depth = _match_to_int(result_dict, 'children_depth') return cls( raw=raw, - method=method, + method=method_name, + method_arguments=method_arguments, value=result_dict['value'], select_childrens_parents=bool(result_dict.get('childs_parents')), select_parents=bool(result_dict.get('parents')), @@ -122,10 +130,10 @@ def from_single_spec(cls, raw: str) -> 'SelectionCriteria': ) -class BaseSelectionGroup(metaclass=ABCMeta): +class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta): def __init__( self, - components: Sequence[SelectionSpec], + components: Iterable[SelectionSpec], expect_exists: bool = False, raw: Any = None, ): @@ -146,6 +154,12 @@ def combine_selections( '_combine_selections not implemented!' ) + def combined(self, selections: List[Set[UniqueId]]) -> Set[UniqueId]: + if not selections: + return set() + + return self.combine_selections(selections) + class SelectionIntersection(BaseSelectionGroup): def combine_selections( diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 4fef6c16220..0d548c0676c 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -1,7 +1,11 @@ import json from typing import Type -from dbt.graph import ResourceTypeSelector, parse_difference, SelectionSpec +from dbt.graph import ( + parse_difference, + ResourceTypeSelector, + SelectionSpec, +) from dbt.task.runnable import GraphRunnableTask, ManifestTask from dbt.task.test import TestSelector from dbt.node_types import NodeType @@ -158,8 +162,6 @@ def get_node_selector(self): return TestSelector( graph=self.graph, manifest=self.manifest, - schema=True, - data=True ) else: return ResourceTypeSelector( diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index bc99aef681f..170250b260a 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -16,7 +16,12 @@ ) from dbt.contracts.results import RunModelResult from dbt.exceptions import raise_compiler_error, InternalException -from dbt.graph import ResourceTypeSelector, Graph, UniqueId +from dbt.graph import ( + ResourceTypeSelector, + SelectionSpec, + UniqueId, + parse_test_selectors, +) from dbt.node_types import NodeType, RunHookType from dbt import flags @@ -102,38 +107,20 @@ def after_execute(self, result): class TestSelector(ResourceTypeSelector): - def __init__( - self, graph, manifest, data: bool, schema: bool - ): + def __init__(self, graph, manifest): super().__init__( graph=graph, manifest=manifest, resource_types=[NodeType.Test], ) - self.data = data - self.schema = schema - def expand_selection( - self, filtered_graph: Graph, selected: Set[UniqueId] - ) -> Set[UniqueId]: + def expand_selection(self, selected: Set[UniqueId]) -> Set[UniqueId]: selected_tests = { - n for n in filtered_graph.select_successors(selected) + n for n in self.graph.select_successors(selected) if self.manifest.nodes[n].resource_type == NodeType.Test } return selected | selected_tests - def node_is_match(self, node): - if super().node_is_match(node): - test_types = [self.data, self.schema] - - if all(test_types) or not any(test_types): - return True - elif self.data: - return isinstance(node, DATA_TEST_TYPES) - elif self.schema: - return isinstance(node, SCHEMA_TEST_TYPES) - return False - class TestTask(RunTask): """ @@ -150,6 +137,14 @@ def safe_run_hooks( # Don't execute on-run-* hooks for tests pass + def get_selection_spec(self) -> SelectionSpec: + base_spec = super().get_selection_spec() + return parse_test_selectors( + data=self.args.data, + schema=self.args.schema, + base=base_spec + ) + def get_node_selector(self) -> TestSelector: if self.manifest is None or self.graph is None: raise InternalException( @@ -158,8 +153,6 @@ def get_node_selector(self) -> TestSelector: return TestSelector( graph=self.graph, manifest=self.manifest, - data=self.args.data, - schema=self.args.schema, ) def get_runner_type(self): diff --git a/test/integration/007_graph_selection_tests/models/schema.yml b/test/integration/007_graph_selection_tests/models/schema.yml index 6e06720672e..9f4f699a926 100644 --- a/test/integration/007_graph_selection_tests/models/schema.yml +++ b/test/integration/007_graph_selection_tests/models/schema.yml @@ -4,7 +4,8 @@ models: columns: - name: email tests: - - unique + - not_null: + severity: warn - name: users columns: - name: id diff --git a/test/integration/007_graph_selection_tests/test_graph_selection.py b/test/integration/007_graph_selection_tests/test_graph_selection.py index 8e99b06b39f..a3624e7367f 100644 --- a/test/integration/007_graph_selection_tests/test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_graph_selection.py @@ -256,6 +256,12 @@ def test__postgres__childrens_parents(self): self.assertNotIn('subdir', created_models) self.assertNotIn('nested_users', created_models) + results = self.run_dbt( + ['test', '--models', 'test_name:not_null'], + ) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'not_null_emails_email' + @use_profile('postgres') def test__postgres__more_childrens_parents(self): self.run_sql_file("seed.sql") @@ -270,6 +276,13 @@ def test__postgres__more_childrens_parents(self): self.assertNotIn('subdir', created_models) self.assertNotIn('nested_users', created_models) + results = self.run_dbt( + ['test', '--models', 'test_name:unique'], + ) + self.assertEqual(len(results), 2) + assert sorted([r.node.name for r in results]) == ['unique_users_id', 'unique_users_rollup_gender'] + + @use_profile('snowflake') def test__snowflake__skip_intermediate(self): self.run_sql_file("seed.sql") @@ -335,3 +348,10 @@ def test__postgres__concat_exclude_concat(self): self.assertNotIn('users_rollup', created_models) self.assertNotIn('subdir', created_models) self.assertNotIn('nested_users', created_models) + + results = self.run_dbt( + ['test', '--models', '@emails_alt', 'users_rollup', '--exclude', 'emails_alt', 'users_rollup'] + ) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'unique_users_id' + diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index 60a517aa098..7da7ee01ae6 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -47,7 +47,7 @@ def test__postgres__schema_tests_no_specifiers(self): self.run_schema_and_assert( None, None, - ['unique_emails_email', + ['not_null_emails_email', 'unique_table_model_id', 'unique_users_id', 'unique_users_rollup_gender'] @@ -83,7 +83,7 @@ def test__postgres__schema_tests_specify_tag_and_children(self): self.run_schema_and_assert( ['tag:base+'], None, - ['unique_emails_email', + ['not_null_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] ) @@ -109,7 +109,7 @@ def test__postgres__schema_tests_specify_exclude_only(self): self.run_schema_and_assert( None, ['users_rollup'], - ['unique_emails_email', 'unique_table_model_id', 'unique_users_id'] + ['not_null_emails_email', 'unique_table_model_id', 'unique_users_id'] ) @use_profile('postgres') @@ -127,7 +127,7 @@ def test__postgres__schema_tests_with_glob(self): self.run_schema_and_assert( ['*'], ['users'], - ['unique_emails_email', 'unique_table_model_id', 'unique_users_rollup_gender'] + ['not_null_emails_email', 'unique_table_model_id', 'unique_users_rollup_gender'] ) @use_profile('postgres') @@ -151,5 +151,5 @@ def test__postgres__schema_tests_exclude_pkg(self): self.run_schema_and_assert( None, ['dbt_integration_project'], - ['unique_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] + ['not_null_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] ) diff --git a/test/integration/007_graph_selection_tests/test_tag_selection.py b/test/integration/007_graph_selection_tests/test_tag_selection.py index 1925e281aaa..635f289482f 100644 --- a/test/integration/007_graph_selection_tests/test_tag_selection.py +++ b/test/integration/007_graph_selection_tests/test_tag_selection.py @@ -73,3 +73,20 @@ def test__postgres__select_tag_in_model_with_project_Config(self): {'users', 'users_rollup', 'emails_alt', 'users_rollup_dependency'}, models_run ) + + # just the users/users_rollup tests + results = self.run_dbt(['test', '--models', '@tag:users']) + self.assertEqual(len(results), 2) + assert sorted(r.node.name for r in results) == ['unique_users_id', 'unique_users_rollup_gender'] + # just the email test + results = self.run_dbt(['test', '--models', 'tag:base,config.materialized:ephemeral']) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'not_null_emails_email' + # also just the email test + results = self.run_dbt(['test', '--models', 'config.severity:warn']) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'not_null_emails_email' + # all 3 tests + results = self.run_dbt(['test', '--models', '@tag:users tag:base,config.materialized:ephemeral']) + self.assertEqual(len(results), 3) + assert sorted(r.node.name for r in results) == ['not_null_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] diff --git a/test/integration/047_dbt_ls_test/models/incremental.sql b/test/integration/047_dbt_ls_test/models/incremental.sql new file mode 100644 index 00000000000..2bb15542da9 --- /dev/null +++ b/test/integration/047_dbt_ls_test/models/incremental.sql @@ -0,0 +1,12 @@ +{{ + config( + materialized = "incremental", + incremental_strategy = "delete+insert", + ) +}} + +select * from {{ ref('seed') }} + +{% if is_incremental() %} + where a > (select max(a) from {{this}}) +{% endif %} diff --git a/test/integration/047_dbt_ls_test/test_ls.py b/test/integration/047_dbt_ls_test/test_ls.py index 00683001f59..62ca874103c 100644 --- a/test/integration/047_dbt_ls_test/test_ls.py +++ b/test/integration/047_dbt_ls_test/test_ls.py @@ -33,6 +33,10 @@ def project_config(self): }, } + def setUp(self): + super().setUp() + self.maxDiff = None + def run_dbt_ls(self, args=None, expect_pass=True): log_manager.stdout_console() full_args = ['ls'] @@ -123,8 +127,8 @@ def expect_analyses_output(self): def expect_model_output(self): expectations = { - 'name': ('ephemeral', 'inner', 'outer'), - 'selector': ('test.ephemeral', 'test.sub.inner', 'test.outer'), + 'name': ('ephemeral', 'incremental', 'inner', 'outer'), + 'selector': ('test.ephemeral', 'test.incremental', 'test.sub.inner', 'test.outer'), 'json': ( { 'name': 'ephemeral', @@ -146,6 +150,27 @@ def expect_model_output(self): 'alias': 'ephemeral', 'resource_type': 'model', }, + { + 'name': 'incremental', + 'package_name': 'test', + 'depends_on': {'nodes': ['seed.test.seed'], 'macros': ['macro.dbt.is_incremental']}, + 'tags': [], + 'config': { + 'enabled': True, + 'materialized': 'incremental', + 'post-hook': [], + 'tags': [], + 'pre-hook': [], + 'quoting': {}, + 'vars': {}, + 'column_types': {}, + 'persist_docs': {}, + 'full_refresh': None, + 'incremental_strategy': 'delete+insert', + }, + 'alias': 'incremental', + 'resource_type': 'model', + }, { 'name': 'inner', 'package_name': 'test', @@ -187,7 +212,7 @@ def expect_model_output(self): 'resource_type': 'model', }, ), - 'path': (self.dir('models/ephemeral.sql'), self.dir('models/sub/inner.sql'), self.dir('models/outer.sql')), + 'path': (self.dir('models/ephemeral.sql'), self.dir('models/incremental.sql'), self.dir('models/sub/inner.sql'), self.dir('models/outer.sql')), } self.expect_given_output(['--resource-type', 'model'], expectations) @@ -351,6 +376,7 @@ def expect_all_output(self): # sources are like models - (package.source_name.table_name) expected_default = { 'test.ephemeral', + 'test.incremental', 'test.snapshot.my_snapshot', 'test.sub.inner', 'test.outer', @@ -387,7 +413,13 @@ def expect_select(self): self.assertEqual(set(results), {'test.outer', 'test.sub.inner'}) results = self.run_dbt_ls(['--resource-type', 'model', '--exclude', 'inner']) - self.assertEqual(set(results), {'test.ephemeral', 'test.outer'}) + self.assertEqual(set(results), {'test.ephemeral', 'test.outer', 'test.incremental'}) + + results = self.run_dbt_ls(['--select', 'config.incremental_strategy:delete+insert']) + self.assertEqual(set(results), {'test.incremental'}) + + self.run_dbt_ls(['--select', 'config.incremental_strategy:insert_overwrite'], expect_pass=False) + @use_profile('postgres') def test_postgres_ls(self): diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index d4c19da7cd4..a5623a95793 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -7,6 +7,7 @@ import dbt.exceptions import dbt.graph.selector as graph_selector import dbt.graph.cli as graph_cli +from dbt.node_types import NodeType import networkx as nx @@ -28,7 +29,13 @@ def _get_manifest(graph): for unique_id in graph: fqn = unique_id.split('.') node = mock.MagicMock( - unique_id=unique_id, fqn=fqn, package_name=fqn[0], tags=[] + unique_id=unique_id, + fqn=fqn, + package_name=fqn[0], + tags=[], + resource_type=NodeType.Model, + empty=False, + config=mock.MagicMock(enabled=True), ) nodes[unique_id] = node @@ -111,11 +118,9 @@ def id_macro(arg): def test_run_specs(include, exclude, expected): graph = _get_graph() manifest = _get_manifest(graph) - # we can use the same graph twice. selector = graph_selector.NodeSelector(graph, manifest) spec = graph_cli.parse_difference(include, exclude) - subgraph = graph.subgraph(set(graph.nodes())) - selected = selector.select_nodes(subgraph, spec) + selected = selector.select_nodes(spec) assert selected == expected diff --git a/test/unit/test_graph_selector_methods.py b/test/unit/test_graph_selector_methods.py new file mode 100644 index 00000000000..fd3d5c6c0c7 --- /dev/null +++ b/test/unit/test_graph_selector_methods.py @@ -0,0 +1,543 @@ +import pytest + +from datetime import datetime + +from dbt.contracts.graph.parsed import ( + DependsOn, + NodeConfig, + ParsedModelNode, + ParsedSeedNode, + ParsedSnapshotNode, + ParsedDataTestNode, + ParsedSchemaTestNode, + ParsedSourceDefinition, + TestConfig, + TestMetadata, +) +from dbt.contracts.graph.manifest import Manifest +from dbt.node_types import NodeType +from dbt.graph.selector_methods import ( + MethodManager, + QualifiedNameSelectorMethod, + TagSelectorMethod, + SourceSelectorMethod, + PathSelectorMethod, + PackageSelectorMethod, + ConfigSelectorMethod, + TestNameSelectorMethod, + TestTypeSelectorMethod, +) + + +def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None): + if refs is None: + refs = [] + if sources is None: + sources = [] + if tags is None: + tags = [] + if path is None: + path = f'{name}.sql' + if alias is None: + alias = name + if config_kwargs is None: + config_kwargs = {} + + if fqn_extras is None: + fqn_extras = [] + + fqn = [pkg] + fqn_extras + [name] + + depends_on_nodes = [] + source_values = [] + ref_values = [] + for ref in refs: + ref_values.append([ref.name]) + depends_on_nodes.append(ref.unique_id) + for src in sources: + source_values.append([src.source_name, src.name]) + depends_on_nodes.append(src.unique_id) + + return ParsedModelNode( + raw_sql=sql, + database='dbt', + schema='dbt_schema', + alias=alias, + name=name, + fqn=fqn, + unique_id=f'model.{pkg}.{name}', + package_name=pkg, + root_path='/usr/dbt/some-project', + path=path, + original_file_path=f'models/{path}', + config=NodeConfig(**config_kwargs), + tags=tags, + refs=ref_values, + sources=source_values, + depends_on=DependsOn(nodes=depends_on_nodes), + resource_type=NodeType.Model, + ) + + +def make_seed(pkg, name, path=None, loader=None, alias=None, tags=None, fqn_extras=None): + if alias is None: + alias = name + if tags is None: + tags = [] + if path is None: + path = f'{name}.csv' + + if fqn_extras is None: + fqn_extras = [] + + fqn = [pkg] + fqn_extras + [name] + return ParsedSeedNode( + raw_sql='', + database='dbt', + schema='dbt_schema', + alias=alias, + name=name, + fqn=fqn, + unique_id=f'seed.{pkg}.{name}', + package_name=pkg, + root_path='/usr/dbt/some-project', + path=path, + original_file_path=f'data/{path}', + tags=tags, + resource_type=NodeType.Seed, + ) + + +def make_source(pkg, source_name, table_name, path=None, loader=None, identifier=None, fqn_extras=None): + if path is None: + path = 'models/schema.yml' + if loader is None: + loader = 'my_loader' + if identifier is None: + identifier = table_name + + if fqn_extras is None: + fqn_extras = [] + + fqn = [pkg] + fqn_extras + [source_name, table_name] + + return ParsedSourceDefinition( + fqn=fqn, + database='dbt', + schema='dbt_schema', + unique_id=f'source.{pkg}.{source_name}.{table_name}', + package_name=pkg, + root_path='/usr/dbt/some-project', + path=path, + original_file_path=path, + name=table_name, + source_name=source_name, + loader='my_loader', + identifier=identifier, + resource_type=NodeType.Source, + loaded_at_field='loaded_at', + tags=[], + source_description='', + ) + + +def make_unique_test(pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None): + return make_schema_test(pkg, 'unique', test_model, {}, column_name=column_name) + + +def make_not_null_test(pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None): + return make_schema_test(pkg, 'not_null', test_model, {}, column_name=column_name) + + +def make_schema_test(pkg, test_name, test_model, test_kwargs, path=None, refs=None, sources=None, tags=None, column_name=None): + kwargs = test_kwargs.copy() + ref_values = [] + source_values = [] + # this doesn't really have to be correct + if isinstance(test_model, ParsedSourceDefinition): + kwargs['model'] = "{{ source('" + test_model.source_name + "', '" + test_model.name + "') }}" + source_values.append([test_model.source_name, test_model.name]) + else: + kwargs['model'] = "{{ ref('" + test_model.name + "')}}" + ref_values.append([test_model.name]) + if column_name is not None: + kwargs['column_name'] = column_name + + # whatever + args_name = test_model.search_name.replace(".", "_") + if column_name is not None: + args_name += '_' + column_name + node_name = f'{test_name}_{args_name}' + raw_sql = '{{ config(severity="ERROR") }}{{ test_' + test_name + '(**dbt_schema_test_kwargs) }}' + name_parts = test_name.split('.') + + if len(name_parts) == 2: + namespace, test_name = name_parts + macro_depends = f'model.{namespace}.{test_name}' + elif len(name_parts) == 1: + namespace = None + macro_depends = f'model.dbt.{test_name}' + else: + assert False, f'invalid test name: {test_name}' + + if path is None: + path = 'schema.yml' + if tags is None: + tags = ['schema'] + + if refs is None: + refs = [] + if sources is None: + sources = [] + + depends_on_nodes = [] + for ref in refs: + ref_values.append([ref.name]) + depends_on_nodes.append(ref.unique_id) + + for source in sources: + source_values.append([source.source_name, source.name]) + depends_on_nodes.append(source.unique_id) + + return ParsedSchemaTestNode( + raw_sql=raw_sql, + test_metadata=TestMetadata( + namespace=namespace, + name=test_name, + kwargs=kwargs, + ), + database='dbt', + schema='dbt_postgres', + name=node_name, + alias=node_name, + fqn=['minimal', 'schema_test', node_name], + unique_id=f'test.{pkg}.{node_name}', + package_name=pkg, + root_path='/usr/dbt/some-project', + path=f'schema_test/{node_name}.sql', + original_file_path=f'models/{path}', + resource_type=NodeType.Test, + tags=tags, + refs=ref_values, + sources=[], + depends_on=DependsOn( + macros=[macro_depends], + nodes=['model.minimal.view_model'] + ), + column_name=column_name, + ) + + +def make_data_test(pkg, name, sql, refs=None, sources=None, tags=None, path=None, config_kwargs=None): + + if refs is None: + refs = [] + if sources is None: + sources = [] + if tags is None: + tags = ['data'] + if path is None: + path = f'{name}.sql' + + if config_kwargs is None: + config_kwargs = {} + + fqn = ['minimal', 'data_test', name] + + depends_on_nodes = [] + source_values = [] + ref_values = [] + for ref in refs: + ref_values.append([ref.name]) + depends_on_nodes.append(ref.unique_id) + for src in sources: + source_values.append([src.source_name, src.name]) + depends_on_nodes.append(src.unique_id) + + return ParsedDataTestNode( + raw_sql=sql, + database='dbt', + schema='dbt_schema', + name=name, + alias=name, + fqn=fqn, + unique_id=f'test.{pkg}.{name}', + package_name=pkg, + root_path='/usr/dbt/some-project', + path=path, + original_file_path=f'tests/{path}', + config=TestConfig(**config_kwargs), + tags=tags, + refs=ref_values, + sources=source_values, + depends_on=DependsOn(nodes=depends_on_nodes), + resource_type=NodeType.Test, + ) + + +@pytest.fixture +def seed(): + return make_seed( + 'pkg', + 'seed' + ) + + +@pytest.fixture +def source(): + return make_source( + 'pkg', + 'raw', + 'seed', + identifier='seed' + ) + + +@pytest.fixture +def ephemeral_model(source): + return make_model( + 'pkg', + 'ephemeral_model', + 'select * from {{ source("raw", "seed") }}', + config_kwargs={'materialized': 'ephemeral'}, + sources=[source], + ) + + +@pytest.fixture +def view_model(ephemeral_model): + return make_model( + 'pkg', + 'view_model', + 'select * from {{ ref("ephemeral_model") }}', + config_kwargs={'materialized': 'view'}, + refs=[ephemeral_model], + tags=['uses_ephemeral'], + ) + + +@pytest.fixture +def table_model(ephemeral_model): + return make_model( + 'pkg', + 'table_model', + 'select * from {{ ref("ephemeral_model") }}', + config_kwargs={'materialized': 'table'}, + refs=[ephemeral_model], + tags=['uses_ephemeral'], + path='subdirectory/table_model.sql' + ) + + +@pytest.fixture +def ext_source(): + return make_source( + 'ext', + 'ext_raw', + 'ext_source', + ) + + +@pytest.fixture +def ext_source_2(): + return make_source( + 'ext', + 'ext_raw', + 'ext_source_2', + ) + + +@pytest.fixture +def ext_source_other(): + return make_source( + 'ext', + 'raw', + 'ext_source', + ) + + +@pytest.fixture +def ext_source_other_2(): + return make_source( + 'ext', + 'raw', + 'ext_source_2', + ) + + +@pytest.fixture +def ext_model(ext_source): + return make_model( + 'ext', + 'ext_model', + 'select * from {{ source("ext_raw", "ext_source") }}', + sources=[ext_source], + ) + + +@pytest.fixture +def union_model(seed, ext_source): + return make_model( + 'pkg', + 'union_model', + 'select * from {{ ref("seed") }} union all select * from {{ source("ext_raw", "ext_source") }}', + config_kwargs={'materialized': 'table'}, + refs=[seed], + sources=[ext_source], + fqn_extras=['unions'], + path='subdirectory/union_model.sql', + tags=['unions'], + ) + + +@pytest.fixture +def table_id_unique(table_model): + return make_unique_test('pkg', table_model, 'id') + + +@pytest.fixture +def table_id_not_null(table_model): + return make_not_null_test('pkg', table_model, 'id') + + +@pytest.fixture +def view_id_unique(view_model): + return make_unique_test('pkg', view_model, 'id') + + +@pytest.fixture +def ext_source_id_unique(ext_source): + return make_unique_test('ext', ext_source, 'id') + + +@pytest.fixture +def view_test_nothing(view_model): + return make_data_test('pkg', 'view_test_nothing', 'select * from {{ ref("view_model") }} limit 0', refs=[view_model]) + + +@pytest.fixture +def manifest(seed, source, ephemeral_model, view_model, table_model, ext_source, ext_model, union_model, ext_source_2, ext_source_other, ext_source_other_2, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing): + nodes = [seed, ephemeral_model, view_model, table_model, union_model, ext_model, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing] + sources = [source, ext_source, ext_source_2, ext_source_other, ext_source_other_2] + manifest = Manifest( + nodes={n.unique_id: n for n in nodes}, + sources={s.unique_id: s for s in sources}, + macros={}, + docs={}, + files={}, + generated_at=datetime.utcnow(), + disabled=[], + ) + return manifest + + +def search_manifest_using_method(manifest, method, selection): + selected = method.search(set(manifest.nodes) | set(manifest.sources), selection) + results = {manifest.expect(uid).search_name for uid in selected} + return results + + +def test_select_fqn(manifest): + methods = MethodManager(manifest) + method = methods.get_method('fqn', []) + assert isinstance(method, QualifiedNameSelectorMethod) + assert method.arguments == [] + + assert search_manifest_using_method(manifest, method, 'pkg.unions') == {'union_model'} + assert not search_manifest_using_method(manifest, method, 'ext.unions') + # sources don't show up, because selection pretends they have no FQN. Should it? + assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'view_model', 'ephemeral_model', 'seed'} + assert search_manifest_using_method(manifest, method, 'ext') == {'ext_model'} + + +def test_select_tag(manifest): + methods = MethodManager(manifest) + method = methods.get_method('tag', []) + assert isinstance(method, TagSelectorMethod) + assert method.arguments == [] + + assert search_manifest_using_method(manifest, method, 'uses_ephemeral') == {'view_model', 'table_model'} + assert not search_manifest_using_method(manifest, method, 'missing') + + +def test_select_source(manifest): + methods = MethodManager(manifest) + method = methods.get_method('source', []) + assert isinstance(method, SourceSelectorMethod) + assert method.arguments == [] + + # the lookup is based on how many components you provide: source, source.table, package.source.table + assert search_manifest_using_method(manifest, method, 'raw') == {'raw.seed', 'raw.ext_source', 'raw.ext_source_2'} + assert search_manifest_using_method(manifest, method, 'raw.seed') == {'raw.seed'} + assert search_manifest_using_method(manifest, method, 'pkg.raw.seed') == {'raw.seed'} + assert search_manifest_using_method(manifest, method, 'pkg.*.*') == {'raw.seed'} + assert search_manifest_using_method(manifest, method, 'raw.*') == {'raw.seed', 'raw.ext_source', 'raw.ext_source_2'} + assert search_manifest_using_method(manifest, method, 'ext.raw.*') == {'raw.ext_source', 'raw.ext_source_2'} + assert not search_manifest_using_method(manifest, method, 'missing') + assert not search_manifest_using_method(manifest, method, 'raw.missing') + assert not search_manifest_using_method(manifest, method, 'missing.raw.seed') + + assert search_manifest_using_method(manifest, method, 'ext.*.*') == {'ext_raw.ext_source', 'ext_raw.ext_source_2', 'raw.ext_source', 'raw.ext_source_2'} + assert search_manifest_using_method(manifest, method, 'ext_raw') == {'ext_raw.ext_source', 'ext_raw.ext_source_2'} + assert search_manifest_using_method(manifest, method, 'ext.ext_raw.*') == {'ext_raw.ext_source', 'ext_raw.ext_source_2'} + assert not search_manifest_using_method(manifest, method, 'pkg.ext_raw.*') + + +# TODO: this requires writing out files +@pytest.mark.skip('TODO: write manifest files to disk') +def test_select_path(manifest): + methods = MethodManager(manifest) + method = methods.get_method('path', []) + assert isinstance(method, PathSelectorMethod) + assert method.arguments == [] + + assert search_manifest_using_method(manifest, method, 'subdirectory/*.sql') == {'union_model', 'table_model'} + assert search_manifest_using_method(manifest, method, 'subdirectory/union_model.sql') == {'union_model'} + assert search_manifest_using_method(manifest, method, 'models/*.sql') == {'view_model', 'ephemeral_model'} + assert not search_manifest_using_method(manifest, method, 'missing') + assert not search_manifest_using_method(manifest, method, 'models/missing.sql') + assert not search_manifest_using_method(manifest, method, 'models/missing*') + + +def test_select_package(manifest): + methods = MethodManager(manifest) + method = methods.get_method('package', []) + assert isinstance(method, PackageSelectorMethod) + assert method.arguments == [] + + assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'view_model', 'ephemeral_model', 'seed', 'raw.seed', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'view_test_nothing'} + assert search_manifest_using_method(manifest, method, 'ext') == {'ext_model', 'ext_raw.ext_source', 'ext_raw.ext_source_2', 'raw.ext_source', 'raw.ext_source_2', 'unique_ext_raw_ext_source_id'} + + assert not search_manifest_using_method(manifest, method, 'missing') + + +def test_select_config_materialized(manifest): + methods = MethodManager(manifest) + method = methods.get_method('config', ['materialized']) + assert isinstance(method, ConfigSelectorMethod) + assert method.arguments == ['materialized'] + + # yes, technically tests are "views" + assert search_manifest_using_method(manifest, method, 'view') == {'view_model', 'ext_model', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'unique_ext_raw_ext_source_id', 'view_test_nothing'} + assert search_manifest_using_method(manifest, method, 'table') == {'table_model', 'union_model'} + + +def test_select_test_name(manifest): + methods = MethodManager(manifest) + method = methods.get_method('test_name', []) + assert isinstance(method, TestNameSelectorMethod) + assert method.arguments == [] + + assert search_manifest_using_method(manifest, method, 'unique') == {'unique_table_model_id', 'unique_view_model_id', 'unique_ext_raw_ext_source_id'} + assert search_manifest_using_method(manifest, method, 'not_null') == {'not_null_table_model_id'} + assert not search_manifest_using_method(manifest, method, 'notatest') + + +def test_select_test_type(manifest): + methods = MethodManager(manifest) + method = methods.get_method('test_type', []) + assert isinstance(method, TestTypeSelectorMethod) + assert method.arguments == [] + assert search_manifest_using_method(manifest, method, 'schema') == {'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'unique_ext_raw_ext_source_id'} + assert search_manifest_using_method(manifest, method, 'data') == {'view_test_nothing'} + diff --git a/test/unit/test_graph_selector_spec.py b/test/unit/test_graph_selector_spec.py new file mode 100644 index 00000000000..fdbd43e917e --- /dev/null +++ b/test/unit/test_graph_selector_spec.py @@ -0,0 +1,151 @@ +import pytest + +from dbt.exceptions import RuntimeException +from dbt.graph.selector_spec import ( + SelectionCriteria, + SelectionIntersection, + SelectionDifference, + SelectionUnion, +) +from dbt.graph.selector_methods import MethodName +import os + + +def test_raw_parse_simple(): + raw = 'asdf' + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.FQN + assert result.method_arguments == [] + assert result.value == raw + assert not result.select_childrens_parents + assert not result.select_children + assert not result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_simple_infer_path(): + raw = os.path.join('asdf', '*') + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.Path + assert result.method_arguments == [] + assert result.value == raw + assert not result.select_childrens_parents + assert not result.select_children + assert not result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_simple_infer_path_modified(): + raw = '@' + os.path.join('asdf', '*') + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.Path + assert result.method_arguments == [] + assert result.value == raw[1:] + assert result.select_childrens_parents + assert not result.select_children + assert not result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_simple_infer_fqn_parents(): + raw = '+asdf' + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.FQN + assert result.method_arguments == [] + assert result.value == 'asdf' + assert not result.select_childrens_parents + assert not result.select_children + assert result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_simple_infer_fqn_children(): + raw = 'asdf+' + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.FQN + assert result.method_arguments == [] + assert result.value == 'asdf' + assert not result.select_childrens_parents + assert result.select_children + assert not result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_complex(): + raw = '2+config.arg.secondarg:argument_value+4' + result = SelectionCriteria.from_single_spec(raw) + assert result.raw == raw + assert result.method == MethodName.Config + assert result.method_arguments == ['arg', 'secondarg'] + assert result.value == 'argument_value' + assert not result.select_childrens_parents + assert result.select_children + assert result.select_parents + assert result.select_parents_max_depth == 2 + assert result.select_children_max_depth == 4 + + +def test_raw_parse_weird(): + # you can have an empty method name (defaults to FQN/path) and you can have + # an empty value, so you can also have this... + result = SelectionCriteria.from_single_spec('') + assert result.raw == '' + assert result.method == MethodName.FQN + assert result.method_arguments == [] + assert result.value == '' + assert not result.select_childrens_parents + assert not result.select_children + assert not result.select_parents + assert result.select_parents_max_depth is None + assert result.select_children_max_depth is None + + +def test_raw_parse_invalid(): + with pytest.raises(RuntimeException): + SelectionCriteria.from_single_spec('invalid_method:something') + + with pytest.raises(RuntimeException): + SelectionCriteria.from_single_spec('@foo+') + + +def test_intersection(): + fqn_a = SelectionCriteria.from_single_spec('fqn:model_a') + fqn_b = SelectionCriteria.from_single_spec('fqn:model_b') + intersection = SelectionIntersection(components=[fqn_a, fqn_b]) + assert list(intersection) == [fqn_a, fqn_b] + combined = intersection.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}]) + assert combined == {'model_c'} + + +def test_difference(): + fqn_a = SelectionCriteria.from_single_spec('fqn:model_a') + fqn_b = SelectionCriteria.from_single_spec('fqn:model_b') + difference = SelectionDifference(components=[fqn_a, fqn_b]) + assert list(difference) == [fqn_a, fqn_b] + combined = difference.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}]) + assert combined == {'model_a', 'model_b'} + + fqn_c = SelectionCriteria.from_single_spec('fqn:model_c') + difference = SelectionDifference(components=[fqn_a, fqn_b, fqn_c]) + assert list(difference) == [fqn_a, fqn_b, fqn_c] + combined = difference.combine_selections([{'model_a', 'model_b', 'model_c'}, {'model_c', 'model_d'}, {'model_a'}]) + assert combined == {'model_b'} + + +def test_union(): + fqn_a = SelectionCriteria.from_single_spec('fqn:model_a') + fqn_b = SelectionCriteria.from_single_spec('fqn:model_b') + fqn_c = SelectionCriteria.from_single_spec('fqn:model_c') + difference = SelectionUnion(components=[fqn_a, fqn_b, fqn_c]) + combined = difference.combine_selections([{'model_a', 'model_b'}, {'model_b', 'model_c'}, {'model_d'}]) + assert combined == {'model_a', 'model_b', 'model_c', 'model_d'} diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index c846c218077..e7d25489ed4 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -17,13 +17,12 @@ NodeConfig, ParsedSeedNode, ParsedSourceDefinition, - ParsedDocumentation, ) from dbt.contracts.graph.compiled import CompiledModelNode from dbt.node_types import NodeType import freezegun -from .utils import MockMacro +from .utils import MockMacro, MockDocumentation, MockSource, MockNode, MockMaterialization, MockGenerateMacro REQUIRED_PARSED_NODE_KEYS = frozenset({ @@ -663,62 +662,6 @@ def test__build_flat_graph(self): # Tests of the manifest search code (find_X_by_Y) -def MockMaterialization(package, name='my_materialization', adapter_type=None, kwargs={}): - if adapter_type is None: - adapter_type = 'default' - kwargs['adapter_type'] = adapter_type - return MockMacro(package, f'materialization_{name}_{adapter_type}', kwargs) - - -def MockGenerateMacro(package, component='some_component', kwargs={}): - name = f'generate_{component}_name' - return MockMacro(package, name=name, kwargs=kwargs) - - -def MockSource(package, source_name, name, kwargs={}): - src = mock.MagicMock( - __class__=ParsedSourceDefinition, - resource_type=NodeType.Source, - source_name=source_name, - package_name=package, - unique_id=f'source.{package}.{source_name}.{name}', - search_name=f'{source_name}.{name}', - **kwargs - ) - src.name = name - return src - - -def MockNode(package, name, resource_type=NodeType.Model, kwargs={}): - if resource_type == NodeType.Model: - cls = ParsedModelNode - elif resource_type == NodeType.Seed: - cls = ParsedSeedNode - else: - raise ValueError(f'I do not know how to handle {resource_type}') - node = mock.MagicMock( - __class__=cls, - resource_type=resource_type, - package_name=package, - unique_id=f'{str(resource_type)}.{package}.{name}', - search_name=name, - **kwargs - ) - node.name = name - return node - - -def MockDocumentation(package, name, kwargs={}): - doc = mock.MagicMock( - __class__=ParsedDocumentation, - resource_type=NodeType.Documentation, - package_name=package, - search_name=name, - unique_id=f'{package}.{name}', - ) - doc.name = name - return doc - class TestManifestSearch(unittest.TestCase): _macros = [] diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 6824bd3e3a6..29dbdcbfdd3 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -26,60 +26,13 @@ NodeConfig, TestConfig, TimestampSnapshotConfig, SnapshotStrategy, ) from dbt.contracts.graph.parsed import ( - ParsedModelNode, ParsedMacro, ParsedNodePatch, ParsedSourceDefinition, - DependsOn, ColumnInfo, ParsedDataTestNode, ParsedSnapshotNode, - ParsedAnalysisNode, ParsedDocumentation, UnpatchedSourceDefinition, - ParsedSeedNode + ParsedModelNode, ParsedMacro, ParsedNodePatch, DependsOn, ColumnInfo, + ParsedDataTestNode, ParsedSnapshotNode, ParsedAnalysisNode, + UnpatchedSourceDefinition ) from dbt.contracts.graph.unparsed import Docs -from .utils import config_from_parts_or_dicts, normalize, generate_name_macros - - -def MockSource(package, source_name, name, **kwargs): - src = mock.MagicMock( - __class__=ParsedSourceDefinition, - resource_type=NodeType.Source, - source_name=source_name, - package_name=package, - unique_id=f'source.{package}.{source_name}.{name}', - search_name=f'{source_name}.{name}', - **kwargs - ) - src.name = name - return src - - -def MockNode(package, name, resource_type=NodeType.Model, **kwargs): - if resource_type == NodeType.Model: - cls = ParsedModelNode - elif resource_type == NodeType.Seed: - cls = ParsedSeedNode - else: - raise ValueError(f'I do not know how to handle {resource_type}') - node = mock.MagicMock( - __class__=cls, - resource_type=resource_type, - package_name=package, - unique_id=f'{str(resource_type)}.{package}.{name}', - search_name=name, - **kwargs - ) - node.name = name - return node - - -def MockDocumentation(package, name, **kwargs): - doc = mock.MagicMock( - __class__=ParsedDocumentation, - resource_type=NodeType.Documentation, - package_name=package, - search_name=name, - unique_id=f'{package}.{name}', - **kwargs - ) - doc.name = name - return doc +from .utils import config_from_parts_or_dicts, normalize, generate_name_macros, MockNode, MockSource, MockDocumentation def get_abs_os_path(unix_path): diff --git a/test/unit/utils.py b/test/unit/utils.py index 784b3803e5c..cd2989e982c 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -195,7 +195,7 @@ def _make_table_of(self, rows, column_types): return table -def MockMacro(package, name='my_macro', kwargs={}): +def MockMacro(package, name='my_macro', **kwargs): from dbt.contracts.graph.parsed import ParsedMacro from dbt.node_types import NodeType @@ -214,3 +214,69 @@ def MockMacro(package, name='my_macro', kwargs={}): ) macro.name = name return macro + + +def MockMaterialization(package, name='my_materialization', adapter_type=None, **kwargs): + if adapter_type is None: + adapter_type = 'default' + kwargs['adapter_type'] = adapter_type + return MockMacro(package, f'materialization_{name}_{adapter_type}', **kwargs) + + +def MockGenerateMacro(package, component='some_component', **kwargs): + name = f'generate_{component}_name' + return MockMacro(package, name=name, **kwargs) + + +def MockSource(package, source_name, name, **kwargs): + from dbt.node_types import NodeType + from dbt.contracts.graph.parsed import ParsedSourceDefinition + src = mock.MagicMock( + __class__=ParsedSourceDefinition, + resource_type=NodeType.Source, + source_name=source_name, + package_name=package, + unique_id=f'source.{package}.{source_name}.{name}', + search_name=f'{source_name}.{name}', + **kwargs + ) + src.name = name + return src + + +def MockNode(package, name, resource_type=None, **kwargs): + from dbt.node_types import NodeType + from dbt.contracts.graph.parsed import ParsedModelNode, ParsedSeedNode + if resource_type is None: + resource_type = NodeType.Model + if resource_type == NodeType.Model: + cls = ParsedModelNode + elif resource_type == NodeType.Seed: + cls = ParsedSeedNode + else: + raise ValueError(f'I do not know how to handle {resource_type}') + node = mock.MagicMock( + __class__=cls, + resource_type=resource_type, + package_name=package, + unique_id=f'{str(resource_type)}.{package}.{name}', + search_name=name, + **kwargs + ) + node.name = name + return node + + +def MockDocumentation(package, name, **kwargs): + from dbt.node_types import NodeType + from dbt.contracts.graph.parsed import ParsedDocumentation + doc = mock.MagicMock( + __class__=ParsedDocumentation, + resource_type=NodeType.Documentation, + package_name=package, + search_name=name, + unique_id=f'{package}.{name}', + **kwargs + ) + doc.name = name + return doc