Skip to content

Commit

Permalink
Enable unit testing in non-root packages (#9184)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Nov 30, 2023
1 parent bf6bffa commit ca82f54
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231130-130948.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support unit tests in non-root packages
time: 2023-11-30T13:09:48.206007-05:00
custom:
Author: gshank
Issue: "8285"
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
limit: Optional[int],
limit: Optional[int] = None,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
Expand Down
72 changes: 35 additions & 37 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def load(self) -> Manifest:
return self.unit_test_manifest

def parse_unit_test_case(self, test_case: UnitTestDefinition):
package_name = self.root_project.project_name

# Create unit test node based on the node being tested
tested_node = self.manifest.ref_lookup.perform_lookup(
f"model.{package_name}.{test_case.model}", self.manifest
f"model.{test_case.package_name}.{test_case.model}", self.manifest
)
assert isinstance(tested_node, ModelNode)

Expand All @@ -68,7 +66,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
package_name=package_name,
package_name=test_case.package_name,
path=get_pseudo_test_path(name, test_case.original_file_path),
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
Expand All @@ -92,7 +90,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
unit_test_node, # type: ignore
self.root_project,
self.manifest,
package_name,
test_case.package_name,
)
get_rendered(unit_test_node.raw_code, ctx, unit_test_node, capture_macros=True)
# unit_test_node now has a populated refs/sources
Expand Down Expand Up @@ -121,7 +119,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
project_root = self.root_project.project_root
common_fields = {
"resource_type": NodeType.Model,
"package_name": package_name,
"package_name": test_case.package_name,
"original_file_path": original_input_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
Expand All @@ -142,7 +140,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
unique_id=f"model.{test_case.package_name}.{input_name}",
name=input_name,
path=original_input_node.path,
)
Expand All @@ -153,7 +151,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
input_name = f"{unit_test_node.name}__{original_input_node.search_name}__{original_input_node.name}"
input_node = UnitTestSourceDefinition(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
unique_id=f"model.{test_case.package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
Expand Down Expand Up @@ -227,35 +225,6 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None:
self.schema_parser = schema_parser
self.yaml = yaml

def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]

rows: List[Dict[str, Any]] = []

seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)

seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)

if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)

seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)

return rows

def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
Expand Down Expand Up @@ -351,3 +320,32 @@ def _build_fqn(self, package_name, original_file_path, model_name, test_name):
fqn.append(model_name)
fqn.append(test_name)
return fqn

def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]:
"""Read rows from seed file on disk if not specified in YAML config. If seed file doesn't exist, return empty list."""
ref = py_extract_from_source("{{ " + ref_str + " }}")["refs"][0]

rows: List[Dict[str, Any]] = []

seed_name = ref["name"]
package_name = ref.get("package", self.project.project_name)

seed_node = self.manifest.ref_lookup.find(seed_name, package_name, None, self.manifest)

if not seed_node or seed_node.resource_type != NodeType.Seed:
# Seed not found in custom package specified
if package_name != self.project.project_name:
raise ParsingError(

Check warning on line 338 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L338

Added line #L338 was not covered by tests
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in '{package_name}' package"
)
else:
raise ParsingError(
f"Unable to find seed '{package_name}.{seed_name}' for unit tests in directories: {self.project.seed_paths}"
)

seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)

return rows
114 changes: 114 additions & 0 deletions tests/functional/unit_testing/test_ut_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest
from dbt.tests.util import run_dbt, get_unique_ids_in_results
from dbt.tests.fixtures.project import write_project_files

local_dependency__dbt_project_yml = """
name: 'local_dep'
version: '1.0'
seeds:
quote_columns: False
"""

local_dependency__schema_yml = """
sources:
- name: seed_source
schema: "{{ var('schema_override', target.schema) }}"
tables:
- name: "seed"
columns:
- name: id
tests:
- unique
unit_tests:
- name: test_dep_model_id
model: dep_model
given:
- input: ref('seed')
rows:
- {id: 1, name: Joe}
expect:
rows:
- {name_id: Joe_1}
"""

local_dependency__dep_model_sql = """
select name || '_' || id as name_id from {{ ref('seed') }}
"""

local_dependency__seed_csv = """id,name
1,Mary
2,Sam
3,John
"""

my_model_sql = """
select * from {{ ref('dep_model') }}
"""

my_model_schema_yml = """
unit_tests:
- name: test_my_model_name_id
model: my_model
given:
- input: ref('dep_model')
rows:
- {name_id: Joe_1}
expect:
rows:
- {name_id: Joe_1}
"""


class TestUnitTestingInDependency:
@pytest.fixture(scope="class", autouse=True)
def setUp(self, project_root):
local_dependency_files = {
"dbt_project.yml": local_dependency__dbt_project_yml,
"models": {
"schema.yml": local_dependency__schema_yml,
"dep_model.sql": local_dependency__dep_model_sql,
},
"seeds": {"seed.csv": local_dependency__seed_csv},
}
write_project_files(project_root, "local_dependency", local_dependency_files)

@pytest.fixture(scope="class")
def packages(self):
return {"packages": [{"local": "local_dependency"}]}

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"schema.yml": my_model_schema_yml,
}

def test_unit_test_in_dependency(self, project):
run_dbt(["deps"])
run_dbt(["seed"])
results = run_dbt(["run"])
assert len(results) == 2

results = run_dbt(["test"])
assert len(results) == 3
unique_ids = get_unique_ids_in_results(results)
assert "unit_test.local_dep.dep_model.test_dep_model_id" in unique_ids

results = run_dbt(["test", "--select", "test_type:unit"])
# two unit tests, 1 in root package, one in local_dep package
assert len(results) == 2

results = run_dbt(["test", "--select", "local_dep"])
# 2 tests in local_dep package
assert len(results) == 2

results = run_dbt(["test", "--select", "test"])
# 1 test in root package
assert len(results) == 1

0 comments on commit ca82f54

Please sign in to comment.