diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cbaf6f608e..f39b60afd18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Add more helpful error message for misconfiguration in profiles.yml ([#2569](https://github.com/fishtown-analytics/dbt/issues/2569), [#2627](https://github.com/fishtown-analytics/dbt/pull/2627)) ### Fixes - Adapter plugins can once again override plugins defined in core ([#2548](https://github.com/fishtown-analytics/dbt/issues/2548), [#2590](https://github.com/fishtown-analytics/dbt/pull/2590)) +- Added `--selector` argument and support for `selectors.yml` file to define selection mechanisms. ([#2172](https://github.com/fishtown-analytics/dbt/issues/2172), [#2640](https://github.com/fishtown-analytics/dbt/pull/2640)) Contributors: - [@brunomurino](https://github.com/brunomurino) ([#2437](https://github.com/fishtown-analytics/dbt/pull/2581)) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 7bf61bcc1c8..83c787625d8 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -19,6 +19,8 @@ from dbt.exceptions import RecursionException from dbt.exceptions import SemverException from dbt.exceptions import validator_error_message +from dbt.exceptions import RuntimeException +from dbt.graph import SelectionSpec from dbt.helper_types import NoValue from dbt.semver import VersionSpecifier from dbt.semver import versions_compatible @@ -37,6 +39,11 @@ from hologram import ValidationError from .renderer import DbtProjectYamlRenderer +from .selectors import ( + selector_config_from_data, + selector_data_from_root, + SelectorConfig, +) INVALID_VERSION_ERROR = """\ @@ -211,10 +218,12 @@ class PartialProject: def render(self, renderer): packages_dict = package_data_from_root(self.project_root) + selectors_dict = selector_data_from_root(self.project_root) return Project.render_from_dict( self.project_root, self.project_dict, packages_dict, + selectors_dict, renderer, ) @@ -310,6 +319,7 @@ class Project: vars: VarProvider dbt_version: List[VersionSpecifier] packages: Dict[str, Any] + selectors: SelectorConfig query_comment: QueryComment config_version: int @@ -351,6 +361,7 @@ def from_project_config( cls, project_dict: Dict[str, Any], packages_dict: Optional[Dict[str, Any]] = None, + selectors_dict: Optional[Dict[str, Any]] = None, ) -> 'Project': """Create a project from its project and package configuration, as read by yaml.safe_load(). @@ -464,6 +475,11 @@ def from_project_config( except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e + try: + selectors = selector_config_from_data(selectors_dict) + except ValidationError as e: + raise DbtProjectError(validator_error_message(e)) from e + project = cls( project_name=name, version=version, @@ -488,6 +504,7 @@ def from_project_config( snapshots=snapshots, dbt_version=dbt_version, packages=packages, + selectors=selectors, query_comment=query_comment, sources=sources, vars=vars_value, @@ -568,14 +585,21 @@ def render_from_dict( project_root: str, project_dict: Dict[str, Any], packages_dict: Dict[str, Any], + selectors_dict: Dict[str, Any], renderer: DbtProjectYamlRenderer, ) -> 'Project': rendered_project = renderer.render_data(project_dict) rendered_project['project-root'] = project_root package_renderer = renderer.get_package_renderer() rendered_packages = package_renderer.render_data(packages_dict) + selectors_renderer = renderer.get_selector_renderer() + rendered_selectors = selectors_renderer.render_data(selectors_dict) try: - return cls.from_project_config(rendered_project, rendered_packages) + return cls.from_project_config( + rendered_project, + rendered_packages, + rendered_selectors, + ) except DbtProjectError as exc: if exc.path is None: exc.path = os.path.join(project_root, 'dbt_project.yml') @@ -659,6 +683,14 @@ def as_v1(self, all_projects: Iterable[str]): project.packages = self.packages return project + def get_selector(self, name: str) -> SelectionSpec: + if name not in self.selectors: + raise RuntimeException( + f'Could not find selector named {name}, expected one of ' + f'{list(self.selectors)}' + ) + return self.selectors[name] + def v2_vars_to_v1( dst: Dict[str, Any], src_vars: Dict[str, Any], project_names: Set[str] diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 9e37b0e70a9..a1112134cfd 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -69,6 +69,9 @@ def name(self): def get_package_renderer(self) -> BaseRenderer: return PackageRenderer(self.context) + def get_selector_renderer(self) -> BaseRenderer: + return SelectorRenderer(self.context) + def should_render_keypath_v1(self, keypath: Keypath) -> bool: if not keypath: return True @@ -206,3 +209,9 @@ class PackageRenderer(BaseRenderer): @property def name(self): return 'Packages config' + + +class SelectorRenderer(BaseRenderer): + @property + def name(self): + return 'Selector config' diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index f90db86c998..7ff2c8b9a23 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -106,6 +106,7 @@ def from_parts( snapshots=project.snapshots, dbt_version=project.dbt_version, packages=project.packages, + selectors=project.selectors, query_comment=project.query_comment, sources=project.sources, vars=project.vars, @@ -361,8 +362,9 @@ def load_projects( project = self.new_project(str(path)) except DbtProjectError as e: raise DbtProjectError( - 'Failed to read package at {}: {}' - .format(path, e) + f'Failed to read package: {e}', + result_type='invalid_project', + path=path, ) from e else: yield project.project_name, project @@ -501,6 +503,7 @@ def from_parts( snapshots=project.snapshots, dbt_version=project.dbt_version, packages=project.packages, + selectors=project.selectors, query_comment=project.query_comment, sources=project.sources, vars=project.vars, diff --git a/core/dbt/config/selectors.py b/core/dbt/config/selectors.py new file mode 100644 index 00000000000..1dfd7438f8f --- /dev/null +++ b/core/dbt/config/selectors.py @@ -0,0 +1,104 @@ +from pathlib import Path +from typing import Dict, Any, Optional + +from hologram import ValidationError + +from .renderer import SelectorRenderer + +from dbt.clients.system import ( + load_file_contents, + path_exists, + resolve_path_from_base, +) +from dbt.clients.yaml_helper import load_yaml_text +from dbt.contracts.selection import SelectorFile +from dbt.exceptions import DbtSelectorsError, RuntimeException +from dbt.graph import parse_from_selectors_definition, SelectionSpec + +MALFORMED_SELECTOR_ERROR = """\ +The selectors.yml file in this project is malformed. Please double check +the contents of this file and fix any errors before retrying. + +You can find more information on the syntax for this file here: +https://docs.getdbt.com/docs/package-management + +Validator Error: +{error} +""" + + +class SelectorConfig(Dict[str, SelectionSpec]): + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'SelectorConfig': + try: + selector_file = SelectorFile.from_dict(data) + selectors = parse_from_selectors_definition(selector_file) + except (ValidationError, RuntimeException) as exc: + raise DbtSelectorsError( + f'Could not read selector file data: {exc}', + result_type='invalid_selector', + ) from exc + + return cls(selectors) + + @classmethod + def render_from_dict( + cls, + data: Dict[str, Any], + renderer: SelectorRenderer, + ) -> 'SelectorConfig': + try: + rendered = renderer.render_data(data) + except (ValidationError, RuntimeException) as exc: + raise DbtSelectorsError( + f'Could not render selector data: {exc}', + result_type='invalid_selector', + ) from exc + return cls.from_dict(rendered) + + @classmethod + def from_path( + cls, path: Path, renderer: SelectorRenderer, + ) -> 'SelectorConfig': + try: + data = load_yaml_text(load_file_contents(str(path))) + except (ValidationError, RuntimeException) as exc: + raise DbtSelectorsError( + f'Could not read selector file: {exc}', + result_type='invalid_selector', + path=path, + ) from exc + + try: + return cls.render_from_dict(data, renderer) + except DbtSelectorsError as exc: + exc.path = path + raise + + +def selector_data_from_root(project_root: str) -> Dict[str, Any]: + selector_filepath = resolve_path_from_base( + 'selectors.yml', project_root + ) + + if path_exists(selector_filepath): + selectors_dict = load_yaml_text(load_file_contents(selector_filepath)) + else: + selectors_dict = None + return selectors_dict + + +def selector_config_from_data( + selectors_data: Optional[Dict[str, Any]] +) -> SelectorConfig: + if selectors_data is None: + selectors_data = {'selectors': []} + + try: + selectors = SelectorConfig.from_dict(selectors_data) + except ValidationError as e: + raise DbtSelectorsError( + MALFORMED_SELECTOR_ERROR.format(error=str(e.message)), + result_type='invalid_selector', + ) from e + return selectors diff --git a/core/dbt/contracts/common.py b/core/dbt/contracts/common.py deleted file mode 100644 index b42f6306be8..00000000000 --- a/core/dbt/contracts/common.py +++ /dev/null @@ -1,11 +0,0 @@ - - -def named_property(name, doc=None): - def get_prop(self): - return self._contents.get(name) - - def set_prop(self, value): - self._contents[name] = value - self.validate() - - return property(get_prop, set_prop, doc=doc) diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index e354a2a3454..8810bf04b86 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -1,5 +1,9 @@ from dbt.node_types import NodeType -from dbt.contracts.util import Replaceable, Mergeable +from dbt.contracts.util import ( + AdditionalPropertiesMixin, + Mergeable, + Replaceable, +) # trigger the PathEncoder import dbt.helper_types # noqa:F401 from dbt.exceptions import CompilationException @@ -177,32 +181,12 @@ def __bool__(self): @dataclass -class AdditionalPropertiesAllowed(ExtensibleJsonSchemaMixin): +class AdditionalPropertiesAllowed( + AdditionalPropertiesMixin, + ExtensibleJsonSchemaMixin +): _extra: Dict[str, Any] = field(default_factory=dict) - @property - def extra(self): - return self._extra - - @classmethod - def from_dict(cls, data, validate=True): - self = super().from_dict(data=data, validate=validate) - keys = self.to_dict(validate=False, omit_none=False) - for key, value in data.items(): - if key not in keys: - self._extra[key] = value - return self - - def to_dict(self, omit_none=True, validate=False): - data = super().to_dict(omit_none=omit_none, validate=validate) - data.update(self._extra) - return data - - def replace(self, **kwargs): - dct = self.to_dict(omit_none=False, validate=False) - dct.update(kwargs) - return self.from_dict(dct) - @dataclass class ExternalPartition(AdditionalPropertiesAllowed, Replaceable): diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index aab03527009..0df94173a32 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -44,6 +44,7 @@ class RPCCompileParameters(RPCParameters): threads: Optional[int] = None models: Union[None, str, List[str]] = None exclude: Union[None, str, List[str]] = None + selector: Optional[str] = None @dataclass @@ -51,6 +52,7 @@ class RPCSnapshotParameters(RPCParameters): threads: Optional[int] = None select: Union[None, str, List[str]] = None exclude: Union[None, str, List[str]] = None + selector: Optional[str] = None @dataclass @@ -64,6 +66,7 @@ class RPCSeedParameters(RPCParameters): threads: Optional[int] = None select: Union[None, str, List[str]] = None exclude: Union[None, str, List[str]] = None + selector: Optional[str] = None show: bool = False diff --git a/core/dbt/contracts/selection.py b/core/dbt/contracts/selection.py new file mode 100644 index 00000000000..d193473ecb8 --- /dev/null +++ b/core/dbt/contracts/selection.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from hologram import JsonSchemaMixin + +from typing import List, Dict, Any, Union + + +@dataclass +class SelectorDefinition(JsonSchemaMixin): + name: str + definition: Union[str, Dict[str, Any]] + + +@dataclass +class SelectorFile(JsonSchemaMixin): + selectors: List[SelectorDefinition] + version: int = 2 + + +# @dataclass +# class SelectorCollection: +# packages: Dict[str, List[SelectorFile]] = field(default_factory=dict) diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index e90694fedaf..e2f2c257c2a 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -44,3 +44,35 @@ def merged(self, *args): class Writable: def write(self, path: str, omit_none: bool = False): write_json(path, self.to_dict(omit_none=omit_none)) # type: ignore + + +class AdditionalPropertiesMixin: + """Make this class an extensible property. + + The underlying class definition must include a type definition for a field + named '_extra' that is of type `Dict[str, Any]`. + """ + ADDITIONAL_PROPERTIES = True + + @classmethod + def from_dict(cls, data, validate=True): + self = super().from_dict(data=data, validate=validate) + keys = self.to_dict(validate=False, omit_none=False) + for key, value in data.items(): + if key not in keys: + self.extra[key] = value + return self + + def to_dict(self, omit_none=True, validate=False): + data = super().to_dict(omit_none=omit_none, validate=validate) + data.update(self.extra) + return data + + def replace(self, **kwargs): + dct = self.to_dict(omit_none=False, validate=False) + dct.update(kwargs) + return self.from_dict(dct) + + @property + def extra(self): + return self._extra diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 41e6ded9a43..437403d35f7 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -319,6 +319,10 @@ class DbtProjectError(DbtConfigError): pass +class DbtSelectorsError(DbtConfigError): + pass + + class DbtProfileError(DbtConfigError): pass diff --git a/core/dbt/graph/__init__.py b/core/dbt/graph/__init__.py index 4a96f7da650..f815af4828b 100644 --- a/core/dbt/graph/__init__.py +++ b/core/dbt/graph/__init__.py @@ -12,6 +12,7 @@ from .cli import ( # noqa: F401 parse_difference, parse_test_selectors, + parse_from_selectors_definition, ) from .queue import GraphQueue # noqa: F401 from .graph import Graph, UniqueId # noqa: F401 diff --git a/core/dbt/graph/cli.py b/core/dbt/graph/cli.py index 3b0d946e3e8..7eabb501902 100644 --- a/core/dbt/graph/cli.py +++ b/core/dbt/graph/cli.py @@ -2,9 +2,12 @@ import itertools from typing import ( - List, Optional + Dict, List, Optional, Tuple, Any, Union ) +from dbt.contracts.selection import SelectorDefinition, SelectorFile +from dbt.exceptions import InternalException, ValidationException + from .selector_spec import ( SelectionUnion, SelectionSpec, @@ -97,3 +100,169 @@ def parse_test_selectors( return SelectionIntersection( components=[base, intersect_with], expect_exists=True ) + + +RawDefinition = Union[str, Dict[str, Any]] + + +def _get_list_dicts( + dct: Dict[str, Any], key: str +) -> List[RawDefinition]: + result: List[RawDefinition] = [] + if key not in dct: + raise InternalException( + f'Expected to find key {key} in dict, only found {list(dct)}' + ) + values = dct[key] + if not isinstance(values, list): + raise ValidationException( + f'Invalid value type {type(values)} in key "{key}" ' + f'(value "{values}")' + ) + for value in values: + if isinstance(value, dict): + for value_key in value: + if not isinstance(value_key, str): + raise ValidationException( + f'Expected all keys to "{key}" dict to be strings, ' + f'but "{value_key}" is a "{type(value_key)}"' + ) + result.append(value) + elif isinstance(value, str): + result.append(value) + else: + raise ValidationException( + f'Invalid value type {type(value)} in key "{key}", expected ' + f'dict or str (value: {value}).' + ) + + return result + + +def _parse_exclusions(definition) -> Optional[SelectionSpec]: + exclusions = _get_list_dicts(definition, 'exclude') + parsed_exclusions = [ + parse_from_definition(excl) for excl in exclusions + ] + if len(parsed_exclusions) == 1: + return parsed_exclusions[0] + elif len(parsed_exclusions) > 1: + return SelectionUnion( + components=parsed_exclusions, + raw=exclusions + ) + else: + return None + + +def _parse_include_exclude_subdefs( + definitions: List[RawDefinition] +) -> Tuple[List[SelectionSpec], Optional[SelectionSpec]]: + include_parts: List[SelectionSpec] = [] + diff_arg: Optional[SelectionSpec] = None + + for definition in definitions: + if isinstance(definition, dict) and 'exclude' in definition: + # do not allow multiple exclude: defs at the same level + if diff_arg is not None: + raise ValidationException( + f'Got multiple exclusion definitions in definition list ' + f'{definitions}' + ) + diff_arg = _parse_exclusions(definition) + else: + include_parts.append(parse_from_definition(definition)) + + return (include_parts, diff_arg) + + +def parse_union_definition(definition: Dict[str, Any]) -> SelectionSpec: + union_def_parts = _get_list_dicts(definition, 'union') + include, exclude = _parse_include_exclude_subdefs(union_def_parts) + + union = SelectionUnion(components=include) + + if exclude is None: + union.raw = definition + return union + else: + return SelectionDifference( + components=[union, exclude], + raw=definition + ) + + +def parse_intersection_definition( + definition: Dict[str, Any] +) -> SelectionSpec: + intersection_def_parts = _get_list_dicts(definition, 'intersection') + include, exclude = _parse_include_exclude_subdefs(intersection_def_parts) + intersection = SelectionIntersection(components=include) + if exclude is None: + intersection.raw = definition + return intersection + else: + return SelectionDifference( + components=[intersection, exclude], + raw=definition + ) + + +def parse_dict_definition(definition: Dict[str, Any]) -> SelectionSpec: + diff_arg: Optional[SelectionSpec] = None + + if len(definition) == 1: + key = list(definition)[0] + value = definition[key] + if not isinstance(key, str): + raise ValidationException( + f'Expected definition key to be a "str", got one of type ' + f'"{type(key)}" ({key})' + ) + dct = { + 'method': key, + 'value': value, + } + elif 'method' in definition and 'value' in definition: + dct = definition + if 'exclude' in definition: + diff_arg = _parse_exclusions(definition) + dct = {k: v for k, v in dct.items() if k != 'exclude'} + else: + raise ValidationException( + f'Expected exactly 1 key in the selection definition or "method" ' + f'and "value" keys, but got {list(definition)}' + ) + + # if key isn't a valid method name, this will raise + base = SelectionCriteria.from_dict(definition, dct) + if diff_arg is None: + return base + else: + return SelectionDifference(components=[base, diff_arg]) + + +def parse_from_definition(definition: RawDefinition) -> SelectionSpec: + if isinstance(definition, str): + return SelectionCriteria.from_single_spec(definition) + elif 'union' in definition: + return parse_union_definition(definition) + elif 'intersection' in definition: + return parse_intersection_definition(definition) + elif isinstance(definition, dict): + return parse_dict_definition(definition) + else: + raise ValidationException( + f'Expected to find str or dict, instead found ' + f'{type(definition)}: {definition}' + ) + + +def parse_from_selectors_definition( + source: SelectorFile +) -> Dict[str, SelectionSpec]: + result: Dict[str, SelectionSpec] = {} + selector: SelectorDefinition + for selector in source.selectors: + result[selector.name] = parse_from_definition(selector.definition) + return result diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index f47e1bc48c2..539340dd873 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -93,15 +93,15 @@ def collect_specified_neighbors( overlap with the selected set). """ additional: Set[UniqueId] = set() - if spec.select_childrens_parents: + if spec.childrens_parents: additional.update(self.graph.select_childrens_parents(selected)) - if spec.select_parents: - depth = spec.select_parents_max_depth + if spec.parents: + depth = spec.parents_depth additional.update(self.graph.select_parents(selected, depth)) - if spec.select_children: - depth = spec.select_children_max_depth + if spec.children: + depth = spec.children_depth additional.update(self.graph.select_children(selected, depth)) return additional diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index 1b029396028..afc968654a2 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -23,6 +23,7 @@ InternalException, RuntimeException, ) +from dbt.node_types import NodeType SELECTOR_GLOB = '*' @@ -38,6 +39,7 @@ class MethodName(StrEnum): Config = 'config' TestName = 'test_name' TestType = 'test_type' + ResourceType = 'resource_type' def is_selected_node(real_node, node_selector): @@ -259,7 +261,9 @@ def __eq__(self, other): class ConfigSelectorMethod(SelectorMethod): def search( - self, included_nodes: Set[UniqueId], selector: str + self, + included_nodes: Set[UniqueId], + selector: Any, ) -> Iterator[UniqueId]: parts = self.arguments # special case: if the user wanted to compare test severity, @@ -276,14 +280,25 @@ def search( except AttributeError: continue else: - # the selector can only be a str, so call str() on the value. - # of course, if one wished to render the selector in the jinja - # native env, this would no longer be true - - if selector == str(value): + if selector == value: yield node +class ResourceTypeSelectorMethod(SelectorMethod): + def search( + self, included_nodes: Set[UniqueId], selector: str + ) -> Iterator[UniqueId]: + try: + resource_type = NodeType(selector) + except ValueError as exc: + raise RuntimeException( + f'Invalid resource_type selector "{selector}"' + ) from exc + for node, real_node in self.parsed_nodes(included_nodes): + if real_node.resource_type == resource_type: + yield node + + class TestNameSelectorMethod(SelectorMethod): def search( self, included_nodes: Set[UniqueId], selector: str diff --git a/core/dbt/graph/selector_spec.py b/core/dbt/graph/selector_spec.py index ee86e091ca9..f8547c4d2c7 100644 --- a/core/dbt/graph/selector_spec.py +++ b/core/dbt/graph/selector_spec.py @@ -13,7 +13,7 @@ RAW_SELECTOR_PATTERN = re.compile( r'\A' - r'(?P(\@))?' + r'(?P(\@))?' r'(?P((?P(\d*))\+))?' r'((?P([\w.]+)):)?(?P(.*?))' r'(?P(\+(?P(\d*))))?' @@ -57,18 +57,18 @@ def _match_to_int(match: Dict[str, str], key: str) -> Optional[int]: @dataclass class SelectionCriteria: - raw: str + raw: Any method: MethodName method_arguments: List[str] - value: str - select_childrens_parents: bool - select_parents: bool - select_parents_max_depth: Optional[int] - select_children: bool - select_children_max_depth: Optional[int] + value: Any + childrens_parents: bool + parents: bool + parents_depth: Optional[int] + children: bool + children_depth: Optional[int] def __post_init__(self): - if self.select_children and self.select_childrens_parents: + if self.children and self.childrens_parents: raise RuntimeException( f'Invalid node spec {self.raw} - "@" prefix and "+" suffix ' 'are incompatible' @@ -83,7 +83,7 @@ def default_method(cls, value: str) -> MethodName: @classmethod def parse_method( - cls, raw: str, groupdict: Dict[str, Any] + cls, groupdict: Dict[str, Any] ) -> Tuple[MethodName, List[str]]: raw_method = groupdict.get('method') if raw_method is None: @@ -100,35 +100,36 @@ def parse_method( return method_name, method_arguments @classmethod - def from_single_spec(cls, raw: str) -> 'SelectionCriteria': - result = RAW_SELECTOR_PATTERN.match(raw) - if result is None: - # bad spec! - raise RuntimeException(f'Invalid selector spec "{raw}"') - result_dict = result.groupdict() - - if 'value' not in result_dict: + def from_dict(cls, raw: Any, dct: Dict[str, Any]) -> 'SelectionCriteria': + if 'value' not in dct: raise RuntimeException( f'Invalid node spec "{raw}" - no search value!' ) + method_name, method_arguments = cls.parse_method(dct) - method_name, method_arguments = cls.parse_method(raw, result_dict) - - parents_max_depth = _match_to_int(result_dict, 'parents_depth') - children_max_depth = _match_to_int(result_dict, 'children_depth') - + parents_depth = _match_to_int(dct, 'parents_depth') + children_depth = _match_to_int(dct, 'children_depth') return cls( raw=raw, method=method_name, method_arguments=method_arguments, - value=result_dict['value'], - select_childrens_parents=bool(result_dict.get('childs_parents')), - select_parents=bool(result_dict.get('parents')), - select_parents_max_depth=parents_max_depth, - select_children=bool(result_dict.get('children')), - select_children_max_depth=children_max_depth, + value=dct['value'], + childrens_parents=bool(dct.get('childrens_parents')), + parents=bool(dct.get('parents')), + parents_depth=parents_depth, + children=bool(dct.get('children')), + children_depth=children_depth, ) + @classmethod + def from_single_spec(cls, raw: str) -> 'SelectionCriteria': + result = RAW_SELECTOR_PATTERN.match(raw) + if result is None: + # bad spec! + raise RuntimeException(f'Invalid selector spec "{raw}"') + + return cls.from_dict(raw, result.groupdict()) + class BaseSelectionGroup(Iterable[SelectionSpec], metaclass=ABCMeta): def __init__( diff --git a/core/dbt/main.py b/core/dbt/main.py index cde94c77c71..726dca5cea8 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -461,6 +461,14 @@ def _add_selection_arguments(*subparsers, **kwargs): Specify the models to exclude. ''', ) + sub.add_argument( + '--selector', + dest='selector_name', + metavar='SELECTOR_NAME', + help=''' + The selector name to use, as defined in selectors.yml + ''' + ) def _add_table_mutability_arguments(*subparsers): @@ -703,6 +711,14 @@ def _build_list_subparser(subparsers, base_subparser): Specify the models to exclude. ''' ) + sub.add_argument( + '--selector', + metavar='SELECTOR_NAME', + dest='selector_name', + help=''' + The selector name to use, as defined in selectors.yml + ''' + ) return sub diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 8ae99a3d698..13a99b4fabd 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -28,7 +28,10 @@ def raise_on_first_error(self): return True def get_selection_spec(self) -> SelectionSpec: - spec = parse_difference(self.args.models, self.args.exclude) + if self.args.selector_name: + spec = self.config.get_selector(self.args.selector_name) + else: + spec = parse_difference(self.args.models, self.args.exclude) return spec def get_node_selector(self) -> ResourceTypeSelector: diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 0d548c0676c..d4a39acd094 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -149,7 +149,10 @@ def selector(self): return self.args.select def get_selection_spec(self) -> SelectionSpec: - spec = parse_difference(self.selector, self.args.exclude) + if self.args.selector_name: + spec = self.config.get_selector(self.args.selector_name) + else: + spec = parse_difference(self.selector, self.args.exclude) return spec def get_node_selector(self): diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index 99a67f8a36b..8bf12215690 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -62,6 +62,7 @@ class RemoteCompileProjectTask( def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) + self.args.selector_name = params.selector if params.threads is not None: self.args.threads = params.threads @@ -72,6 +73,7 @@ class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask): def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) + self.args.selector_name = params.selector if params.threads is not None: self.args.threads = params.threads @@ -83,6 +85,7 @@ def set_args(self, params: RPCSeedParameters) -> None: # select has an argparse `dest` value of `models`. self.args.models = self._listify(params.select) self.args.exclude = self._listify(params.exclude) + self.args.selector_name = params.selector if params.threads is not None: self.args.threads = params.threads self.args.show = params.show @@ -94,6 +97,7 @@ class RemoteTestProjectTask(RPCCommandTask[RPCTestParameters], TestTask): def set_args(self, params: RPCTestParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) + self.args.selector_name = params.selector self.args.data = params.data self.args.schema = params.schema if params.threads is not None: @@ -109,6 +113,7 @@ class RemoteDocsGenerateProjectTask( def set_args(self, params: RPCDocsGenerateParameters) -> None: self.args.models = None self.args.exclude = None + self.args.selector_name = None self.args.compile = params.compile def get_catalog_results( @@ -176,6 +181,7 @@ def set_args(self, params: RPCSnapshotParameters) -> None: # select has an argparse `dest` value of `models`. self.args.models = self._listify(params.select) self.args.exclude = self._listify(params.exclude) + self.args.selector_name = params.selector if params.threads is not None: self.args.threads = params.threads @@ -202,6 +208,7 @@ class GetManifest( def set_args(self, params: GetManifestParameters) -> None: self.args.models = None self.args.exclude = None + self.args.selector_name = None def handle_request(self) -> GetManifestResult: task = RemoteCompileProjectTask(self.args, self.config, self.manifest) diff --git a/test/integration/007_graph_selection_tests/test_intersection_syntax.py b/test/integration/007_graph_selection_tests/test_intersection_syntax.py index 175a571fd83..d725d03c39f 100644 --- a/test/integration/007_graph_selection_tests/test_intersection_syntax.py +++ b/test/integration/007_graph_selection_tests/test_intersection_syntax.py @@ -1,4 +1,5 @@ from test.integration.base import DBTIntegrationTest, use_profile +import yaml class TestGraphSelection(DBTIntegrationTest): @@ -11,11 +12,70 @@ def schema(self): def models(self): return "models" - @use_profile('postgres') - def test__postgres__same_model_intersection(self): - self.run_sql_file("seed.sql") - - results = self.run_dbt(['run', '--models', 'users,users']) + @property + def selectors_config(self): + return yaml.safe_load(''' + selectors: + - name: same_intersection + definition: + intersection: + - fqn: users + - fqn:users + - name: tags_intersection + definition: + intersection: + - tag: bi + - tag: users + - name: triple_descending + definition: + intersection: + - fqn: "*" + - tag: bi + - tag: users + - name: triple_ascending + definition: + intersection: + - tag: users + - tag: bi + - fqn: "*" + - name: intersection_with_exclusion + definition: + intersection: + - method: fqn + value: users_rollup_dependency + parents: true + - method: fqn + value: users + children: true + - exclude: + - users_rollup_dependency + - name: intersection_exclude_intersection + definition: + intersection: + - tag:bi + - "@users" + - exclude: + - intersection: + - tag:bi + - method: fqn + value: users_rollup + children: true + - name: intersection_exclude_intersection_lack + definition: + intersection: + - tag:bi + - "@users" + - exclude: + - intersection: + - method: fqn + value: emails + children_parents: true + - method: fqn + value: emails_alt + children_parents: true + ''') + + def _verify_selected_users(self, results): # users self.assertEqual(len(results), 1) @@ -26,56 +86,70 @@ def test__postgres__same_model_intersection(self): self.assertNotIn('subdir', created_models) self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__same_model_intersection(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--models', 'users,users']) + self._verify_selected_users(results) + + @use_profile('postgres') + def test__postgres__same_model_intersection_selectors(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--selector', 'same_intersection']) + self._verify_selected_users(results) + @use_profile('postgres') def test__postgres__tags_intersection(self): self.run_sql_file("seed.sql") results = self.run_dbt(['run', '--models', 'tag:bi,tag:users']) - # users - self.assertEqual(len(results), 1) + self._verify_selected_users(results) - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertNotIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__tags_intersection_selectors(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--selector', 'tags_intersection']) + self._verify_selected_users(results) @use_profile('postgres') def test__postgres__intersection_triple_descending(self): self.run_sql_file("seed.sql") results = self.run_dbt(['run', '--models', '*,tag:bi,tag:users']) - # users - self.assertEqual(len(results), 1) + self._verify_selected_users(results) - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertNotIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__intersection_triple_descending_schema(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--models', '*,tag:bi,tag:users']) + self._verify_selected_users(results) + + @use_profile('postgres') + def test__postgres__intersection_triple_descending_schema_selectors(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--selector', 'triple_descending']) + self._verify_selected_users(results) @use_profile('postgres') def test__postgres__intersection_triple_ascending(self): self.run_sql_file("seed.sql") results = self.run_dbt(['run', '--models', 'tag:users,tag:bi,*']) - # users - self.assertEqual(len(results), 1) - - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertNotIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + self._verify_selected_users(results) @use_profile('postgres') - def test__postgres__intersection_with_exclusion(self): + def test__postgres__intersection_triple_ascending_schema_selectors(self): self.run_sql_file("seed.sql") - results = self.run_dbt(['run', '--models', '+users_rollup_dependency,users+', '--exclude', 'users_rollup_dependency']) + results = self.run_dbt(['run', '--selector', 'triple_ascending']) + self._verify_selected_users(results) + + def _verify_selected_users_and_rollup(self, results): # users, users_rollup self.assertEqual(len(results), 2) @@ -86,6 +160,20 @@ def test__postgres__intersection_with_exclusion(self): self.assertNotIn('subdir', created_models) self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__intersection_with_exclusion(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--models', '+users_rollup_dependency,users+', '--exclude', 'users_rollup_dependency']) + self._verify_selected_users_and_rollup(results) + + @use_profile('postgres') + def test__postgres__intersection_with_exclusion_selectors(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt(['run', '--selector', 'intersection_with_exclusion']) + self._verify_selected_users_and_rollup(results) + @use_profile('postgres') def test__postgres__intersection_exclude_intersection(self): self.run_sql_file("seed.sql") @@ -93,15 +181,16 @@ def test__postgres__intersection_exclude_intersection(self): results = self.run_dbt( ['run', '--models', 'tag:bi,@users', '--exclude', 'tag:bi,users_rollup+']) - # users - self.assertEqual(len(results), 1) + self._verify_selected_users(results) - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertNotIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__intersection_exclude_intersection_selectors(self): + self.run_sql_file("seed.sql") + + results = self.run_dbt( + ['run', '--selector', 'intersection_exclude_intersection'] + ) + self._verify_selected_users(results) @use_profile('postgres') def test__postgres__intersection_exclude_intersection_lack(self): @@ -110,16 +199,15 @@ def test__postgres__intersection_exclude_intersection_lack(self): results = self.run_dbt( ['run', '--models', 'tag:bi,@users', '--exclude', '@emails,@emails_alt']) - # users, users_rollup - self.assertEqual(len(results), 2) + self._verify_selected_users_and_rollup(results) - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + @use_profile('postgres') + def test__postgres__intersection_exclude_intersection_lack_selector(self): + self.run_sql_file("seed.sql") + results = self.run_dbt( + ['run', '--selector', 'intersection_exclude_intersection_lack']) + self._verify_selected_users_and_rollup(results) @use_profile('postgres') def test__postgres__intersection_exclude_triple_intersection(self): @@ -128,15 +216,7 @@ def test__postgres__intersection_exclude_triple_intersection(self): results = self.run_dbt( ['run', '--models', 'tag:bi,@users', '--exclude', '*,tag:bi,users_rollup']) - # users - self.assertEqual(len(results), 1) - - created_models = self.get_models_in_schema() - self.assertIn('users', created_models) - self.assertNotIn('users_rollup', created_models) - self.assertNotIn('emails_alt', created_models) - self.assertNotIn('subdir', created_models) - self.assertNotIn('nested_users', created_models) + self._verify_selected_users(results) @use_profile('postgres') def test__postgres__intersection_concat(self): diff --git a/test/integration/007_graph_selection_tests/test_tag_selection.py b/test/integration/007_graph_selection_tests/test_tag_selection.py index 635f289482f..3d3b0c3baf7 100644 --- a/test/integration/007_graph_selection_tests/test_tag_selection.py +++ b/test/integration/007_graph_selection_tests/test_tag_selection.py @@ -1,5 +1,7 @@ from test.integration.base import DBTIntegrationTest, use_profile +import yaml + class TestGraphSelection(DBTIntegrationTest): @@ -27,44 +29,119 @@ def project_config(self): } } - @use_profile('postgres') - def test__postgres__select_tag(self): + @property + def selectors_config(self): + return yaml.safe_load(''' + selectors: + - name: tag_specified_as_string_str + definition: tag:specified_as_string + - name: tag_specified_as_string_dict + definition: + method: tag + value: specified_as_string + - name: tag_specified_in_project_children_str + definition: +tag:specified_in_project+ + - name: tag_specified_in_project_children_dict + definition: + method: tag + value: specified_in_project + parents: true + children: true + - name: tagged-bi + definition: + method: tag + value: bi + - name: user_tagged_childrens_parents + definition: + method: tag + value: users + childrens_parents: true + - name: base_ephemerals + definition: + union: + - tag: base + - method: config.materialized + value: ephemeral + - name: warn-severity + definition: + config.severity: warn + - name: roundabout-everything + definition: + union: + - "@tag:users" + - intersection: + - tag: base + - config.materialized: ephemeral + ''') + + def setUp(self): + super().setUp() self.run_sql_file("seed.sql") - results = self.run_dbt(['run', '--models', 'tag:specified_as_string']) + def _verify_select_tag(self, results): self.assertEqual(len(results), 1) models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) @use_profile('postgres') - def test__postgres__select_tag_and_children(self): - self.run_sql_file("seed.sql") + def test__postgres__select_tag(self): + results = self.run_dbt(['run', '--models', 'tag:specified_as_string']) + self._verify_select_tag(results) - results = self.run_dbt(['run', '--models', '+tag:specified_in_project+']) + @use_profile('postgres') + def test__postgres__select_tag_selector_str(self): + results = self.run_dbt(['run', '--selector', 'tag_specified_as_string_str']) + self._verify_select_tag(results) + + @use_profile('postgres') + def test__postgres__select_tag_selector_dict(self): + results = self.run_dbt(['run', '--selector', 'tag_specified_as_string_dict']) + self._verify_select_tag(results) + + def _verify_select_tag_and_children(self, results): self.assertEqual(len(results), 3) models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) self.assertTrue('users_rollup' in models_run) - # check that model configs aren't squashed by project configs @use_profile('postgres') - def test__postgres__select_tag_in_model_with_project_Config(self): - self.run_sql_file("seed.sql") + def test__postgres__select_tag_and_children(self): + results = self.run_dbt(['run', '--models', '+tag:specified_in_project+']) + self._verify_select_tag_and_children(results) - results = self.run_dbt(['run', '--models', 'tag:bi']) + @use_profile('postgres') + def test__postgres__select_tag_and_children_selector_str(self): + results = self.run_dbt(['run', '--selector', 'tag_specified_in_project_children_str']) + self._verify_select_tag_and_children(results) + + @use_profile('postgres') + def test__postgres__select_tag_and_children_selector_dict(self): + results = self.run_dbt(['run', '--selector', 'tag_specified_in_project_children_dict']) + self._verify_select_tag_and_children(results) + + # check that model configs aren't squashed by project configs + def _verify_select_bi(self, results): self.assertEqual(len(results), 2) models_run = [r.node.name for r in results] self.assertTrue('users' in models_run) self.assertTrue('users_rollup' in models_run) - # check that model configs aren't squashed by project configs @use_profile('postgres') - def test__postgres__select_tag_in_model_with_project_Config(self): - self.run_sql_file("seed.sql") + def test__postgres__select_tag_in_model_with_project_config(self): + results = self.run_dbt(['run', '--models', 'tag:bi']) + self._verify_select_bi(results) + @use_profile('postgres') + def test__postgres__select_tag_in_model_with_project_config_selector(self): + results = self.run_dbt(['run', '--selector', 'tagged-bi']) + self._verify_select_bi(results) + + # check that model configs aren't squashed by project configs + @use_profile('postgres') + def test__postgres__select_tag_in_model_with_project_config_parents_children(self): results = self.run_dbt(['run', '--models', '@tag:users']) self.assertEqual(len(results), 4) @@ -90,3 +167,33 @@ def test__postgres__select_tag_in_model_with_project_Config(self): results = self.run_dbt(['test', '--models', '@tag:users tag:base,config.materialized:ephemeral']) self.assertEqual(len(results), 3) assert sorted(r.node.name for r in results) == ['not_null_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] + + + @use_profile('postgres') + def test__postgres__select_tag_in_model_with_project_config_parents_children_selectors(self): + results = self.run_dbt(['run', '--selector', 'user_tagged_childrens_parents']) + self.assertEqual(len(results), 4) + + models_run = set(r.node.name for r in results) + self.assertEqual( + {'users', 'users_rollup', 'emails_alt', 'users_rollup_dependency'}, + models_run + ) + + # just the users/users_rollup tests + results = self.run_dbt(['test', '--selector', 'user_tagged_childrens_parents']) + self.assertEqual(len(results), 2) + assert sorted(r.node.name for r in results) == ['unique_users_id', 'unique_users_rollup_gender'] + # just the email test + results = self.run_dbt(['test', '--selector', 'base_ephemerals']) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'not_null_emails_email' + + # also just the email test + results = self.run_dbt(['test', '--selector', 'warn-severity']) + self.assertEqual(len(results), 1) + assert results[0].node.name == 'not_null_emails_email' + # all 3 tests + results = self.run_dbt(['test', '--selector', 'roundabout-everything']) + self.assertEqual(len(results), 3) + assert sorted(r.node.name for r in results) == ['not_null_emails_email', 'unique_users_id', 'unique_users_rollup_gender'] diff --git a/test/integration/base.py b/test/integration/base.py index 9a2d4edd1c2..5065f1c1d09 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -65,6 +65,7 @@ def __init__(self): self.models = None self.exclude = None self.single_threaded = False + self.selector_name = None class TestArgs: @@ -289,6 +290,10 @@ def presto_profile(self): def packages_config(self): return None + @property + def selectors_config(self): + return None + def unique_schema(self): schema = self.schema @@ -388,6 +393,7 @@ def setUp(self): self.use_profile(self._pick_profile()) self.use_default_project() self.set_packages() + self.set_selectors() self.load_config() def use_default_project(self, overrides=None): @@ -431,6 +437,11 @@ def set_packages(self): with open('packages.yml', 'w') as f: yaml.safe_dump(self.packages_config, f, default_flow_style=True) + def set_selectors(self): + if self.selectors_config is not None: + with open('selectors.yml', 'w') as f: + yaml.safe_dump(self.selectors_config, f, default_flow_style=True) + def load_config(self): # we've written our profile and project. Now we want to instantiate a # fresh adapter for the tests. diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index a5623a95793..7180680bc24 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -167,19 +167,19 @@ def test_run_specs(include, exclude, expected): @pytest.mark.parametrize( - 'spec,parents,parents_max_depth,children,children_max_depth,filter_type,filter_value,childrens_parents', + 'spec,parents,parents_depth,children,children_depth,filter_type,filter_value,childrens_parents', param_specs, ids=id_macro ) -def test_parse_specs(spec, parents, parents_max_depth, children, children_max_depth, filter_type, filter_value, childrens_parents): +def test_parse_specs(spec, parents, parents_depth, children, children_depth, filter_type, filter_value, childrens_parents): parsed = graph_selector.SelectionCriteria.from_single_spec(spec) - assert parsed.select_parents == parents - assert parsed.select_parents_max_depth == parents_max_depth - assert parsed.select_children == children - assert parsed.select_children_max_depth == children_max_depth + assert parsed.parents == parents + assert parsed.parents_depth == parents_depth + assert parsed.children == children + assert parsed.children_depth == children_depth assert parsed.method == filter_type assert parsed.value == filter_value - assert parsed.select_childrens_parents == childrens_parents + assert parsed.childrens_parents == childrens_parents invalid_specs = [ diff --git a/test/unit/test_graph_selector_parsing.py b/test/unit/test_graph_selector_parsing.py new file mode 100644 index 00000000000..7004694afca --- /dev/null +++ b/test/unit/test_graph_selector_parsing.py @@ -0,0 +1,301 @@ +from dbt.graph import ( + cli, + SelectionUnion, + SelectionIntersection, + SelectionDifference, + SelectionCriteria, +) +from dbt.graph.selector_methods import MethodName +import textwrap +import yaml + +from dbt.contracts.selection import SelectorFile + + +def parse_file(txt: str) -> SelectorFile: + txt = textwrap.dedent(txt) + dct = yaml.safe_load(txt) + sf = SelectorFile.from_dict(dct) + return sf + + +class Union: + def __init__(self, *args): + self.components = args + + def __str__(self): + return f'Union(components={self.components})' + + def __repr__(self): + return f'Union(components={self.components!r})' + + def __eq__(self, other): + if not isinstance(other, SelectionUnion): + return False + + return all(mine == theirs for mine, theirs in zip(self.components, other.components)) + + +class Intersection: + def __init__(self, *args): + self.components = args + + def __str__(self): + return f'Intersection(components={self.components})' + + def __repr__(self): + return f'Intersection(components={self.components!r})' + + def __eq__(self, other): + if not isinstance(other, SelectionIntersection): + return False + + return all(mine == theirs for mine, theirs in zip(self.components, other.components)) + + +class Difference: + def __init__(self, *args): + self.components = args + + def __str__(self): + return f'Difference(components={self.components})' + + def __repr__(self): + return f'Difference(components={self.components!r})' + + def __eq__(self, other): + if not isinstance(other, SelectionDifference): + return False + + return all(mine == theirs for mine, theirs in zip(self.components, other.components)) + + +class Criteria: + def __init__(self, method, value, **kwargs): + self.method = method + self.value = value + self.kwargs = kwargs + + def __str__(self): + return f'Criteria(method={self.method}, value={self.value}, **{self.kwargs})' + + def __repr__(self): + return f'Criteria(method={self.method!r}, value={self.value!r}, **{self.kwargs!r})' + + def __eq__(self, other): + if not isinstance(other, SelectionCriteria): + return False + return ( + self.method == other.method and + self.value == other.value and + all(getattr(other, k) == v for k, v in self.kwargs.items()) + ) + + +def test_parse_simple(): + sf = parse_file('''\ + selectors: + - name: tagged_foo + definition: + tag: foo + ''') + + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + assert len(parsed) == 1 + assert 'tagged_foo' in parsed + assert Criteria( + method=MethodName.Tag, + method_arguments=[], + value='foo', + children=False, + parents=False, + childrens_parents=False, + children_depth=None, + parents_depth=None, + ) == parsed['tagged_foo'] + + +def test_parse_simple_childrens_parents(): + sf = parse_file('''\ + selectors: + - name: tagged_foo + definition: + method: tag + value: foo + childrens_parents: True + ''') + + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + assert len(parsed) == 1 + assert 'tagged_foo' in parsed + assert Criteria( + method=MethodName.Tag, + method_arguments=[], + value='foo', + children=False, + parents=False, + childrens_parents=True, + children_depth=None, + parents_depth=None, + ) == parsed['tagged_foo'] + + +def test_parse_simple_arguments_with_modifiers(): + sf = parse_file('''\ + selectors: + - name: configured_view + definition: + method: config.materialized + value: view + parents: True + children: True + children_depth: 2 + ''') + + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + assert len(parsed) == 1 + assert 'configured_view' in parsed + assert Criteria( + method=MethodName.Config, + method_arguments=['materialized'], + value='view', + children=True, + parents=True, + childrens_parents=False, + children_depth=2, + parents_depth=None, + ) == parsed['configured_view'] + + +def test_parse_union(): + sf = parse_file('''\ + selectors: + - name: views-or-foos + definition: + union: + - method: config.materialized + value: view + - tag: foo + ''') + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + assert 'views-or-foos' in parsed + assert Union( + Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']), + Criteria(method=MethodName.Tag, value='foo', method_arguments=[]) + ) == parsed['views-or-foos'] + + +def test_parse_intersection(): + sf = parse_file('''\ + selectors: + - name: views-and-foos + definition: + intersection: + - method: config.materialized + value: view + - tag: foo + ''') + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + + assert 'views-and-foos' in parsed + assert Intersection( + Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']), + Criteria(method=MethodName.Tag, value='foo', method_arguments=[]), + ) == parsed['views-and-foos'] + + +def test_parse_union_excluding(): + sf = parse_file('''\ + selectors: + - name: views-or-foos-not-bars + definition: + union: + - method: config.materialized + value: view + - tag: foo + - exclude: + - tag: bar + ''') + assert len(sf.selectors) == 1 + parsed = cli.parse_from_selectors_definition(sf) + assert 'views-or-foos-not-bars' in parsed + assert Difference( + Union( + Criteria(method=MethodName.Config, value='view', method_arguments=['materialized']), + Criteria(method=MethodName.Tag, value='foo', method_arguments=[]) + ), + Criteria(method=MethodName.Tag, value='bar', method_arguments=[]), + ) == parsed['views-or-foos-not-bars'] + + +def test_parse_yaml_complex(): + sf = parse_file('''\ + selectors: + - name: test_name + definition: + union: + - intersection: + - tag: foo + - tag: bar + - union: + - package: snowplow + - config.materialized: incremental + - union: + - path: "models/snowplow/marketing/custom_events.sql" + - fqn: "snowplow.marketing" + - intersection: + - resource_type: seed + - package: snowplow + - exclude: + - country_codes + - intersection: + - tag: baz + - config.materialized: ephemeral + - name: weeknights + definition: + union: + - tag: nightly + - tag:weeknights_only + ''') + + assert len(sf.selectors) == 2 + parsed = cli.parse_from_selectors_definition(sf) + assert 'test_name' in parsed + assert 'weeknights' in parsed + assert Union( + Criteria(method=MethodName.Tag, value='nightly'), + Criteria(method=MethodName.Tag, value='weeknights_only'), + ) == parsed['weeknights'] + + assert Union( + Intersection( + Criteria(method=MethodName.Tag, value='foo'), + Criteria(method=MethodName.Tag, value='bar'), + Union( + Criteria(method=MethodName.Package, value='snowplow'), + Criteria(method=MethodName.Config, value='incremental', method_arguments=['materialized']), + ), + ), + Union( + Criteria(method=MethodName.Path, value="models/snowplow/marketing/custom_events.sql"), + Criteria(method=MethodName.FQN, value='snowplow.marketing'), + ), + Difference( + Intersection( + Criteria(method=MethodName.ResourceType, value='seed'), + Criteria(method=MethodName.Package, value='snowplow'), + ), + Union( + Criteria(method=MethodName.FQN, value='country_codes'), + Intersection( + Criteria(method=MethodName.Tag, value='baz'), + Criteria(method=MethodName.Config, value='ephemeral', method_arguments=['materialized']), + ), + ), + ), + ) == parsed['test_name'] diff --git a/test/unit/test_graph_selector_spec.py b/test/unit/test_graph_selector_spec.py index fdbd43e917e..68c8611ccac 100644 --- a/test/unit/test_graph_selector_spec.py +++ b/test/unit/test_graph_selector_spec.py @@ -18,11 +18,11 @@ def test_raw_parse_simple(): assert result.method == MethodName.FQN assert result.method_arguments == [] assert result.value == raw - assert not result.select_childrens_parents - assert not result.select_children - assert not result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert not result.childrens_parents + assert not result.children + assert not result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_simple_infer_path(): @@ -32,11 +32,11 @@ def test_raw_parse_simple_infer_path(): assert result.method == MethodName.Path assert result.method_arguments == [] assert result.value == raw - assert not result.select_childrens_parents - assert not result.select_children - assert not result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert not result.childrens_parents + assert not result.children + assert not result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_simple_infer_path_modified(): @@ -46,11 +46,11 @@ def test_raw_parse_simple_infer_path_modified(): assert result.method == MethodName.Path assert result.method_arguments == [] assert result.value == raw[1:] - assert result.select_childrens_parents - assert not result.select_children - assert not result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert result.childrens_parents + assert not result.children + assert not result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_simple_infer_fqn_parents(): @@ -60,11 +60,11 @@ def test_raw_parse_simple_infer_fqn_parents(): assert result.method == MethodName.FQN assert result.method_arguments == [] assert result.value == 'asdf' - assert not result.select_childrens_parents - assert not result.select_children - assert result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert not result.childrens_parents + assert not result.children + assert result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_simple_infer_fqn_children(): @@ -74,11 +74,11 @@ def test_raw_parse_simple_infer_fqn_children(): assert result.method == MethodName.FQN assert result.method_arguments == [] assert result.value == 'asdf' - assert not result.select_childrens_parents - assert result.select_children - assert not result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert not result.childrens_parents + assert result.children + assert not result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_complex(): @@ -88,11 +88,11 @@ def test_raw_parse_complex(): assert result.method == MethodName.Config assert result.method_arguments == ['arg', 'secondarg'] assert result.value == 'argument_value' - assert not result.select_childrens_parents - assert result.select_children - assert result.select_parents - assert result.select_parents_max_depth == 2 - assert result.select_children_max_depth == 4 + assert not result.childrens_parents + assert result.children + assert result.parents + assert result.parents_depth == 2 + assert result.children_depth == 4 def test_raw_parse_weird(): @@ -103,11 +103,11 @@ def test_raw_parse_weird(): assert result.method == MethodName.FQN assert result.method_arguments == [] assert result.value == '' - assert not result.select_childrens_parents - assert not result.select_children - assert not result.select_parents - assert result.select_parents_max_depth is None - assert result.select_children_max_depth is None + assert not result.childrens_parents + assert not result.children + assert not result.parents + assert result.parents_depth is None + assert result.children_depth is None def test_raw_parse_invalid(): diff --git a/test/unit/utils.py b/test/unit/utils.py index cd2989e982c..a6f17f8f01d 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -52,7 +52,7 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'): ) -def project_from_dict(project, profile, packages=None, cli_vars='{}'): +def project_from_dict(project, profile, packages=None, selectors=None, cli_vars='{}'): from dbt.context.target import generate_target_context from dbt.config import Project from dbt.config.renderer import DbtProjectYamlRenderer @@ -65,11 +65,11 @@ def project_from_dict(project, profile, packages=None, cli_vars='{}'): project_root = project.pop('project-root', os.getcwd()) return Project.render_from_dict( - project_root, project, packages, renderer + project_root, project, packages, selectors, renderer ) -def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): +def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars='{}'): from dbt.config import Project, Profile, RuntimeConfig from copy import deepcopy @@ -90,6 +90,7 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): deepcopy(project), profile, packages, + selectors, cli_vars, )