Skip to content

Commit

Permalink
Remove unnecessary ConfigSource
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Dec 3, 2024
1 parent 13105cd commit 3f7ee0e
Showing 1 changed file with 52 additions and 76 deletions.
128 changes: 52 additions & 76 deletions core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,73 +27,6 @@ class ModelParts(IsFQNResource):
C = TypeVar("C", bound=BaseConfig)


class ConfigSource:
def __init__(self, project):
self.project = project

def get_config_dict(self, resource_type: NodeType): ...


class UnrenderedConfig(ConfigSource):
def __init__(self, project: Project):
self.project = project

def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
unrendered = self.project.unrendered.project_dict
if resource_type == NodeType.Seed:
model_configs = unrendered.get("seeds")
elif resource_type == NodeType.Snapshot:
model_configs = unrendered.get("snapshots")
elif resource_type == NodeType.Source:
model_configs = unrendered.get("sources")
elif resource_type == NodeType.Test:
model_configs = unrendered.get("data_tests")
elif resource_type == NodeType.Metric:
model_configs = unrendered.get("metrics")
elif resource_type == NodeType.SemanticModel:
model_configs = unrendered.get("semantic_models")
elif resource_type == NodeType.SavedQuery:
model_configs = unrendered.get("saved_queries")
elif resource_type == NodeType.Exposure:
model_configs = unrendered.get("exposures")
elif resource_type == NodeType.Unit:
model_configs = unrendered.get("unit_tests")
else:
model_configs = unrendered.get("models")
if model_configs is None:
return {}
else:
return model_configs


class RenderedConfig(ConfigSource):
def __init__(self, project: Project):
self.project = project

def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
if resource_type == NodeType.Seed:
model_configs = self.project.seeds
elif resource_type == NodeType.Snapshot:
model_configs = self.project.snapshots
elif resource_type == NodeType.Source:
model_configs = self.project.sources
elif resource_type == NodeType.Test:
model_configs = self.project.data_tests
elif resource_type == NodeType.Metric:
model_configs = self.project.metrics
elif resource_type == NodeType.SemanticModel:
model_configs = self.project.semantic_models
elif resource_type == NodeType.SavedQuery:
model_configs = self.project.saved_queries
elif resource_type == NodeType.Exposure:
model_configs = self.project.exposures
elif resource_type == NodeType.Unit:
model_configs = self.project.unit_tests
else:
model_configs = self.project.models
return model_configs


def fix_hooks(config_dict: Dict[str, Any]):
"""Given a config dict that may have `pre-hook`/`post-hook` keys,
convert it from the yucky maybe-a-string, maybe-a-dict to a dict.
Expand All @@ -108,9 +41,6 @@ class BaseContextConfigGenerator(Generic[T]):
def __init__(self, active_project: RuntimeConfig):
self._active_project = active_project

def get_config_source(self, project: Project) -> ConfigSource:
return RenderedConfig(project)

def get_node_project_config(self, project_name: str):
if project_name == self._active_project.project_name:
return self._active_project
Expand All @@ -125,8 +55,7 @@ def get_node_project_config(self, project_name: str):
def _project_configs(
self, project: Project, fqn: List[str], resource_type: NodeType
) -> Iterator[Dict[str, Any]]:
src = self.get_config_source(project)
model_configs = src.get_config_dict(resource_type)
model_configs = self.get_model_configs(project, resource_type)
for level_config in fqn_search(model_configs, fqn):
result = {}
for key, value in level_config.items():
Expand All @@ -142,6 +71,9 @@ def _active_project_configs(
) -> Iterator[Dict[str, Any]]:
return self._project_configs(self._active_project, fqn, resource_type)

@abstractmethod
def get_model_configs(self, project: Project, resource_type: NodeType) -> Dict[str, Any]: ...

@abstractmethod
def merge_config_dicts(
self,
Expand Down Expand Up @@ -175,8 +107,28 @@ class ContextConfigGenerator(BaseContextConfigGenerator[C]):
def __init__(self, active_project: RuntimeConfig):
self._active_project = active_project

def get_config_source(self, project: Project) -> ConfigSource:
return RenderedConfig(project)
def get_model_configs(self, project: Project, resource_type: NodeType) -> Dict[str, Any]:
if resource_type == NodeType.Seed:
model_configs = project.seeds
elif resource_type == NodeType.Snapshot:
model_configs = project.snapshots
elif resource_type == NodeType.Source:
model_configs = project.sources
elif resource_type == NodeType.Test:
model_configs = project.data_tests
elif resource_type == NodeType.Metric:
model_configs = project.metrics
elif resource_type == NodeType.SemanticModel:
model_configs = project.semantic_models
elif resource_type == NodeType.SavedQuery:
model_configs = project.saved_queries
elif resource_type == NodeType.Exposure:
model_configs = project.exposures
elif resource_type == NodeType.Unit:
model_configs = project.unit_tests
else:
model_configs = project.models
return model_configs

def merge_config_dicts(
self,
Expand Down Expand Up @@ -279,8 +231,32 @@ def generate_node_config(


class UnrenderedConfigGenerator(BaseContextConfigGenerator[Dict[str, Any]]):
def get_config_source(self, project: Project) -> ConfigSource:
return UnrenderedConfig(project)
def get_model_configs(self, project: Project, resource_type: NodeType) -> Dict[str, Any]:
unrendered = project.unrendered.project_dict
if resource_type == NodeType.Seed:
model_configs = unrendered.get("seeds")
elif resource_type == NodeType.Snapshot:
model_configs = unrendered.get("snapshots")
elif resource_type == NodeType.Source:
model_configs = unrendered.get("sources")
elif resource_type == NodeType.Test:
model_configs = unrendered.get("data_tests")
elif resource_type == NodeType.Metric:
model_configs = unrendered.get("metrics")
elif resource_type == NodeType.SemanticModel:
model_configs = unrendered.get("semantic_models")
elif resource_type == NodeType.SavedQuery:
model_configs = unrendered.get("saved_queries")
elif resource_type == NodeType.Exposure:
model_configs = unrendered.get("exposures")
elif resource_type == NodeType.Unit:
model_configs = unrendered.get("unit_tests")
else:
model_configs = unrendered.get("models")
if model_configs is None:
return {}
else:
return model_configs

def merge_config_dicts(
self,
Expand Down

0 comments on commit 3f7ee0e

Please sign in to comment.