diff --git a/dbt_meshify/cli.py b/dbt_meshify/cli.py index f4b9f46..7a62408 100644 --- a/dbt_meshify/cli.py +++ b/dbt_meshify/cli.py @@ -12,6 +12,32 @@ help="The path to the dbt project to operate on. Defaults to the current directory.", ) +project_paths = click.option( + "--project-paths", + cls=MultiOption, + multiple=True, + type=tuple, + default=None, + help="The paths to the set of dbt projects to connect. Must supply 2+ paths.", +) + +projects_dir = click.option( + "--projects-dir", + type=click.Path(exists=True), + default=None, + help="The path to a directory containing multiple dbt projects. Directory must contain 2+ projects.", +) + +exclude_projects = click.option( + "--exclude-projects", + "-e", + cls=MultiOption, + multiple=True, + type=tuple, + default=None, + help="The set of dbt projects to exclude from the operation when using the --projects-dir option.", +) + create_path = click.option( "--create-path", type=click.Path(exists=False), diff --git a/dbt_meshify/linker.py b/dbt_meshify/linker.py index 1f49d39..551bdd9 100644 --- a/dbt_meshify/linker.py +++ b/dbt_meshify/linker.py @@ -1,8 +1,14 @@ from dataclasses import dataclass from enum import Enum -from typing import Set +from typing import Set, Union -from dbt_meshify.dbt_projects import BaseDbtProject +from dbt.node_types import AccessType +from loguru import logger + +from dbt_meshify.dbt_projects import BaseDbtProject, DbtProject +from dbt_meshify.exceptions import FatalMeshifyException, FileEditorException +from dbt_meshify.storage.dbt_project_editors import DbtProjectEditor, YMLOperationType +from dbt_meshify.storage.file_content_editors import DbtMeshConstructor class ProjectDependencyType(str, Enum): @@ -16,12 +22,14 @@ class ProjectDependencyType(str, Enum): class ProjectDependency: """ProjectDependencies define shared resources between two different projects""" - upstream: str - downstream: str + upstream_resource: str + upstream_project_name: str + downstream_resource: str + downstream_project_name: str type: ProjectDependencyType def __key(self): - return self.upstream, self.downstream, self.type + return self.upstream_resource, self.downstream_resource, self.type def __hash__(self): return hash(self.__key()) @@ -45,12 +53,14 @@ def _find_relation_dependencies( source_relations: Set[str], target_relations: Set[str] ) -> Set[str]: """ - Identify dependencies between projects using shared relations. + Identify dependencies between projects using shared relation names. """ return source_relations.intersection(target_relations) def _source_dependencies( - self, project: BaseDbtProject, other_project: BaseDbtProject + self, + project: Union[BaseDbtProject, DbtProject], + other_project: Union[BaseDbtProject, DbtProject], ) -> Set[ProjectDependency]: """ Identify source-hack dependencies between projects. @@ -74,8 +84,10 @@ def _source_dependencies( forward_dependencies = { ProjectDependency( - upstream=project.model_relation_names[relation], - downstream=other_project.source_relation_names[relation], + upstream_resource=project.model_relation_names[relation], + upstream_project_name=project.name, + downstream_resource=other_project.source_relation_names[relation], + downstream_project_name=other_project.name, type=ProjectDependencyType.Source, ) for relation in relations @@ -96,8 +108,10 @@ def _source_dependencies( backward_dependencies = { ProjectDependency( - upstream=other_project.model_relation_names[relation], - downstream=project.source_relation_names[relation], + upstream_resource=other_project.model_relation_names[relation], + upstream_project_name=other_project.name, + downstream_resource=project.source_relation_names[relation], + downstream_project_name=project.name, type=ProjectDependencyType.Source, ) for relation in backwards_relations @@ -106,7 +120,9 @@ def _source_dependencies( return forward_dependencies | backward_dependencies def _package_dependencies( - self, project: BaseDbtProject, other_project: BaseDbtProject + self, + project: Union[BaseDbtProject, DbtProject], + other_project: Union[BaseDbtProject, DbtProject], ) -> Set[ProjectDependency]: """ Identify package-imported dependencies between projects. @@ -115,9 +131,13 @@ def _package_dependencies( project (Project B) imports Project A and references the model. """ - if project.project_id not in other_project.installed_packages(): + if ( + project.project_id not in other_project.installed_packages() + and other_project.project_id not in project.installed_packages() + ): return set() + # find which models are in both manifests relations = self._find_relation_dependencies( source_relations={ model.relation_name @@ -131,17 +151,54 @@ def _package_dependencies( }, ) - return { + # find the children of the shared models in the downstream project + package_children = [ + { + 'upstream_resource': project.model_relation_names[relation], + 'downstream_resource': child, + } + for relation in relations + for child in other_project.manifest.child_map[project.model_relation_names[relation]] + ] + + forward_dependencies = { ProjectDependency( - upstream=project.model_relation_names[relation], - downstream=other_project.model_relation_names[relation], + upstream_resource=child['upstream_resource'], + upstream_project_name=project.name, + downstream_resource=child['downstream_resource'], + downstream_project_name=other_project.name, type=ProjectDependencyType.Package, ) + for child in package_children + } + + # find the children of the shared models in the downstream project + backward_package_children = [ + { + 'upstream_resource': other_project.model_relation_names[relation], + 'downstream_resource': child, + } for relation in relations + for child in project.manifest.child_map[other_project.model_relation_names[relation]] + ] + + backward_dependencies = { + ProjectDependency( + upstream_resource=child['upstream_resource'], + upstream_project_name=other_project.name, + downstream_resource=child['downstream_resource'], + downstream_project_name=project.name, + type=ProjectDependencyType.Package, + ) + for child in backward_package_children } + return forward_dependencies | backward_dependencies + def dependencies( - self, project: BaseDbtProject, other_project: BaseDbtProject + self, + project: Union[BaseDbtProject, DbtProject], + other_project: Union[BaseDbtProject, DbtProject], ) -> Set[ProjectDependency]: """Detect dependencies between two projects and return a list of resources shared.""" @@ -156,3 +213,102 @@ def dependencies( dependencies.update(package_dependencies) return dependencies + + def resolve_dependency( + self, + dependency: ProjectDependency, + upstream_project: DbtProject, + downstream_project: DbtProject, + ): + upstream_manifest_entry = upstream_project.get_manifest_node(dependency.upstream_resource) + if not upstream_manifest_entry: + raise ValueError( + f"Could not find upstream resource {dependency.upstream_resource} in project {upstream_project.name}" + ) + downstream_manifest_entry = downstream_project.get_manifest_node( + dependency.downstream_resource + ) + upstream_catalog_entry = upstream_project.get_catalog_entry(dependency.upstream_resource) + upstream_mesh_constructor = DbtMeshConstructor( + project_path=upstream_project.path, + node=upstream_manifest_entry, # type: ignore + catalog=upstream_catalog_entry, + ) + + downstream_mesh_constructor = DbtMeshConstructor( + project_path=downstream_project.path, + node=downstream_manifest_entry, # type: ignore + catalog=None, + ) + downstream_editor = DbtProjectEditor(downstream_project) + # for either dependency type, add contracts and make model public + try: + upstream_mesh_constructor.add_model_access(AccessType.Public) + logger.success( + f"Successfully update model access : {dependency.upstream_resource} is now {AccessType.Public}" + ) + except FatalMeshifyException as e: + logger.error(f"Failed to update model access: {dependency.upstream_resource}") + raise e + try: + upstream_mesh_constructor.add_model_contract() + logger.success(f"Successfully added contract to model: {dependency.upstream_resource}") + except FatalMeshifyException as e: + logger.error(f"Failed to add contract to model: {dependency.upstream_resource}") + raise e + + if dependency.type == ProjectDependencyType.Source: + for child in downstream_project.manifest.child_map[dependency.downstream_resource]: + constructor = DbtMeshConstructor( + project_path=downstream_project.path, + node=downstream_project.get_manifest_node(child), # type: ignore + catalog=None, + ) + try: + constructor.replace_source_with_refs( + source_unique_id=dependency.downstream_resource, + model_unique_id=dependency.upstream_resource, + ) + logger.success( + f"Successfully replaced source function with ref to upstream resource: {dependency.downstream_resource} now calls {dependency.upstream_resource} directly" + ) + except FileEditorException as e: + logger.error( + f"Failed to replace source function with ref to upstream resource" + ) + raise e + try: + downstream_editor.update_resource_yml_entry( + downstream_mesh_constructor, operation_type=YMLOperationType.Delete + ) + logger.success( + f"Successfully deleted unnecessary source: {dependency.downstream_resource}" + ) + except FatalMeshifyException as e: + logger.error(f"Failed to delete unnecessary source") + raise e + + if dependency.type == ProjectDependencyType.Package: + try: + downstream_mesh_constructor.update_model_refs( + model_name=upstream_manifest_entry.name, + project_name=dependency.upstream_project_name, + ) + logger.success( + f"Successfully updated model refs: {dependency.downstream_resource} now references {dependency.upstream_resource}" + ) + except FileEditorException as e: + logger.error(f"Failed to update model refs") + raise e + + # for both types, add upstream project to downstream project's dependencies.yml + try: + downstream_editor.update_dependencies_yml(name=upstream_project.name) + logger.success( + f"Successfully added {dependency.upstream_project_name} to {dependency.downstream_project_name}'s dependencies.yml" + ) + except FileEditorException as e: + logger.error( + f"Failed to add {dependency.upstream_project_name} to {dependency.downstream_project_name}'s dependencies.yml" + ) + raise e diff --git a/dbt_meshify/main.py b/dbt_meshify/main.py index da7ccaf..4999148 100644 --- a/dbt_meshify/main.py +++ b/dbt_meshify/main.py @@ -1,31 +1,36 @@ import os import sys +from itertools import combinations from pathlib import Path -from typing import Optional +from typing import List, Optional import click import yaml from dbt.contracts.graph.unparsed import Owner from loguru import logger -from dbt_meshify.storage.dbt_project_creator import DbtSubprojectCreator +from dbt_meshify.storage.dbt_project_editors import DbtSubprojectCreator from .cli import ( TupleCompatibleCommand, create_path, exclude, + exclude_projects, group_yml_path, owner, owner_email, owner_name, owner_properties, project_path, + project_paths, + projects_dir, read_catalog, select, selector, ) from .dbt_projects import DbtProject, DbtProjectHolder from .exceptions import FatalMeshifyException +from .linker import Linker from .storage.file_content_editors import DbtMeshConstructor log_format = "{time:HH:mm:ss} | {level} | {message}" @@ -48,26 +53,80 @@ def operation(): @cli.command(name="connect") -@click.argument("projects-dir", type=click.Path(exists=True), default=".") -def connect(projects_dir): +@project_paths +@projects_dir +@exclude_projects +@read_catalog +def connect( + project_paths: tuple, projects_dir: Path, exclude_projects: List[str], read_catalog: bool +): """ - !!! info - This command is not yet implemented - Connects multiple dbt projects together by adding all necessary dbt Mesh constructs """ - holder = DbtProjectHolder() - - while True: - path_string = input("Enter the relative path to a dbt project (enter 'done' to finish): ") - if path_string == "done": - break - - path = Path(path_string).expanduser().resolve() - project = DbtProject.from_directory(path, read_catalog) - holder.register_project(project) - - print(holder.project_map()) + if project_paths and projects_dir: + raise click.BadOptionUsage( + option_name="project_paths", + message="Cannot specify both project_paths and projects_dir", + ) + # 1. initialize all the projects supplied to the command + # 2. compute the dependency graph between each combination of 2 projects in that set + # 3. for each dependency, add the necessary dbt Mesh constructs to each project. + # This includes: + # - adding the dependency to the dependencies.yml file of the downstream project + # - adding contracts and public access to the upstream models + # - deleting the source definition of the upstream models in the downstream project + # - updating the `{{ source }}` macro in the downstream project to a {{ ref }} to the upstream project + + linker = Linker() + if project_paths: + dbt_projects = [ + DbtProject.from_directory(project_path, read_catalog) for project_path in project_paths + ] + + if projects_dir: + dbt_project_paths = [path.parent for path in Path(projects_dir).glob("**/dbt_project.yml")] + all_dbt_projects = [ + DbtProject.from_directory(project_path, read_catalog) + for project_path in dbt_project_paths + ] + dbt_projects = [ + project for project in all_dbt_projects if project.name not in exclude_projects + ] + + project_map = {project.name: project for project in dbt_projects} + dbt_project_combinations = [combo for combo in combinations(dbt_projects, 2)] + all_dependencies = set() + for dbt_project_combo in dbt_project_combinations: + dependencies = linker.dependencies(dbt_project_combo[0], dbt_project_combo[1]) + if len(dependencies) == 0: + logger.info( + f"No dependencies found between {dbt_project_combo[0].name} and {dbt_project_combo[1].name}" + ) + continue + + noun = "dependency" if len(dependencies) == 1 else "dependencies" + logger.info( + f"Found {len(dependencies)} {noun} between {dbt_project_combo[0].name} and {dbt_project_combo[1].name}" + ) + all_dependencies.update(dependencies) + if len(all_dependencies) == 0: + logger.info("No dependencies found between any of the projects") + return + + noun = "dependency" if len(all_dependencies) == 1 else "dependencies" + logger.info(f"Found {len(all_dependencies)} unique {noun} between all projects.") + for dependency in all_dependencies: + logger.info( + f"Resolving dependency between {dependency.upstream_resource} and {dependency.downstream_resource}" + ) + try: + linker.resolve_dependency( + dependency, + project_map[dependency.upstream_project_name], + project_map[dependency.downstream_project_name], + ) + except Exception as e: + raise FatalMeshifyException(f"Error resolving dependency : {dependency} {e}") @cli.command( @@ -98,13 +157,13 @@ def split(ctx, project_name, select, exclude, project_path, selector, create_pat create_path = Path(create_path).expanduser().resolve() create_path.parent.mkdir(parents=True, exist_ok=True) - subproject_creator = DbtSubprojectCreator(subproject=subproject, target_directory=create_path) + subproject_creator = DbtSubprojectCreator(project=subproject, target_directory=create_path) logger.info(f"Creating subproject {subproject.name}...") try: subproject_creator.initialize() logger.success(f"Successfully created subproject {subproject.name}") except Exception as e: - raise FatalMeshifyException(f"Error creating subproject {subproject.name}") + raise FatalMeshifyException(f"Error creating subproject {subproject.name}: error {e}") @operation.command(name="add-contract") @@ -230,7 +289,7 @@ def create_group( ) group_owner: Owner = Owner( - name=owner_name, email=owner_email, _extra=yaml.safe_load(owner_properties or '{}') + name=owner_name, email=owner_email, _extra=yaml.safe_load(owner_properties or "{}") ) grouper = ResourceGrouper(project) @@ -279,4 +338,4 @@ def group( Detects the edges of the group, makes their access public, and adds contracts to them """ ctx.forward(create_group) - ctx.invoke(add_contract, select=f'group:{name}', project_path=project_path, public_only=True) + ctx.invoke(add_contract, select=f"group:{name}", project_path=project_path, public_only=True) diff --git a/dbt_meshify/storage/dbt_project_creator.py b/dbt_meshify/storage/dbt_project_editors.py similarity index 77% rename from dbt_meshify/storage/dbt_project_creator.py rename to dbt_meshify/storage/dbt_project_editors.py index e049763..2f09fe3 100644 --- a/dbt_meshify/storage/dbt_project_creator.py +++ b/dbt_meshify/storage/dbt_project_editors.py @@ -1,11 +1,13 @@ +from enum import Enum from pathlib import Path -from typing import Optional, Set +from typing import Optional, Set, Union from dbt.contracts.graph.nodes import ManifestNode +from dbt.contracts.util import Identifier from dbt.node_types import AccessType from loguru import logger -from dbt_meshify.dbt_projects import DbtSubProject +from dbt_meshify.dbt_projects import DbtProject, DbtSubProject from dbt_meshify.storage.file_content_editors import ( DbtMeshConstructor, filter_empty_dict_items, @@ -14,28 +16,107 @@ from dbt_meshify.utilities.grouper import ResourceGrouper -class DbtSubprojectCreator: +class YMLOperationType(str, Enum): + """ProjectDependencyTypes define how the dependency relationship was defined.""" + + Move = "move" + Delete = "delete" + + +class DbtProjectEditor: + def __init__(self, project: Union[DbtSubProject, DbtProject]): + self.project = project + self.file_manager = DbtFileManager( + read_project_path=project.path, + ) + + def move_resource(self, meshify_constructor: DbtMeshConstructor) -> None: + """ + move a resource file from one project to another + + """ + current_path = meshify_constructor.get_resource_path() + self.file_manager.move_file(current_path) + + def copy_resource(self, meshify_constructor: DbtMeshConstructor) -> None: + """ + copy a resource file from one project to another + + """ + resource_path = meshify_constructor.get_resource_path() + contents = self.file_manager.read_file(resource_path) + self.file_manager.write_file(resource_path, contents) + + def update_resource_yml_entry( + self, + meshify_constructor: DbtMeshConstructor, + operation_type: YMLOperationType = YMLOperationType.Move, + ) -> None: + """ + move a resource yml entry from one project to another + """ + current_yml_path = meshify_constructor.get_patch_path() + new_yml_path = self.file_manager.write_project_path / current_yml_path + full_yml_entry = self.file_manager.read_file(current_yml_path) + source_name = ( + meshify_constructor.node.source_name + if hasattr(meshify_constructor.node, "source_name") + else None + ) + resource_entry, remainder = meshify_constructor.get_yml_entry( + resource_name=meshify_constructor.node.name, + full_yml=full_yml_entry, # type: ignore + resource_type=meshify_constructor.node.resource_type, + source_name=source_name, + ) + try: + existing_yml = self.file_manager.read_file(new_yml_path) + except FileNotFoundError: + existing_yml = None + if operation_type == YMLOperationType.Move: + new_yml_contents = meshify_constructor.add_entry_to_yml( + resource_entry, existing_yml, meshify_constructor.node.resource_type # type: ignore + ) + self.file_manager.write_file(current_yml_path, new_yml_contents) + if remainder: + self.file_manager.write_file(current_yml_path, remainder, writeback=True) + else: + self.file_manager.delete_file(current_yml_path) + + def update_dependencies_yml(self, name: Union[str, None] = None) -> None: + try: + contents = self.file_manager.read_file(Path("dependencies.yml")) + except FileNotFoundError: + contents = {"projects": []} + + contents["projects"].append({"name": str(Identifier(name)) if name else self.project.name}) # type: ignore + self.file_manager.write_file(Path("dependencies.yml"), contents, writeback=True) + + +class DbtSubprojectCreator(DbtProjectEditor): """ Takes a `DbtSubProject` and creates the directory structure and files for it. """ - def __init__(self, subproject: DbtSubProject, target_directory: Optional[Path] = None): - self.subproject = subproject - self.target_directory = target_directory if target_directory else subproject.path + def __init__(self, project: DbtSubProject, target_directory: Optional[Path] = None): + if not isinstance(project, DbtSubProject): + raise TypeError(f"DbtSubprojectCreator requires a DbtSubProject, got {type(project)}") + super().__init__(project) + self.target_directory = target_directory if target_directory else project.path self.file_manager = DbtFileManager( - read_project_path=subproject.parent_project.path, + read_project_path=project.parent_project.path, write_project_path=self.target_directory, ) - self.subproject_boundary_models = self._get_subproject_boundary_models() + self.project_boundary_models = self._get_subproject_boundary_models() def _get_subproject_boundary_models(self) -> Set[str]: """ get a set of boundary model unique_ids for all the selected resources """ - nodes = set(filter(lambda x: not x.startswith("source"), self.subproject.resources)) - parent_project_name = self.subproject.parent_project.name + nodes = set(filter(lambda x: not x.startswith("source"), self.project.resources)) # type: ignore + parent_project_name = self.project.parent_project.name # type: ignore interface = ResourceGrouper.identify_interface( - graph=self.subproject.graph.graph, selected_bunch=nodes + graph=self.project.graph.graph, selected_bunch=nodes ) boundary_models = set( filter( @@ -49,7 +130,7 @@ def write_project_file(self) -> None: """ Writes the dbt_project.yml file for the subproject in the specified subdirectory """ - contents = self.subproject.project.to_dict() + contents = self.project.project.to_dict() # was gettinga weird serialization error from ruamel on this value # it's been deprecated, so no reason to keep it contents.pop("version") @@ -71,48 +152,39 @@ def copy_packages_dir(self) -> None: """ raise NotImplementedError("copy_packages_dir not implemented yet") - def update_dependencies_yml(self) -> None: - try: - contents = self.file_manager.read_file(Path("dependencies.yml")) - except FileNotFoundError: - contents = {"projects": []} - - contents["projects"].append({"name": self.subproject.name}) # type: ignore - self.file_manager.write_file(Path("dependencies.yml"), contents, writeback=True) - def update_child_refs(self, resource: ManifestNode) -> None: downstream_models = [ node.unique_id - for node in self.subproject.manifest.nodes.values() + for node in self.project.manifest.nodes.values() if node.resource_type == "model" and resource.unique_id in node.depends_on.nodes # type: ignore ] for model in downstream_models: - model_node = self.subproject.get_manifest_node(model) + model_node = self.project.get_manifest_node(model) if not model_node: raise KeyError(f"Resource {model} not found in manifest") meshify_constructor = DbtMeshConstructor( - project_path=self.subproject.parent_project.path, node=model_node, catalog=None + project_path=self.project.parent_project.path, node=model_node, catalog=None # type: ignore ) meshify_constructor.update_model_refs( - model_name=resource.name, project_name=self.subproject.name + model_name=resource.name, project_name=self.project.name ) def initialize(self) -> None: """Initialize this subproject as a full dbt project at the provided `target_directory`.""" - subproject = self.subproject - for unique_id in subproject.resources | subproject.custom_macros | subproject.groups: + subproject = self.project + for unique_id in subproject.resources | subproject.custom_macros | subproject.groups: # type: ignore resource = subproject.get_manifest_node(unique_id) catalog = subproject.get_catalog_entry(unique_id) if not resource: raise KeyError(f"Resource {unique_id} not found in manifest") meshify_constructor = DbtMeshConstructor( - project_path=subproject.parent_project.path, node=resource, catalog=catalog + project_path=subproject.parent_project.path, node=resource, catalog=catalog # type: ignore ) if resource.resource_type in ["model", "test", "snapshot", "seed"]: # ignore generic tests, as moving the yml entry will move the test too if resource.resource_type == "test" and len(resource.unique_id.split(".")) == 4: continue - if resource.unique_id in self.subproject_boundary_models: + if resource.unique_id in self.project_boundary_models: logger.info( f"Adding contract to and publicizing boundary node {resource.unique_id}" ) @@ -130,7 +202,7 @@ def initialize(self) -> None: # apply access method too logger.info(f"Updating ref functions for children of {resource.unique_id}...") try: - self.update_child_refs(resource) + self.update_child_refs(resource) # type: ignore logger.success( f"Successfully updated ref functions for children of {resource.unique_id}" ) @@ -145,7 +217,7 @@ def initialize(self) -> None: ) try: self.move_resource(meshify_constructor) - self.move_resource_yml_entry(meshify_constructor) + self.update_resource_yml_entry(meshify_constructor) logger.success( f"Successfully moved {resource.unique_id} and associated YML to subproject {subproject.name}" ) @@ -172,7 +244,7 @@ def initialize(self) -> None: f"Moving resource {resource.unique_id} to subproject {subproject.name}..." ) try: - self.move_resource_yml_entry(meshify_constructor) + self.update_resource_yml_entry(meshify_constructor) logger.success( f"Successfully moved resource {resource.unique_id} to subproject {subproject.name}" ) @@ -186,51 +258,3 @@ def initialize(self) -> None: self.copy_packages_yml_file() self.update_dependencies_yml() # self.copy_packages_dir() - - def move_resource(self, meshify_constructor: DbtMeshConstructor) -> None: - """ - move a resource file from one project to another - - """ - current_path = meshify_constructor.get_resource_path() - self.file_manager.move_file(current_path) - - def copy_resource(self, meshify_constructor: DbtMeshConstructor) -> None: - """ - copy a resource file from one project to another - - """ - resource_path = meshify_constructor.get_resource_path() - contents = self.file_manager.read_file(resource_path) - self.file_manager.write_file(resource_path, contents) - - def move_resource_yml_entry(self, meshify_constructor: DbtMeshConstructor) -> None: - """ - move a resource yml entry from one project to another - """ - current_yml_path = meshify_constructor.get_patch_path() - new_yml_path = self.file_manager.write_project_path / current_yml_path - full_yml_entry = self.file_manager.read_file(current_yml_path) - source_name = ( - meshify_constructor.node.source_name - if hasattr(meshify_constructor.node, "source_name") - else None - ) - resource_entry, remainder = meshify_constructor.get_yml_entry( - resource_name=meshify_constructor.node.name, - full_yml=full_yml_entry, # type: ignore - resource_type=meshify_constructor.node.resource_type, - source_name=source_name, - ) - try: - existing_yml = self.file_manager.read_file(new_yml_path) - except FileNotFoundError: - existing_yml = None - new_yml_contents = meshify_constructor.add_entry_to_yml( - resource_entry, existing_yml, meshify_constructor.node.resource_type # type: ignore - ) - self.file_manager.write_file(current_yml_path, new_yml_contents) - if remainder: - self.file_manager.write_file(current_yml_path, remainder, writeback=True) - else: - self.file_manager.delete_file(current_yml_path) diff --git a/dbt_meshify/storage/file_content_editors.py b/dbt_meshify/storage/file_content_editors.py index 4900972..d2831cc 100644 --- a/dbt_meshify/storage/file_content_editors.py +++ b/dbt_meshify/storage/file_content_editors.py @@ -281,7 +281,7 @@ def add_model_version_to_yml( models_yml["models"] = list(models.values()) return models_yml - def update_sql_refs(self, model_code: str, model_name: str, project_name: str): + def update_refs__sql(self, model_code: str, model_name: str, project_name: str): import re # pattern to search for ref() with optional spaces and either single or double quotes @@ -295,7 +295,32 @@ def update_sql_refs(self, model_code: str, model_name: str, project_name: str): return new_code - def update_python_refs(self, model_code: str, model_name: str, project_name: str): + def replace_source_with_ref__sql( + self, model_code: str, source_unique_id: str, model_unique_id: str + ): + import re + + source_parsed = source_unique_id.split(".") + model_parsed = model_unique_id.split(".") + + # pattern to search for source() with optional spaces and either single or double quotes + pattern = re.compile( + r"{{\s*source\s*\(\s*['\"]" + + re.escape(source_parsed[2]) + + r"['\"]\s*,\s*['\"]" + + re.escape(source_parsed[3]) + + r"['\"]\s*\)\s*}}" + ) + + # replacement string with the new format + replacement = f"{{{{ ref('{model_parsed[1]}', '{model_parsed[2]}') }}}}" + + # perform replacement + new_code = re.sub(pattern, replacement, model_code) + + return new_code + + def update_refs__python(self, model_code: str, model_name: str, project_name: str): import re # pattern to search for ref() with optional spaces and either single or double quotes @@ -309,6 +334,31 @@ def update_python_refs(self, model_code: str, model_name: str, project_name: str return new_code + def replace_source_with_ref__python( + self, model_code: str, source_unique_id: str, model_unique_id: str + ): + import re + + source_parsed = source_unique_id.split(".") + model_parsed = model_unique_id.split(".") + + # pattern to search for source() with optional spaces and either single or double quotes + pattern = re.compile( + r"dbt\.source\s*\(\s*['\"]" + + re.escape(source_parsed[2]) + + r"['\"]\s*,\s*['\"]" + + re.escape(source_parsed[3]) + + r"['\"]\s*\)" + ) + + # replacement string with the new format + replacement = f'dbt.ref("{model_parsed[1]}", "{model_parsed[2]}")' + + # perform replacement + new_code = re.sub(pattern, replacement, model_code) + + return new_code + class DbtMeshConstructor(DbtMeshFileEditor): def __init__( @@ -347,7 +397,6 @@ def get_patch_path(self) -> Path: filename = f"_{self.node.resource_type.pluralize()}.yml" yml_path = resource_path.parent / filename self.file_manager.write_file(yml_path, {}) - logger.info(f"Schema entry for {self.node.unique_id} written to {yml_path}") return yml_path def get_resource_path(self) -> Path: @@ -362,6 +411,7 @@ def get_resource_path(self) -> Path: def add_model_contract(self) -> None: """Adds a model contract to the model's yaml""" yml_path = self.get_patch_path() + logger.info(f"Adding contract to {self.node.name} at {yml_path}") # read the yml file # pass empty dict if no file contents returned models_yml = self.file_manager.read_file(yml_path) @@ -382,7 +432,7 @@ def add_model_contract(self) -> None: def add_model_access(self, access_type: AccessType) -> None: """Adds a model contract to the model's yaml""" yml_path = self.get_patch_path() - logger.info(f"Adding model contract for {self.node.name} at {yml_path}") + logger.info(f"Adding {access_type} access to {self.node.name} at {yml_path}") # read the yml file # pass empty dict if no file contents returned models_yml = self.file_manager.read_file(yml_path) @@ -478,7 +528,7 @@ def update_model_refs(self, model_name: str, project_name: str) -> None: # read the model file model_code = str(self.file_manager.read_file(model_path)) # This can be defined in the init for this clas. - ref_update_methods = {'sql': self.update_sql_refs, 'python': self.update_python_refs} + ref_update_methods = {'sql': self.update_refs__sql, 'python': self.update_refs__python} # Here, we're trusting the dbt-core code to check the languages for us. 🐉 updated_code = ref_update_methods[self.node.language]( model_name=model_name, @@ -487,3 +537,28 @@ def update_model_refs(self, model_name: str, project_name: str) -> None: ) # write the updated model code to the file self.file_manager.write_file(model_path, updated_code) + + def replace_source_with_refs(self, source_unique_id: str, model_unique_id: str) -> None: + """Updates the model refs in the model's sql file""" + model_path = self.get_resource_path() + + if model_path is None: + raise ModelFileNotFoundError( + f"Unable to find path to model {self.node.name}. Aborting." + ) + + # read the model file + model_code = str(self.file_manager.read_file(model_path)) + # This can be defined in the init for this clas. + ref_update_methods = { + 'sql': self.replace_source_with_ref__sql, + 'python': self.replace_source_with_ref__python, + } + # Here, we're trusting the dbt-core code to check the languages for us. 🐉 + updated_code = ref_update_methods[self.node.language]( + model_code=model_code, + source_unique_id=source_unique_id, + model_unique_id=model_unique_id, + ) + # write the updated model code to the file + self.file_manager.write_file(model_path, updated_code) diff --git a/dbt_meshify/storage/file_manager.py b/dbt_meshify/storage/file_manager.py index 850b4a4..137085b 100644 --- a/dbt_meshify/storage/file_manager.py +++ b/dbt_meshify/storage/file_manager.py @@ -5,8 +5,10 @@ from pathlib import Path from typing import Any, Dict, Optional, Union +from dbt.contracts.util import Identifier from ruamel.yaml import YAML from ruamel.yaml.compat import StringIO +from ruamel.yaml.representer import Representer class DbtYAML(YAML): @@ -29,7 +31,7 @@ def dump(self, data, stream=None, **kw): yaml = DbtYAML() - +yaml.register_class(Identifier) FileContent = Union[Dict[str, str], str] @@ -55,7 +57,7 @@ def __init__( self.write_project_path = write_project_path if write_project_path else read_project_path def read_file(self, path: Path) -> Union[Dict[str, Any], str]: - """Returns the yaml for a model in the dbt project's manifest""" + """Returns the file contents at a given path""" full_path = self.read_project_path / path if full_path.suffix == ".yml": return yaml.load(full_path.read_text()) @@ -96,4 +98,5 @@ def move_file(self, path: Path) -> None: def delete_file(self, path: Path) -> None: """deletes the specified file""" - path.unlink() + delete_path = self.read_project_path / path + delete_path.unlink() diff --git a/tests/integration/test_connect_command.py b/tests/integration/test_connect_command.py new file mode 100644 index 0000000..48cfe90 --- /dev/null +++ b/tests/integration/test_connect_command.py @@ -0,0 +1,119 @@ +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from dbt_meshify.main import connect +from tests.dbt_project_utils import setup_test_project, teardown_test_project + +producer_project_path = "test-projects/source-hack/src_proj_a" +source_consumer_project_path = "test-projects/source-hack/src_proj_b" +package_consumer_project_path = "test-projects/source-hack/dest_proj_a" +copy_producer_project_path = producer_project_path + "_copy" +copy_source_consumer_project_path = source_consumer_project_path + "_copy" +copy_package_consumer_project_path = package_consumer_project_path + "_copy" + + +@pytest.fixture +def producer_project(): + setup_test_project(producer_project_path, copy_producer_project_path) + setup_test_project(source_consumer_project_path, copy_source_consumer_project_path) + setup_test_project(package_consumer_project_path, copy_package_consumer_project_path) + # yield to the test. We'll come back here after the test returns. + yield + + teardown_test_project(copy_producer_project_path) + teardown_test_project(copy_source_consumer_project_path) + teardown_test_project(copy_package_consumer_project_path) + + +class TestConnectCommand: + def test_connect_source_hack(self, producer_project): + runner = CliRunner() + result = runner.invoke( + connect, + ["--project-paths", copy_producer_project_path, copy_source_consumer_project_path], + ) + + assert result.exit_code == 0 + + # assert that the source is replaced with a ref + x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}" + source_func = "{{ source('src_proj_a', 'shared_model') }}" + child_sql = ( + Path(copy_source_consumer_project_path) / "models" / "downstream_model.sql" + ).read_text() + assert x_proj_ref in child_sql + assert source_func not in child_sql + + # assert that the source was deleted + # may want to add some nuance in the future for cases where we delete a single source out of a set + assert not ( + Path(copy_source_consumer_project_path) / "models" / "staging" / "_sources.yml" + ).exists() + + # assert that the dependecies yml was created with a pointer to the upstream project + assert ( + "src_proj_a" + in (Path(copy_source_consumer_project_path) / "dependencies.yml").read_text() + ) + + def test_connect_package(self, producer_project): + runner = CliRunner() + result = runner.invoke( + connect, + ["--project-paths", copy_producer_project_path, copy_package_consumer_project_path], + ) + + assert result.exit_code == 0 + + # assert that the source is replaced with a ref + x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}" + child_sql = ( + Path(copy_package_consumer_project_path) / "models" / "downstream_model.sql" + ).read_text() + assert x_proj_ref in child_sql + + # assert that the dependecies yml was created with a pointer to the upstream project + assert ( + "src_proj_a" + in (Path(copy_package_consumer_project_path) / "dependencies.yml").read_text() + ) + + def test_packages_dir_exclude_packages(self): + # copy all three packages into a subdirectory to ensure no repeat project names in one directory + subdir_producer_project = "test-projects/source-hack/subdir/src_proj_a" + subdir_source_consumer_project = "test-projects/source-hack/subdir/src_proj_b" + subdir_package_consumer_project = "test-projects/source-hack/subdir/dest_proj_a" + setup_test_project(producer_project_path, subdir_producer_project) + setup_test_project(source_consumer_project_path, subdir_source_consumer_project) + setup_test_project(package_consumer_project_path, subdir_package_consumer_project) + runner = CliRunner() + result = runner.invoke( + connect, + [ + "--projects-dir", + "test-projects/source-hack/subdir", + "--exclude-projects", + "src_proj_b", + ], + ) + + assert result.exit_code == 0 + + # assert that the source is replaced with a ref + x_proj_ref = "{{ ref('src_proj_a', 'shared_model') }}" + child_sql = ( + Path(subdir_package_consumer_project) / "models" / "downstream_model.sql" + ).read_text() + assert x_proj_ref in child_sql + + # assert that the dependecies yml was created with a pointer to the upstream project + assert ( + "src_proj_a" + in (Path(subdir_package_consumer_project) / "dependencies.yml").read_text() + ) + + teardown_test_project(subdir_producer_project) + teardown_test_project(subdir_package_consumer_project) + teardown_test_project(subdir_source_consumer_project) diff --git a/tests/integration/test_dependency_detection.py b/tests/integration/test_dependency_detection.py index 1aa749e..d9b5ec2 100644 --- a/tests/integration/test_dependency_detection.py +++ b/tests/integration/test_dependency_detection.py @@ -50,8 +50,10 @@ def test_linker_detects_source_dependencies(self, src_proj_a, src_proj_b): assert dependencies == { ProjectDependency( - upstream="model.src_proj_a.shared_model", - downstream="source.src_proj_b.src_proj_a.shared_model", + upstream_resource="model.src_proj_a.shared_model", + upstream_project_name="src_proj_a", + downstream_resource="source.src_proj_b.src_proj_a.shared_model", + downstream_project_name="src_proj_b", type=ProjectDependencyType.Source, ) } @@ -64,8 +66,10 @@ def test_linker_detects_source_dependencies_bidirectionally(self, src_proj_a, sr assert dependencies == { ProjectDependency( - upstream="model.src_proj_a.shared_model", - downstream="source.src_proj_b.src_proj_a.shared_model", + upstream_resource="model.src_proj_a.shared_model", + upstream_project_name="src_proj_a", + downstream_resource="source.src_proj_b.src_proj_a.shared_model", + downstream_project_name="src_proj_b", type=ProjectDependencyType.Source, ) } @@ -78,8 +82,10 @@ def test_linker_detects_package_import_dependencies(self, src_proj_a, dest_proj_ assert dependencies == { ProjectDependency( - upstream="model.src_proj_a.shared_model", - downstream="model.src_proj_a.shared_model", + upstream_resource="model.src_proj_a.shared_model", + upstream_project_name="src_proj_a", + downstream_resource="model.dest_proj_a.downstream_model", + downstream_project_name="dest_proj_a", type=ProjectDependencyType.Package, ) } diff --git a/tests/integration/test_subproject_creator.py b/tests/integration/test_subproject_creator.py index 22fc3a1..ddc998e 100644 --- a/tests/integration/test_subproject_creator.py +++ b/tests/integration/test_subproject_creator.py @@ -6,7 +6,7 @@ from dbt_meshify.dbt import Dbt from dbt_meshify.dbt_projects import DbtProject -from dbt_meshify.storage.dbt_project_creator import DbtSubprojectCreator +from dbt_meshify.storage.dbt_project_editors import DbtSubprojectCreator from dbt_meshify.storage.file_content_editors import DbtMeshConstructor test_project_profile = yaml.safe_load( @@ -106,7 +106,7 @@ def test_move_yml_entry(self) -> None: subproject = split_project() meshify_constructor = get_meshify_constructor(subproject, model_unique_id) creator = DbtSubprojectCreator(subproject) - creator.move_resource_yml_entry(meshify_constructor) + creator.update_resource_yml_entry(meshify_constructor) # the original path should still exist, since we take only the single model entry assert Path("test/models/example/schema.yml").exists() assert Path("test/subdir/models/example/schema.yml").exists() diff --git a/tests/unit/test_update_ref_functions.py b/tests/unit/test_update_ref_functions.py index 4a8d92f..941d150 100644 --- a/tests/unit/test_update_ref_functions.py +++ b/tests/unit/test_update_ref_functions.py @@ -38,7 +38,7 @@ def read_yml(yml_str): class TestRemoveResourceYml: def test_update_sql_ref_function__basic(self): - updated_sql = meshify.update_sql_refs( + updated_sql = meshify.update_refs__sql( model_code=simple_model_sql, model_name=upstream_model_name, project_name=upstream_project_name, @@ -46,7 +46,7 @@ def test_update_sql_ref_function__basic(self): assert updated_sql == expected_simple_model_sql def test_update_python_ref_function__basic(self): - updated_python = meshify.update_python_refs( + updated_python = meshify.update_refs__python( model_code=simple_model_python, model_name=upstream_model_name, project_name=upstream_project_name,