Skip to content

Commit

Permalink
add generic tests to test-paths (#4052)
Browse files Browse the repository at this point in the history
* removed overlooked breakpoint

* first pass

* save progress - singualr tests broken

* fixed to work with both generic and singular tests

* fixed formatting

* added a comment

* change to use /generic subfolder

* fix formatting issues

* fixed bug on code consolidation

* fixed typo

* added test for generic tests

* added changelog entry

* added logic to treat generic tests like macro tests

* add generic test to macro_edges

* fixed generic tests to match unique_ids

* fixed test
  • Loading branch information
emmyoop authored Oct 21, 2021
1 parent 34c23fe commit f79a968
Show file tree
Hide file tree
Showing 16 changed files with 304 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Contributors:
- Turns on the static parser by default and adds the flag `--no-static-parser` to disable it. ([#3377](https://github.com/dbt-labs/dbt-core/issues/3377), [#3939](https://github.com/dbt-labs/dbt-core/pull/3939))
- Generic test FQNs have changed to include the relative path, resource, and column (if applicable) where they are defined. This makes it easier to configure them from the `tests` block in `dbt_project.yml` ([#3259](https://github.com/dbt-labs/dbt-core/pull/3259), [#3880](https://github.com/dbt-labs/dbt-core/pull/3880)
- Turn on partial parsing by default ([#3867](https://github.com/dbt-labs/dbt-core/issues/3867), [#3989](https://github.com/dbt-labs/dbt-core/issues/3989))
- Generic test can now be added under a `generic` subfolder in the `test-paths` directory. ([#4052](https://github.com/dbt-labs/dbt-core/pull/4052))

### Fixes
- Add generic tests defined on sources to the manifest once, not twice ([#3347](https://github.com/dbt-labs/dbt/issues/3347), [#3880](https://github.com/dbt-labs/dbt/pull/3880))
Expand Down
6 changes: 4 additions & 2 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class ParseFileType(StrEnum):
Model = 'model'
Snapshot = 'snapshot'
Analysis = 'analysis'
Test = 'test'
SingularTest = 'singular_test'
GenericTest = 'generic_test'
Seed = 'seed'
Documentation = 'docs'
Schema = 'schema'
Expand All @@ -30,7 +31,8 @@ class ParseFileType(StrEnum):
ParseFileType.Model: 'ModelParser',
ParseFileType.Snapshot: 'SnapshotParser',
ParseFileType.Analysis: 'AnalysisParser',
ParseFileType.Test: 'SingularTestParser',
ParseFileType.SingularTest: 'SingularTestParser',
ParseFileType.GenericTest: 'GenericTestParser',
ParseFileType.Seed: 'SeedParser',
ParseFileType.Documentation: 'DocumentationParser',
ParseFileType.Schema: 'SchemaParser',
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def build_node_edges(nodes: List[ManifestNode]):
return _sort_values(forward_edges), _sort_values(backward_edges)


# Build a map of children of macros
# Build a map of children of macros and generic tests
def build_macro_edges(nodes: List[Any]):
forward_edges: Dict[str, List[str]] = {
n.unique_id: [] for n in nodes if n.unique_id.startswith('macro') or n.depends_on.macros
Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class UnparsedMacro(UnparsedBaseNode, HasSQL):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})


@dataclass
class UnparsedGenericTest(UnparsedBaseNode, HasSQL):
resource_type: NodeType = field(metadata={'restrict': [NodeType.Macro]})


@dataclass
class UnparsedNode(UnparsedBaseNode, HasSQL):
name: str
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .analysis import AnalysisParser # noqa
from .base import Parser, ConfiguredParser # noqa
from .singular_test import SingularTestParser # noqa
from .generic_test import GenericTestParser # noqa
from .docs import DocumentationParser # noqa
from .hooks import HookParser # noqa
from .macros import MacroParser # noqa
Expand All @@ -10,6 +11,6 @@
from .snapshots import SnapshotParser # noqa

from . import ( # noqa
analysis, base, singular_test, docs, hooks, macros, models, schemas,
analysis, base, generic_test, singular_test, docs, hooks, macros, models, schemas,
snapshots
)
106 changes: 106 additions & 0 deletions core/dbt/parser/generic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Iterable, List

import jinja2

from dbt.exceptions import CompilationException
from dbt.clients import jinja
from dbt.contracts.graph.parsed import ParsedGenericTestNode
from dbt.contracts.graph.unparsed import UnparsedMacro
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.contracts.files import SourceFile
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt.parser.base import BaseParser
from dbt.parser.search import FileBlock
from dbt.utils import MACRO_PREFIX


class GenericTestParser(BaseParser[ParsedGenericTestNode]):

@property
def resource_type(self) -> NodeType:
return NodeType.Macro

@classmethod
def get_compiled_path(cls, block: FileBlock):
return block.path.relative_path

def parse_generic_test(
self, block: jinja.BlockTag, base_node: UnparsedMacro, name: str
) -> ParsedMacro:
unique_id = self.generate_unique_id(name)

return ParsedMacro(
path=base_node.path,
macro_sql=block.full_block,
original_file_path=base_node.original_file_path,
package_name=base_node.package_name,
root_path=base_node.root_path,
resource_type=base_node.resource_type,
name=name,
unique_id=unique_id,
)

def parse_unparsed_generic_test(
self, base_node: UnparsedMacro
) -> Iterable[ParsedMacro]:
try:
blocks: List[jinja.BlockTag] = [
t for t in
jinja.extract_toplevel_blocks(
base_node.raw_sql,
allowed_blocks={'test'},
collect_raw_data=False,
)
if isinstance(t, jinja.BlockTag)
]
except CompilationException as exc:
exc.add_node(base_node)
raise

for block in blocks:
try:
ast = jinja.parse(block.full_block)
except CompilationException as e:
e.add_node(base_node)
raise

# generic tests are structured as macros so we want to count the number of macro blocks
generic_test_nodes = list(ast.find_all(jinja2.nodes.Macro))

if len(generic_test_nodes) != 1:
# things have gone disastrously wrong, we thought we only
# parsed one block!
raise CompilationException(
f'Found multiple generic tests in {block.full_block}, expected 1',
node=base_node
)

generic_test_name = generic_test_nodes[0].name

if not generic_test_name.startswith(MACRO_PREFIX):
continue

name: str = generic_test_name.replace(MACRO_PREFIX, '')
node = self.parse_generic_test(block, base_node, name)
yield node

def parse_file(self, block: FileBlock):
assert isinstance(block.file, SourceFile)
source_file = block.file
assert isinstance(source_file.contents, str)
original_file_path = source_file.path.original_file_path
logger.debug("Parsing {}".format(original_file_path))

# this is really only used for error messages
base_node = UnparsedMacro(
path=original_file_path,
original_file_path=original_file_path,
package_name=self.project.project_name,
raw_sql=source_file.contents,
root_path=self.project.project_root,
resource_type=NodeType.Macro,
)

for node in self.parse_unparsed_generic_test(base_node):
self.manifest.add_macro(block.file, node)
30 changes: 20 additions & 10 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from dbt.parser.base import Parser
from dbt.parser.analysis import AnalysisParser
from dbt.parser.generic_test import GenericTestParser
from dbt.parser.singular_test import SingularTestParser
from dbt.parser.docs import DocumentationParser
from dbt.parser.hooks import HookParser
Expand Down Expand Up @@ -277,9 +278,10 @@ def load(self):
if skip_parsing:
logger.debug("Partial parsing enabled, no changes found, skipping parsing")
else:
# Load Macros
# Load Macros and tests
# We need to parse the macros first, so they're resolvable when
# the other files are loaded
# the other files are loaded. Also need to parse tests, specifically
# generic tests
start_load_macros = time.perf_counter()
self.load_and_parse_macros(project_parser_files)

Expand Down Expand Up @@ -379,14 +381,22 @@ def load_and_parse_macros(self, project_parser_files):
if project.project_name not in project_parser_files:
continue
parser_files = project_parser_files[project.project_name]
if 'MacroParser' not in parser_files:
continue
parser = MacroParser(project, self.manifest)
for file_id in parser_files['MacroParser']:
block = FileBlock(self.manifest.files[file_id])
parser.parse_file(block)
# increment parsed path count for performance tracking
self._perf_info.parsed_path_count = self._perf_info.parsed_path_count + 1
if 'MacroParser' in parser_files:
parser = MacroParser(project, self.manifest)
for file_id in parser_files['MacroParser']:
block = FileBlock(self.manifest.files[file_id])
parser.parse_file(block)
# increment parsed path count for performance tracking
self._perf_info.parsed_path_count = self._perf_info.parsed_path_count + 1
# generic tests hisotrically lived in the macros directoy but can now be nested
# in a /generic directory under /tests so we want to process them here as well
if 'GenericTestParser' in parser_files:
parser = GenericTestParser(project, self.manifest)
for file_id in parser_files['GenericTestParser']:
block = FileBlock(self.manifest.files[file_id])
parser.parse_file(block)
# increment parsed path count for performance tracking
self._perf_info.parsed_path_count = self._perf_info.parsed_path_count + 1

self.build_macro_resolver()
# Look at changed macros and update the macro.depends_on.macros
Expand Down
15 changes: 10 additions & 5 deletions core/dbt/parser/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
ParseFileType.Seed,
ParseFileType.Snapshot,
ParseFileType.Analysis,
ParseFileType.Test,
ParseFileType.SingularTest,
)

mg_files = (
ParseFileType.Macro,
ParseFileType.GenericTest,
)


Expand Down Expand Up @@ -88,7 +93,7 @@ def build_file_diff(self):
if self.saved_files[file_id].parse_file_type == ParseFileType.Schema:
deleted_schema_files.append(file_id)
else:
if self.saved_files[file_id].parse_file_type == ParseFileType.Macro:
if self.saved_files[file_id].parse_file_type in mg_files:
changed_or_deleted_macro_file = True
deleted.append(file_id)

Expand All @@ -106,7 +111,7 @@ def build_file_diff(self):
raise Exception(f"Serialization failure for {file_id}")
changed_schema_files.append(file_id)
else:
if self.saved_files[file_id].parse_file_type == ParseFileType.Macro:
if self.saved_files[file_id].parse_file_type in mg_files:
changed_or_deleted_macro_file = True
changed.append(file_id)
file_diff = {
Expand Down Expand Up @@ -213,7 +218,7 @@ def delete_from_saved(self, file_id):
self.deleted_manifest.files[file_id] = self.saved_manifest.files.pop(file_id)

# macros
if saved_source_file.parse_file_type == ParseFileType.Macro:
if saved_source_file.parse_file_type in mg_files:
self.delete_macro_file(saved_source_file, follow_references=True)

# docs
Expand All @@ -229,7 +234,7 @@ def update_in_saved(self, file_id):

if new_source_file.parse_file_type in mssat_files:
self.update_mssat_in_saved(new_source_file, old_source_file)
elif new_source_file.parse_file_type == ParseFileType.Macro:
elif new_source_file.parse_file_type in mg_files:
self.update_macro_in_saved(new_source_file, old_source_file)
elif new_source_file.parse_file_type == ParseFileType.Documentation:
self.update_doc_in_saved(new_source_file, old_source_file)
Expand Down
13 changes: 12 additions & 1 deletion core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def get_source_files(project, paths, extension, parse_file_type, saved_files):
for fp in fp_list:
if parse_file_type == ParseFileType.Seed:
fb_list.append(load_seed_source_file(fp, project.project_name))
# singular tests live in /tests but only generic tests live
# in /tests/generic so we want to skip those
elif (parse_file_type == ParseFileType.SingularTest and
'generic/' in fp.relative_path):
continue
else:
file = load_source_file(fp, parse_file_type, project.project_name, saved_files)
# only append the list if it has contents. added to fix #3568
Expand Down Expand Up @@ -137,7 +142,13 @@ def read_files(project, files, parser_files, saved_files):
)

project_files['SingularTestParser'] = read_files_for_parser(
project, files, project.test_paths, '.sql', ParseFileType.Test, saved_files
project, files, project.test_paths, '.sql', ParseFileType.SingularTest, saved_files
)

# all generic tests within /tests must be nested under a /generic subfolder
project_files['GenericTestParser'] = read_files_for_parser(
project, files, ["{}{}".format(test_path, '/generic') for test_path in project.test_paths],
'.sql', ParseFileType.GenericTest, saved_files
)

project_files['SeedParser'] = read_files_for_parser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ def test_postgres_test_context_with_macro_namespace(self):
run_result = self.run_dbt(['test'], expect_pass=False)
results = run_result.results
results = sorted(results, key=lambda r: r.node.name)
# breakpoint()
self.assertEqual(len(results), 4)
# call_pkg_macro_model_c_
self.assertEqual(results[0].status, TestStatus.Fail)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: 2

models:
- name: orders
description: "Some order data"
columns:
- name: id
tests:
- unique
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{% test is_odd(model, column_name) %}

with validation as (

select
{{ column_name }} as odd_field

from {{ model }}

),

validation_errors as (

select
odd_field

from validation
-- if this is true, then odd_field is actually even!
where (odd_field % 2) = 0

)

select *
from validation_errors

{% endtest %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{% test is_odd(model, column_name) %}

with validation as (

select
{{ column_name }} as odd_field2

from {{ model }}

),

validation_errors as (

select
odd_field2

from validation
-- if this is true, then odd_field is actually even!
where (odd_field2 % 2) = 0

)

select *
from validation_errors

{% endtest %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2

models:
- name: orders
description: "Some order data"
columns:
- name: id
tests:
- unique
- is_odd
Loading

0 comments on commit f79a968

Please sign in to comment.