diff --git a/dbt_meshify/cli.py b/dbt_meshify/cli.py
index f4b9f46..7a62408 100644
--- a/dbt_meshify/cli.py
+++ b/dbt_meshify/cli.py
@@ -12,6 +12,32 @@
help="The path to the dbt project to operate on. Defaults to the current directory.",
)
+project_paths = click.option(
+ "--project-paths",
+ cls=MultiOption,
+ multiple=True,
+ type=tuple,
+ default=None,
+ help="The paths to the set of dbt projects to connect. Must supply 2+ paths.",
+)
+
+projects_dir = click.option(
+ "--projects-dir",
+ type=click.Path(exists=True),
+ default=None,
+ help="The path to a directory containing multiple dbt projects. Directory must contain 2+ projects.",
+)
+
+exclude_projects = click.option(
+ "--exclude-projects",
+ "-e",
+ cls=MultiOption,
+ multiple=True,
+ type=tuple,
+ default=None,
+ help="The set of dbt projects to exclude from the operation when using the --projects-dir option.",
+)
+
create_path = click.option(
"--create-path",
type=click.Path(exists=False),
diff --git a/dbt_meshify/linker.py b/dbt_meshify/linker.py
index 1f49d39..551bdd9 100644
--- a/dbt_meshify/linker.py
+++ b/dbt_meshify/linker.py
@@ -1,8 +1,14 @@
from dataclasses import dataclass
from enum import Enum
-from typing import Set
+from typing import Set, Union
-from dbt_meshify.dbt_projects import BaseDbtProject
+from dbt.node_types import AccessType
+from loguru import logger
+
+from dbt_meshify.dbt_projects import BaseDbtProject, DbtProject
+from dbt_meshify.exceptions import FatalMeshifyException, FileEditorException
+from dbt_meshify.storage.dbt_project_editors import DbtProjectEditor, YMLOperationType
+from dbt_meshify.storage.file_content_editors import DbtMeshConstructor
class ProjectDependencyType(str, Enum):
@@ -16,12 +22,14 @@ class ProjectDependencyType(str, Enum):
class ProjectDependency:
"""ProjectDependencies define shared resources between two different projects"""
- upstream: str
- downstream: str
+ upstream_resource: str
+ upstream_project_name: str
+ downstream_resource: str
+ downstream_project_name: str
type: ProjectDependencyType
def __key(self):
- return self.upstream, self.downstream, self.type
+ return self.upstream_resource, self.downstream_resource, self.type
def __hash__(self):
return hash(self.__key())
@@ -45,12 +53,14 @@ def _find_relation_dependencies(
source_relations: Set[str], target_relations: Set[str]
) -> Set[str]:
"""
- Identify dependencies between projects using shared relations.
+ Identify dependencies between projects using shared relation names.
"""
return source_relations.intersection(target_relations)
def _source_dependencies(
- self, project: BaseDbtProject, other_project: BaseDbtProject
+ self,
+ project: Union[BaseDbtProject, DbtProject],
+ other_project: Union[BaseDbtProject, DbtProject],
) -> Set[ProjectDependency]:
"""
Identify source-hack dependencies between projects.
@@ -74,8 +84,10 @@ def _source_dependencies(
forward_dependencies = {
ProjectDependency(
- upstream=project.model_relation_names[relation],
- downstream=other_project.source_relation_names[relation],
+ upstream_resource=project.model_relation_names[relation],
+ upstream_project_name=project.name,
+ downstream_resource=other_project.source_relation_names[relation],
+ downstream_project_name=other_project.name,
type=ProjectDependencyType.Source,
)
for relation in relations
@@ -96,8 +108,10 @@ def _source_dependencies(
backward_dependencies = {
ProjectDependency(
- upstream=other_project.model_relation_names[relation],
- downstream=project.source_relation_names[relation],
+ upstream_resource=other_project.model_relation_names[relation],
+ upstream_project_name=other_project.name,
+ downstream_resource=project.source_relation_names[relation],
+ downstream_project_name=project.name,
type=ProjectDependencyType.Source,
)
for relation in backwards_relations
@@ -106,7 +120,9 @@ def _source_dependencies(
return forward_dependencies | backward_dependencies
def _package_dependencies(
- self, project: BaseDbtProject, other_project: BaseDbtProject
+ self,
+ project: Union[BaseDbtProject, DbtProject],
+ other_project: Union[BaseDbtProject, DbtProject],
) -> Set[ProjectDependency]:
"""
Identify package-imported dependencies between projects.
@@ -115,9 +131,13 @@ def _package_dependencies(
project (Project B) imports Project A and references the model.
"""
- if project.project_id not in other_project.installed_packages():
+ if (
+ project.project_id not in other_project.installed_packages()
+ and other_project.project_id not in project.installed_packages()
+ ):
return set()
+ # find which models are in both manifests
relations = self._find_relation_dependencies(
source_relations={
model.relation_name
@@ -131,17 +151,54 @@ def _package_dependencies(
},
)
- return {
+ # find the children of the shared models in the downstream project
+ package_children = [
+ {
+ 'upstream_resource': project.model_relation_names[relation],
+ 'downstream_resource': child,
+ }
+ for relation in relations
+ for child in other_project.manifest.child_map[project.model_relation_names[relation]]
+ ]
+
+ forward_dependencies = {
ProjectDependency(
- upstream=project.model_relation_names[relation],
- downstream=other_project.model_relation_names[relation],
+ upstream_resource=child['upstream_resource'],
+ upstream_project_name=project.name,
+ downstream_resource=child['downstream_resource'],
+ downstream_project_name=other_project.name,
type=ProjectDependencyType.Package,
)
+ for child in package_children
+ }
+
+ # find the children of the shared models in the downstream project
+ backward_package_children = [
+ {
+ 'upstream_resource': other_project.model_relation_names[relation],
+ 'downstream_resource': child,
+ }
for relation in relations
+ for child in project.manifest.child_map[other_project.model_relation_names[relation]]
+ ]
+
+ backward_dependencies = {
+ ProjectDependency(
+ upstream_resource=child['upstream_resource'],
+ upstream_project_name=other_project.name,
+ downstream_resource=child['downstream_resource'],
+ downstream_project_name=project.name,
+ type=ProjectDependencyType.Package,
+ )
+ for child in backward_package_children
}
+ return forward_dependencies | backward_dependencies
+
def dependencies(
- self, project: BaseDbtProject, other_project: BaseDbtProject
+ self,
+ project: Union[BaseDbtProject, DbtProject],
+ other_project: Union[BaseDbtProject, DbtProject],
) -> Set[ProjectDependency]:
"""Detect dependencies between two projects and return a list of resources shared."""
@@ -156,3 +213,102 @@ def dependencies(
dependencies.update(package_dependencies)
return dependencies
+
+ def resolve_dependency(
+ self,
+ dependency: ProjectDependency,
+ upstream_project: DbtProject,
+ downstream_project: DbtProject,
+ ):
+ upstream_manifest_entry = upstream_project.get_manifest_node(dependency.upstream_resource)
+ if not upstream_manifest_entry:
+ raise ValueError(
+ f"Could not find upstream resource {dependency.upstream_resource} in project {upstream_project.name}"
+ )
+ downstream_manifest_entry = downstream_project.get_manifest_node(
+ dependency.downstream_resource
+ )
+ upstream_catalog_entry = upstream_project.get_catalog_entry(dependency.upstream_resource)
+ upstream_mesh_constructor = DbtMeshConstructor(
+ project_path=upstream_project.path,
+ node=upstream_manifest_entry, # type: ignore
+ catalog=upstream_catalog_entry,
+ )
+
+ downstream_mesh_constructor = DbtMeshConstructor(
+ project_path=downstream_project.path,
+ node=downstream_manifest_entry, # type: ignore
+ catalog=None,
+ )
+ downstream_editor = DbtProjectEditor(downstream_project)
+ # for either dependency type, add contracts and make model public
+ try:
+ upstream_mesh_constructor.add_model_access(AccessType.Public)
+ logger.success(
+ f"Successfully update model access : {dependency.upstream_resource} is now {AccessType.Public}"
+ )
+ except FatalMeshifyException as e:
+ logger.error(f"Failed to update model access: {dependency.upstream_resource}")
+ raise e
+ try:
+ upstream_mesh_constructor.add_model_contract()
+ logger.success(f"Successfully added contract to model: {dependency.upstream_resource}")
+ except FatalMeshifyException as e:
+ logger.error(f"Failed to add contract to model: {dependency.upstream_resource}")
+ raise e
+
+ if dependency.type == ProjectDependencyType.Source:
+ for child in downstream_project.manifest.child_map[dependency.downstream_resource]:
+ constructor = DbtMeshConstructor(
+ project_path=downstream_project.path,
+ node=downstream_project.get_manifest_node(child), # type: ignore
+ catalog=None,
+ )
+ try:
+ constructor.replace_source_with_refs(
+ source_unique_id=dependency.downstream_resource,
+ model_unique_id=dependency.upstream_resource,
+ )
+ logger.success(
+ f"Successfully replaced source function with ref to upstream resource: {dependency.downstream_resource} now calls {dependency.upstream_resource} directly"
+ )
+ except FileEditorException as e:
+ logger.error(
+ f"Failed to replace source function with ref to upstream resource"
+ )
+ raise e
+ try:
+ downstream_editor.update_resource_yml_entry(
+ downstream_mesh_constructor, operation_type=YMLOperationType.Delete
+ )
+ logger.success(
+ f"Successfully deleted unnecessary source: {dependency.downstream_resource}"
+ )
+ except FatalMeshifyException as e:
+ logger.error(f"Failed to delete unnecessary source")
+ raise e
+
+ if dependency.type == ProjectDependencyType.Package:
+ try:
+ downstream_mesh_constructor.update_model_refs(
+ model_name=upstream_manifest_entry.name,
+ project_name=dependency.upstream_project_name,
+ )
+ logger.success(
+ f"Successfully updated model refs: {dependency.downstream_resource} now references {dependency.upstream_resource}"
+ )
+ except FileEditorException as e:
+ logger.error(f"Failed to update model refs")
+ raise e
+
+ # for both types, add upstream project to downstream project's dependencies.yml
+ try:
+ downstream_editor.update_dependencies_yml(name=upstream_project.name)
+ logger.success(
+ f"Successfully added {dependency.upstream_project_name} to {dependency.downstream_project_name}'s dependencies.yml"
+ )
+ except FileEditorException as e:
+ logger.error(
+ f"Failed to add {dependency.upstream_project_name} to {dependency.downstream_project_name}'s dependencies.yml"
+ )
+ raise e
diff --git a/dbt_meshify/main.py b/dbt_meshify/main.py
index da7ccaf..4999148 100644
--- a/dbt_meshify/main.py
+++ b/dbt_meshify/main.py
@@ -1,31 +1,36 @@
import os
import sys
+from itertools import combinations
from pathlib import Path
-from typing import Optional
+from typing import List, Optional
import click
import yaml
from dbt.contracts.graph.unparsed import Owner
from loguru import logger
-from dbt_meshify.storage.dbt_project_creator import DbtSubprojectCreator
+from dbt_meshify.storage.dbt_project_editors import DbtSubprojectCreator
from .cli import (
TupleCompatibleCommand,
create_path,
exclude,
+ exclude_projects,
group_yml_path,
owner,
owner_email,
owner_name,
owner_properties,
project_path,
+ project_paths,
+ projects_dir,
read_catalog,
select,
selector,
)
from .dbt_projects import DbtProject, DbtProjectHolder
from .exceptions import FatalMeshifyException
+from .linker import Linker
from .storage.file_content_editors import DbtMeshConstructor
log_format = "{time:HH:mm:ss} | {level} | {message}"
@@ -48,26 +53,80 @@ def operation():
@cli.command(name="connect")
-@click.argument("projects-dir", type=click.Path(exists=True), default=".")
-def connect(projects_dir):
+@project_paths
+@projects_dir
+@exclude_projects
+@read_catalog
+def connect(
+ project_paths: tuple, projects_dir: Path, exclude_projects: List[str], read_catalog: bool
+):
"""
- !!! info
- This command is not yet implemented
-
Connects multiple dbt projects together by adding all necessary dbt Mesh constructs
"""
- holder = DbtProjectHolder()
-
- while True:
- path_string = input("Enter the relative path to a dbt project (enter 'done' to finish): ")
- if path_string == "done":
- break
-
- path = Path(path_string).expanduser().resolve()
- project = DbtProject.from_directory(path, read_catalog)
- holder.register_project(project)
-
- print(holder.project_map())
+ if project_paths and projects_dir:
+ raise click.BadOptionUsage(
+ option_name="project_paths",
+ message="Cannot specify both project_paths and projects_dir",
+ )
+ # 1. initialize all the projects supplied to the command
+ # 2. compute the dependency graph between each combination of 2 projects in that set
+ # 3. for each dependency, add the necessary dbt Mesh constructs to each project.
+ # This includes:
+ # - adding the dependency to the dependencies.yml file of the downstream project
+ # - adding contracts and public access to the upstream models
+ # - deleting the source definition of the upstream models in the downstream project
+ # - updating the `{{ source }}` macro in the downstream project to a {{ ref }} to the upstream project
+
+ linker = Linker()
+ if project_paths:
+ dbt_projects = [
+ DbtProject.from_directory(project_path, read_catalog) for project_path in project_paths
+ ]
+
+ if projects_dir:
+ dbt_project_paths = [path.parent for path in Path(projects_dir).glob("**/dbt_project.yml")]
+ all_dbt_projects = [
+ DbtProject.from_directory(project_path, read_catalog)
+ for project_path in dbt_project_paths
+ ]
+ dbt_projects = [
+ project for project in all_dbt_projects if project.name not in exclude_projects
+ ]
+
+ project_map = {project.name: project for project in dbt_projects}
+ dbt_project_combinations = [combo for combo in combinations(dbt_projects, 2)]
+ all_dependencies = set()
+ for dbt_project_combo in dbt_project_combinations:
+ dependencies = linker.dependencies(dbt_project_combo[0], dbt_project_combo[1])
+ if len(dependencies) == 0:
+ logger.info(
+ f"No dependencies found between {dbt_project_combo[0].name} and {dbt_project_combo[1].name}"
+ )
+ continue
+
+ noun = "dependency" if len(dependencies) == 1 else "dependencies"
+ logger.info(
+ f"Found {len(dependencies)} {noun} between {dbt_project_combo[0].name} and {dbt_project_combo[1].name}"
+ )
+ all_dependencies.update(dependencies)
+ if len(all_dependencies) == 0:
+ logger.info("No dependencies found between any of the projects")
+ return
+
+ noun = "dependency" if len(all_dependencies) == 1 else "dependencies"
+ logger.info(f"Found {len(all_dependencies)} unique {noun} between all projects.")
+ for dependency in all_dependencies:
+ logger.info(
+ f"Resolving dependency between {dependency.upstream_resource} and {dependency.downstream_resource}"
+ )
+ try:
+ linker.resolve_dependency(
+ dependency,
+ project_map[dependency.upstream_project_name],
+ project_map[dependency.downstream_project_name],
+ )
+ except Exception as e:
+ raise FatalMeshifyException(f"Error resolving dependency : {dependency} {e}")
@cli.command(
@@ -98,13 +157,13 @@ def split(ctx, project_name, select, exclude, project_path, selector, create_pat
create_path = Path(create_path).expanduser().resolve()
create_path.parent.mkdir(parents=True, exist_ok=True)
- subproject_creator = DbtSubprojectCreator(subproject=subproject, target_directory=create_path)
+ subproject_creator = DbtSubprojectCreator(project=subproject, target_directory=create_path)
logger.info(f"Creating subproject {subproject.name}...")
try:
subproject_creator.initialize()
logger.success(f"Successfully created subproject {subproject.name}")
except Exception as e:
- raise FatalMeshifyException(f"Error creating subproject {subproject.name}")
+ raise FatalMeshifyException(f"Error creating subproject {subproject.name}: error {e}")
@operation.command(name="add-contract")
@@ -230,7 +289,7 @@ def create_group(
)
group_owner: Owner = Owner(
- name=owner_name, email=owner_email, _extra=yaml.safe_load(owner_properties or '{}')
+ name=owner_name, email=owner_email, _extra=yaml.safe_load(owner_properties or "{}")
)
grouper = ResourceGrouper(project)
@@ -279,4 +338,4 @@ def group(
Detects the edges of the group, makes their access public, and adds contracts to them
"""
ctx.forward(create_group)
- ctx.invoke(add_contract, select=f'group:{name}', project_path=project_path, public_only=True)
+ ctx.invoke(add_contract, select=f"group:{name}", project_path=project_path, public_only=True)
diff --git a/dbt_meshify/storage/dbt_project_creator.py b/dbt_meshify/storage/dbt_project_editors.py
similarity index 77%
rename from dbt_meshify/storage/dbt_project_creator.py
rename to dbt_meshify/storage/dbt_project_editors.py
index e049763..2f09fe3 100644
--- a/dbt_meshify/storage/dbt_project_creator.py
+++ b/dbt_meshify/storage/dbt_project_editors.py
@@ -1,11 +1,13 @@
+from enum import Enum
from pathlib import Path
-from typing import Optional, Set
+from typing import Optional, Set, Union
from dbt.contracts.graph.nodes import ManifestNode
+from dbt.contracts.util import Identifier
from dbt.node_types import AccessType
from loguru import logger
-from dbt_meshify.dbt_projects import DbtSubProject
+from dbt_meshify.dbt_projects import DbtProject, DbtSubProject
from dbt_meshify.storage.file_content_editors import (
DbtMeshConstructor,
filter_empty_dict_items,
@@ -14,28 +16,107 @@
from dbt_meshify.utilities.grouper import ResourceGrouper
-class DbtSubprojectCreator:
+class YMLOperationType(str, Enum):
+ """ProjectDependencyTypes define how the dependency relationship was defined."""
+
+ Move = "move"
+ Delete = "delete"
+
+
+class DbtProjectEditor:
+ def __init__(self, project: Union[DbtSubProject, DbtProject]):
+ self.project = project
+ self.file_manager = DbtFileManager(
+ read_project_path=project.path,
+ )
+
+ def move_resource(self, meshify_constructor: DbtMeshConstructor) -> None:
+ """
+ move a resource file from one project to another
+
+ """
+ current_path = meshify_constructor.get_resource_path()
+ self.file_manager.move_file(current_path)
+
+ def copy_resource(self, meshify_constructor: DbtMeshConstructor) -> None:
+ """
+ copy a resource file from one project to another
+
+ """
+ resource_path = meshify_constructor.get_resource_path()
+ contents = self.file_manager.read_file(resource_path)
+ self.file_manager.write_file(resource_path, contents)
+
+ def update_resource_yml_entry(
+ self,
+ meshify_constructor: DbtMeshConstructor,
+ operation_type: YMLOperationType = YMLOperationType.Move,
+ ) -> None:
+ """
+ move a resource yml entry from one project to another
+ """
+ current_yml_path = meshify_constructor.get_patch_path()
+ new_yml_path = self.file_manager.write_project_path / current_yml_path
+ full_yml_entry = self.file_manager.read_file(current_yml_path)
+ source_name = (
+ meshify_constructor.node.source_name
+ if hasattr(meshify_constructor.node, "source_name")
+ else None
+ )
+ resource_entry, remainder = meshify_constructor.get_yml_entry(
+ resource_name=meshify_constructor.node.name,
+ full_yml=full_yml_entry, # type: ignore
+ resource_type=meshify_constructor.node.resource_type,
+ source_name=source_name,
+ )
+ try:
+ existing_yml = self.file_manager.read_file(new_yml_path)
+ except FileNotFoundError:
+ existing_yml = None
+ if operation_type == YMLOperationType.Move:
+ new_yml_contents = meshify_constructor.add_entry_to_yml(
+ resource_entry, existing_yml, meshify_constructor.node.resource_type # type: ignore
+ )
+ self.file_manager.write_file(current_yml_path, new_yml_contents)
+ if remainder:
+ self.file_manager.write_file(current_yml_path, remainder, writeback=True)
+ else:
+ self.file_manager.delete_file(current_yml_path)
+
+ def update_dependencies_yml(self, name: Union[str, None] = None) -> None:
+ try:
+ contents = self.file_manager.read_file(Path("dependencies.yml"))
+ except FileNotFoundError:
+ contents = {"projects": []}
+
+ contents["projects"].append({"name": str(Identifier(name)) if name else self.project.name}) # type: ignore
+ self.file_manager.write_file(Path("dependencies.yml"), contents, writeback=True)
+
+
+class DbtSubprojectCreator(DbtProjectEditor):
"""
Takes a `DbtSubProject` and creates the directory structure and files for it.
"""
- def __init__(self, subproject: DbtSubProject, target_directory: Optional[Path] = None):
- self.subproject = subproject
- self.target_directory = target_directory if target_directory else subproject.path
+ def __init__(self, project: DbtSubProject, target_directory: Optional[Path] = None):
+ if not isinstance(project, DbtSubProject):
+ raise TypeError(f"DbtSubprojectCreator requires a DbtSubProject, got {type(project)}")
+ super().__init__(project)
+ self.target_directory = target_directory if target_directory else project.path
self.file_manager = DbtFileManager(
- read_project_path=subproject.parent_project.path,
+ read_project_path=project.parent_project.path,
write_project_path=self.target_directory,
)
- self.subproject_boundary_models = self._get_subproject_boundary_models()
+ self.project_boundary_models = self._get_subproject_boundary_models()
def _get_subproject_boundary_models(self) -> Set[str]:
"""
get a set of boundary model unique_ids for all the selected resources
"""
- nodes = set(filter(lambda x: not x.startswith("source"), self.subproject.resources))
- parent_project_name = self.subproject.parent_project.name
+ nodes = set(filter(lambda x: not x.startswith("source"), self.project.resources)) # type: ignore
+ parent_project_name = self.project.parent_project.name # type: ignore
interface = ResourceGrouper.identify_interface(
- graph=self.subproject.graph.graph, selected_bunch=nodes
+ graph=self.project.graph.graph, selected_bunch=nodes
)
boundary_models = set(
filter(
@@ -49,7 +130,7 @@ def write_project_file(self) -> None:
"""
Writes the dbt_project.yml file for the subproject in the specified subdirectory
"""
- contents = self.subproject.project.to_dict()
+ contents = self.project.project.to_dict()
# was gettinga weird serialization error from ruamel on this value
# it's been deprecated, so no reason to keep it
contents.pop("version")
@@ -71,48 +152,39 @@ def copy_packages_dir(self) -> None:
"""
raise NotImplementedError("copy_packages_dir not implemented yet")
- def update_dependencies_yml(self) -> None:
- try:
- contents = self.file_manager.read_file(Path("dependencies.yml"))
- except FileNotFoundError:
- contents = {"projects": []}
-
- contents["projects"].append({"name": self.subproject.name}) # type: ignore
- self.file_manager.write_file(Path("dependencies.yml"), contents, writeback=True)
-
def update_child_refs(self, resource: ManifestNode) -> None:
downstream_models = [
node.unique_id
- for node in self.subproject.manifest.nodes.values()
+ for node in self.project.manifest.nodes.values()
if node.resource_type == "model" and resource.unique_id in node.depends_on.nodes # type: ignore
]
for model in downstream_models:
- model_node = self.subproject.get_manifest_node(model)
+ model_node = self.project.get_manifest_node(model)
if not model_node:
raise KeyError(f"Resource {model} not found in manifest")
meshify_constructor = DbtMeshConstructor(
- project_path=self.subproject.parent_project.path, node=model_node, catalog=None
+ project_path=self.project.parent_project.path, node=model_node, catalog=None # type: ignore
)
meshify_constructor.update_model_refs(
- model_name=resource.name, project_name=self.subproject.name
+ model_name=resource.name, project_name=self.project.name
)
def initialize(self) -> None:
"""Initialize this subproject as a full dbt project at the provided `target_directory`."""
- subproject = self.subproject
- for unique_id in subproject.resources | subproject.custom_macros | subproject.groups:
+ subproject = self.project
+ for unique_id in subproject.resources | subproject.custom_macros | subproject.groups: # type: ignore
resource = subproject.get_manifest_node(unique_id)
catalog = subproject.get_catalog_entry(unique_id)
if not resource:
raise KeyError(f"Resource {unique_id} not found in manifest")
meshify_constructor = DbtMeshConstructor(
- project_path=subproject.parent_project.path, node=resource, catalog=catalog
+ project_path=subproject.parent_project.path, node=resource, catalog=catalog # type: ignore
)
if resource.resource_type in ["model", "test", "snapshot", "seed"]:
# ignore generic tests, as moving the yml entry will move the test too
if resource.resource_type == "test" and len(resource.unique_id.split(".")) == 4:
continue
- if resource.unique_id in self.subproject_boundary_models:
+ if resource.unique_id in self.project_boundary_models:
logger.info(
f"Adding contract to and publicizing boundary node {resource.unique_id}"
)
@@ -130,7 +202,7 @@ def initialize(self) -> None:
# apply access method too
logger.info(f"Updating ref functions for children of {resource.unique_id}...")
try:
- self.update_child_refs(resource)
+ self.update_child_refs(resource) # type: ignore
logger.success(
f"Successfully updated ref functions for children of {resource.unique_id}"
)
@@ -145,7 +217,7 @@ def initialize(self) -> None:
)
try:
self.move_resource(meshify_constructor)
- self.move_resource_yml_entry(meshify_constructor)
+ self.update_resource_yml_entry(meshify_constructor)
logger.success(
f"Successfully moved {resource.unique_id} and associated YML to subproject {subproject.name}"
)
@@ -172,7 +244,7 @@ def initialize(self) -> None:
f"Moving resource {resource.unique_id} to subproject {subproject.name}..."
)
try:
- self.move_resource_yml_entry(meshify_constructor)
+ self.update_resource_yml_entry(meshify_constructor)
logger.success(
f"Successfully moved resource {resource.unique_id} to subproject {subproject.name}"
)
@@ -186,51 +258,3 @@ def initialize(self) -> None:
self.copy_packages_yml_file()
self.update_dependencies_yml()
# self.copy_packages_dir()
-
- def move_resource(self, meshify_constructor: DbtMeshConstructor) -> None:
- """
- move a resource file from one project to another
-
- """
- current_path = meshify_constructor.get_resource_path()
- self.file_manager.move_file(current_path)
-
- def copy_resource(self, meshify_constructor: DbtMeshConstructor) -> None:
- """
- copy a resource file from one project to another
-
- """
- resource_path = meshify_constructor.get_resource_path()
- contents = self.file_manager.read_file(resource_path)
- self.file_manager.write_file(resource_path, contents)
-
- def move_resource_yml_entry(self, meshify_constructor: DbtMeshConstructor) -> None:
- """
- move a resource yml entry from one project to another
- """
- current_yml_path = meshify_constructor.get_patch_path()
- new_yml_path = self.file_manager.write_project_path / current_yml_path
- full_yml_entry = self.file_manager.read_file(current_yml_path)
- source_name = (
- meshify_constructor.node.source_name
- if hasattr(meshify_constructor.node, "source_name")
- else None
- )
- resource_entry, remainder = meshify_constructor.get_yml_entry(
- resource_name=meshify_constructor.node.name,
- full_yml=full_yml_entry, # type: ignore
- resource_type=meshify_constructor.node.resource_type,
- source_name=source_name,
- )
- try:
- existing_yml = self.file_manager.read_file(new_yml_path)
- except FileNotFoundError:
- existing_yml = None
- new_yml_contents = meshify_constructor.add_entry_to_yml(
- resource_entry, existing_yml, meshify_constructor.node.resource_type # type: ignore
- )
- self.file_manager.write_file(current_yml_path, new_yml_contents)
- if remainder:
- self.file_manager.write_file(current_yml_path, remainder, writeback=True)
- else:
- self.file_manager.delete_file(current_yml_path)
diff --git a/dbt_meshify/storage/file_content_editors.py b/dbt_meshify/storage/file_content_editors.py
index 4900972..d2831cc 100644
--- a/dbt_meshify/storage/file_content_editors.py
+++ b/dbt_meshify/storage/file_content_editors.py
@@ -281,7 +281,7 @@ def add_model_version_to_yml(
models_yml["models"] = list(models.values())
return models_yml
- def update_sql_refs(self, model_code: str, model_name: str, project_name: str):
+ def update_refs__sql(self, model_code: str, model_name: str, project_name: str):
import re
# pattern to search for ref() with optional spaces and either single or double quotes
@@ -295,7 +295,32 @@ def update_sql_refs(self, model_code: str, model_name: str, project_name: str):
return new_code
- def update_python_refs(self, model_code: str, model_name: str, project_name: str):
+ def replace_source_with_ref__sql(
+ self, model_code: str, source_unique_id: str, model_unique_id: str
+ ):
+ import re
+
+ source_parsed = source_unique_id.split(".")
+ model_parsed = model_unique_id.split(".")
+
+ # pattern to search for source() with optional spaces and either single or double quotes
+ pattern = re.compile(
+ r"{{\s*source\s*\(\s*['\"]"
+ + re.escape(source_parsed[2])
+ + r"['\"]\s*,\s*['\"]"
+ + re.escape(source_parsed[3])
+ + r"['\"]\s*\)\s*}}"
+ )
+
+ # replacement string with the new format
+ replacement = f"{{{{ ref('{model_parsed[1]}', '{model_parsed[2]}') }}}}"
+
+ # perform replacement
+ new_code = re.sub(pattern, replacement, model_code)
+
+ return new_code
+
+ def update_refs__python(self, model_code: str, model_name: str, project_name: str):
import re
# pattern to search for ref() with optional spaces and either single or double quotes
@@ -309,6 +334,31 @@ def update_python_refs(self, model_code: str, model_name: str, project_name: str
return new_code
+ def replace_source_with_ref__python(
+ self, model_code: str, source_unique_id: str, model_unique_id: str
+ ):
+ import re
+
+ source_parsed = source_unique_id.split(".")
+ model_parsed = model_unique_id.split(".")
+
+ # pattern to search for source() with optional spaces and either single or double quotes
+ pattern = re.compile(
+ r"dbt\.source\s*\(\s*['\"]"
+ + re.escape(source_parsed[2])
+ + r"['\"]\s*,\s*['\"]"
+ + re.escape(source_parsed[3])
+ + r"['\"]\s*\)"
+ )
+
+ # replacement string with the new format
+ replacement = f'dbt.ref("{model_parsed[1]}", "{model_parsed[2]}")'
+
+ # perform replacement
+ new_code = re.sub(pattern, replacement, model_code)
+
+ return new_code
+
class DbtMeshConstructor(DbtMeshFileEditor):
def __init__(
@@ -347,7 +397,6 @@ def get_patch_path(self) -> Path:
filename = f"_{self.node.resource_type.pluralize()}.yml"
yml_path = resource_path.parent / filename
self.file_manager.write_file(yml_path, {})
- logger.info(f"Schema entry for {self.node.unique_id} written to {yml_path}")
return yml_path
def get_resource_path(self) -> Path:
@@ -362,6 +411,7 @@ def get_resource_path(self) -> Path:
def add_model_contract(self) -> None:
"""Adds a model contract to the model's yaml"""
yml_path = self.get_patch_path()
+ logger.info(f"Adding contract to {self.node.name} at {yml_path}")
# read the yml file
# pass empty dict if no file contents returned
models_yml = self.file_manager.read_file(yml_path)
@@ -382,7 +432,7 @@ def add_model_contract(self) -> None:
def add_model_access(self, access_type: AccessType) -> None:
"""Adds a model contract to the model's yaml"""
yml_path = self.get_patch_path()
- logger.info(f"Adding model contract for {self.node.name} at {yml_path}")
+ logger.info(f"Adding {access_type} access to {self.node.name} at {yml_path}")
# read the yml file
# pass empty dict if no file contents returned
models_yml = self.file_manager.read_file(yml_path)
@@ -478,7 +528,7 @@ def update_model_refs(self, model_name: str, project_name: str) -> None:
# read the model file
model_code = str(self.file_manager.read_file(model_path))
# This can be defined in the init for this clas.
- ref_update_methods = {'sql': self.update_sql_refs, 'python': self.update_python_refs}
+ ref_update_methods = {'sql': self.update_refs__sql, 'python': self.update_refs__python}
# Here, we're trusting the dbt-core code to check the languages for us. 🐉
updated_code = ref_update_methods[self.node.language](
model_name=model_name,
@@ -487,3 +537,28 @@ def update_model_refs(self, model_name: str, project_name: str) -> None:
)
# write the updated model code to the file
self.file_manager.write_file(model_path, updated_code)
+
+ def replace_source_with_refs(self, source_unique_id: str, model_unique_id: str) -> None:
+ """Updates the model refs in the model's sql file"""
+ model_path = self.get_resource_path()
+
+ if model_path is None:
+ raise ModelFileNotFoundError(
+ f"Unable to find path to model {self.node.name}. Aborting."
+ )
+
+ # read the model file
+ model_code = str(self.file_manager.read_file(model_path))
+ # This can be defined in the init for this clas.
+ ref_update_methods = {
+ 'sql': self.replace_source_with_ref__sql,
+ 'python': self.replace_source_with_ref__python,
+ }
+ # Here, we're trusting the dbt-core code to check the languages for us. 🐉
+ updated_code = ref_update_methods[self.node.language](
+ model_code=model_code,
+ source_unique_id=source_unique_id,
+ model_unique_id=model_unique_id,
+ )
+ # write the updated model code to the file
+ self.file_manager.write_file(model_path, updated_code)
diff --git a/dbt_meshify/storage/file_manager.py b/dbt_meshify/storage/file_manager.py
index 850b4a4..137085b 100644
--- a/dbt_meshify/storage/file_manager.py
+++ b/dbt_meshify/storage/file_manager.py
@@ -5,8 +5,10 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union
+from dbt.contracts.util import Identifier
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
+from ruamel.yaml.representer import Representer
class DbtYAML(YAML):
@@ -29,7 +31,7 @@ def dump(self, data, stream=None, **kw):
yaml = DbtYAML()
-
+yaml.register_class(Identifier)
FileContent = Union[Dict[str, str], str]
@@ -55,7 +57,7 @@ def __init__(
self.write_project_path = write_project_path if write_project_path else read_project_path
def read_file(self, path: Path) -> Union[Dict[str, Any], str]:
- """Returns the yaml for a model in the dbt project's manifest"""
+ """Returns the file contents at a given path"""
full_path = self.read_project_path / path
if full_path.suffix == ".yml":
return yaml.load(full_path.read_text())
@@ -96,4 +98,5 @@ def move_file(self, path: Path) -> None:
def delete_file(self, path: Path) -> None:
"""deletes the specified file"""
- path.unlink()
+ delete_path = self.read_project_path / path
+ delete_path.unlink()
diff --git a/tests/integration/test_connect_command.py b/tests/integration/test_connect_command.py
new file mode 100644
index 0000000..48cfe90
--- /dev/null
+++ b/tests/integration/test_connect_command.py
@@ -0,0 +1,119 @@
+from pathlib import Path
+
+import pytest
+from click.testing import CliRunner
+
+from dbt_meshify.main import connect
+from tests.dbt_project_utils import setup_test_project, teardown_test_project
+
+producer_project_path = "test-projects/source-hack/src_proj_a"
+source_consumer_project_path = "test-projects/source-hack/src_proj_b"
+package_consumer_project_path = "test-projects/source-hack/dest_proj_a"
+copy_producer_project_path = producer_project_path + "_copy"
+copy_source_consumer_project_path = source_consumer_project_path + "_copy"
+copy_package_consumer_project_path = package_consumer_project_path + "_copy"
+
+
+@pytest.fixture
+def producer_project():
+ setup_test_project(producer_project_path, copy_producer_project_path)
+ setup_test_project(source_consumer_project_path, copy_source_consumer_project_path)
+ setup_test_project(package_consumer_project_path, copy_package_consumer_project_path)
+ # yield to the test. We'll come back here after the test returns.
+ yield
+
+ teardown_test_project(copy_producer_project_path)
+ teardown_test_project(copy_source_consumer_project_path)
+ teardown_test_project(copy_package_consumer_project_path)
+
+
+class TestConnectCommand:
+ def test_connect_source_hack(self, producer_project):
+ runner = CliRunner()
+ result = runner.invoke(
+ connect,
+ ["--project-paths", copy_producer_project_path, copy_source_consumer_project_path],
+ )
+
+ assert result.exit_code == 0
+
+ # assert that the source is replaced with a ref
+ x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}"
+ source_func = "{{ source('src_proj_a', 'shared_model') }}"
+ child_sql = (
+ Path(copy_source_consumer_project_path) / "models" / "downstream_model.sql"
+ ).read_text()
+ assert x_proj_ref in child_sql
+ assert source_func not in child_sql
+
+ # assert that the source was deleted
+ # may want to add some nuance in the future for cases where we delete a single source out of a set
+ assert not (
+ Path(copy_source_consumer_project_path) / "models" / "staging" / "_sources.yml"
+ ).exists()
+
+ # assert that the dependecies yml was created with a pointer to the upstream project
+ assert (
+ "src_proj_a"
+ in (Path(copy_source_consumer_project_path) / "dependencies.yml").read_text()
+ )
+
+ def test_connect_package(self, producer_project):
+ runner = CliRunner()
+ result = runner.invoke(
+ connect,
+ ["--project-paths", copy_producer_project_path, copy_package_consumer_project_path],
+ )
+
+ assert result.exit_code == 0
+
+ # assert that the source is replaced with a ref
+ x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}"
+ child_sql = (
+ Path(copy_package_consumer_project_path) / "models" / "downstream_model.sql"
+ ).read_text()
+ assert x_proj_ref in child_sql
+
+ # assert that the dependecies yml was created with a pointer to the upstream project
+ assert (
+ "src_proj_a"
+ in (Path(copy_package_consumer_project_path) / "dependencies.yml").read_text()
+ )
+
+ def test_packages_dir_exclude_packages(self):
+ # copy all three packages into a subdirectory to ensure no repeat project names in one directory
+ subdir_producer_project = "test-projects/source-hack/subdir/src_proj_a"
+ subdir_source_consumer_project = "test-projects/source-hack/subdir/src_proj_b"
+ subdir_package_consumer_project = "test-projects/source-hack/subdir/dest_proj_a"
+ setup_test_project(producer_project_path, subdir_producer_project)
+ setup_test_project(source_consumer_project_path, subdir_source_consumer_project)
+ setup_test_project(package_consumer_project_path, subdir_package_consumer_project)
+ runner = CliRunner()
+ result = runner.invoke(
+ connect,
+ [
+ "--projects-dir",
+ "test-projects/source-hack/subdir",
+ "--exclude-projects",
+ "src_proj_b",
+ ],
+ )
+
+ assert result.exit_code == 0
+
+ # assert that the source is replaced with a ref
+ x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}"
+ child_sql = (
+ Path(subdir_package_consumer_project) / "models" / "downstream_model.sql"
+ ).read_text()
+ assert x_proj_ref in child_sql
+
+ # assert that the dependecies yml was created with a pointer to the upstream project
+ assert (
+ "src_proj_a"
+ in (Path(subdir_package_consumer_project) / "dependencies.yml").read_text()
+ )
+
+ teardown_test_project(subdir_producer_project)
+ teardown_test_project(subdir_package_consumer_project)
+ teardown_test_project(subdir_source_consumer_project)
diff --git a/tests/integration/test_dependency_detection.py b/tests/integration/test_dependency_detection.py
index 1aa749e..d9b5ec2 100644
--- a/tests/integration/test_dependency_detection.py
+++ b/tests/integration/test_dependency_detection.py
@@ -50,8 +50,10 @@ def test_linker_detects_source_dependencies(self, src_proj_a, src_proj_b):
assert dependencies == {
ProjectDependency(
- upstream="model.src_proj_a.shared_model",
- downstream="source.src_proj_b.src_proj_a.shared_model",
+ upstream_resource="model.src_proj_a.shared_model",
+ upstream_project_name="src_proj_a",
+ downstream_resource="source.src_proj_b.src_proj_a.shared_model",
+ downstream_project_name="src_proj_b",
type=ProjectDependencyType.Source,
)
}
@@ -64,8 +66,10 @@ def test_linker_detects_source_dependencies_bidirectionally(self, src_proj_a, sr
assert dependencies == {
ProjectDependency(
- upstream="model.src_proj_a.shared_model",
- downstream="source.src_proj_b.src_proj_a.shared_model",
+ upstream_resource="model.src_proj_a.shared_model",
+ upstream_project_name="src_proj_a",
+ downstream_resource="source.src_proj_b.src_proj_a.shared_model",
+ downstream_project_name="src_proj_b",
type=ProjectDependencyType.Source,
)
}
@@ -78,8 +82,10 @@ def test_linker_detects_package_import_dependencies(self, src_proj_a, dest_proj_
assert dependencies == {
ProjectDependency(
- upstream="model.src_proj_a.shared_model",
- downstream="model.src_proj_a.shared_model",
+ upstream_resource="model.src_proj_a.shared_model",
+ upstream_project_name="src_proj_a",
+ downstream_resource="model.dest_proj_a.downstream_model",
+ downstream_project_name="dest_proj_a",
type=ProjectDependencyType.Package,
)
}
diff --git a/tests/integration/test_subproject_creator.py b/tests/integration/test_subproject_creator.py
index 22fc3a1..ddc998e 100644
--- a/tests/integration/test_subproject_creator.py
+++ b/tests/integration/test_subproject_creator.py
@@ -6,7 +6,7 @@
from dbt_meshify.dbt import Dbt
from dbt_meshify.dbt_projects import DbtProject
-from dbt_meshify.storage.dbt_project_creator import DbtSubprojectCreator
+from dbt_meshify.storage.dbt_project_editors import DbtSubprojectCreator
from dbt_meshify.storage.file_content_editors import DbtMeshConstructor
test_project_profile = yaml.safe_load(
@@ -106,7 +106,7 @@ def test_move_yml_entry(self) -> None:
subproject = split_project()
meshify_constructor = get_meshify_constructor(subproject, model_unique_id)
creator = DbtSubprojectCreator(subproject)
- creator.move_resource_yml_entry(meshify_constructor)
+ creator.update_resource_yml_entry(meshify_constructor)
# the original path should still exist, since we take only the single model entry
assert Path("test/models/example/schema.yml").exists()
assert Path("test/subdir/models/example/schema.yml").exists()
diff --git a/tests/unit/test_update_ref_functions.py b/tests/unit/test_update_ref_functions.py
index 4a8d92f..941d150 100644
--- a/tests/unit/test_update_ref_functions.py
+++ b/tests/unit/test_update_ref_functions.py
@@ -38,7 +38,7 @@ def read_yml(yml_str):
class TestRemoveResourceYml:
def test_update_sql_ref_function__basic(self):
- updated_sql = meshify.update_sql_refs(
+ updated_sql = meshify.update_refs__sql(
model_code=simple_model_sql,
model_name=upstream_model_name,
project_name=upstream_project_name,
@@ -46,7 +46,7 @@ def test_update_sql_ref_function__basic(self):
assert updated_sql == expected_simple_model_sql
def test_update_python_ref_function__basic(self):
- updated_python = meshify.update_python_refs(
+ updated_python = meshify.update_refs__python(
model_code=simple_model_python,
model_name=upstream_model_name,
project_name=upstream_project_name,