Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MacroResolverProtocol, remove lazy loading of manifest in adapter.execute_macro #9243

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231207-111554.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add MacroResolverProtocol, remove lazy loading of manifest in adapter.execute_macro
time: 2023-12-07T11:15:54.427818+09:00
custom:
Author: michelleark
Issue: "9244"
89 changes: 31 additions & 58 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
Expand All @@ -19,6 +18,7 @@
TypedDict,
Union,
FrozenSet,
Iterable,
)
from multiprocessing.context import SpawnContext

Expand All @@ -28,6 +28,7 @@
ConstraintType,
ModelLevelConstraint,
)
from dbt.adapters.contracts.macros import MacroResolverProtocol

import agate
import pytz
Expand Down Expand Up @@ -62,7 +63,6 @@
Integer,
)
from dbt.common.clients.jinja import CallableMacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.common.events.functions import fire_event, warn_or_error
from dbt.adapters.events.types import (
CacheMiss,
Expand Down Expand Up @@ -257,7 +257,20 @@
self.config = config
self.cache = RelationsCache(log_cache_events=config.log_cache_events)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_manifest_lazy: Optional[MacroManifest] = None
self._macro_resolver: Optional[MacroResolverProtocol] = None

###
# Methods to set / access a macro resolver
###
def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
self._macro_resolver = macro_resolver

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
return self._macro_resolver

def clear_macro_resolver(self) -> None:
if self._macro_resolver is not None:
self._macro_resolver = None

Check warning on line 273 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L272-L273

Added lines #L272 - L273 were not covered by tests

###
# Methods that pass through to the connection manager
Expand Down Expand Up @@ -370,39 +383,6 @@
"""
return cls.ConnectionManager.TYPE

@property
def _macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
return self.load_macro_manifest()
return self._macro_manifest_lazy

def check_macro_manifest(self) -> Optional[MacroManifest]:
"""Return the internal manifest (used for executing macros) if it's
been initialized, otherwise return None.
"""
return self._macro_manifest_lazy

def load_macro_manifest(self, base_macros_only=False) -> MacroManifest:
# base_macros_only is for the test framework
if self._macro_manifest_lazy is None:
# avoid a circular import
from dbt.parser.manifest import ManifestLoader

manifest = ManifestLoader.load_macros(
self.config,
self.connections.set_query_header,
base_macros_only=base_macros_only,
)
# TODO CT-211
self._macro_manifest_lazy = manifest # type: ignore[assignment]
# TODO CT-211
return self._macro_manifest_lazy # type: ignore[return-value]

def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None:
self._macro_manifest_lazy = None

###
# Caching methods
###
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
Expand Down Expand Up @@ -1052,11 +1032,10 @@
def execute_macro(
self,
macro_name: str,
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
text_only_columns: Optional[Iterable[str]] = None,
) -> AttrDict:
"""Look macro_name up in the manifest and execute its results.

Expand All @@ -1076,13 +1055,11 @@
if context_override is None:
context_override = {}

if manifest is None:
# TODO CT-211
manifest = self._macro_manifest # type: ignore[assignment]
# TODO CT-211
macro = manifest.find_macro_by_name( # type: ignore[union-attr]
macro_name, self.config.project_name, project
)
resolver = macro_resolver or self._macro_resolver
if resolver is None:
raise DbtInternalError("macro resolver was None when calling execute_macro!")

Check warning on line 1060 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1060

Added line #L1060 was not covered by tests

macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project)
if macro is None:
if project is None:
package_name = "any package"
Expand All @@ -1102,7 +1079,7 @@
# TODO CT-211
macro=macro,
config=self.config,
manifest=manifest, # type: ignore[arg-type]
manifest=resolver, # type: ignore[arg-type]
package_name=project,
)
macro_context.update(context_override)
Expand Down Expand Up @@ -1135,10 +1112,7 @@
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
)
table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results
Expand All @@ -1154,10 +1128,7 @@
"information_schema": information_schema,
"relations": relations,
}
table = self.execute_macro(
GET_CATALOG_RELATIONS_MACRO_NAME,
kwargs=kwargs,
)
table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results
Expand Down Expand Up @@ -1258,7 +1229,7 @@
source: BaseRelation,
loaded_at_field: str,
filter: Optional[str],
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
"""Calculate the freshness of sources in dbt, and return it"""
kwargs: Dict[str, Any] = {
Expand All @@ -1274,7 +1245,9 @@
AttrDict, # current: contains AdapterResponse + agate.Table
agate.Table, # previous: just table
]
result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest)
result = self.execute_macro(

Check warning on line 1248 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1248

Added line #L1248 was not covered by tests
FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
if isinstance(result, agate.Table):
warn_or_error(CollectFreshnessReturnSignature())
adapter_response = None
Expand Down Expand Up @@ -1304,14 +1277,14 @@
def calculate_freshness_from_metadata(
self,
source: BaseRelation,
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
kwargs: Dict[str, Any] = {
"information_schema": source.information_schema_only(),
"relations": [source],
}
result = self.execute_macro(
GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, manifest=manifest
GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
adapter_response, table = result.response, result.table # type: ignore[attr-defined]

Expand Down
11 changes: 11 additions & 0 deletions core/dbt/adapters/contracts/macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional
from typing_extensions import Protocol

from dbt.common.clients.jinja import MacroProtocol


class MacroResolverProtocol(Protocol):
def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
) -> Optional[MacroProtocol]:
raise NotImplementedError("find_macro_by_name not implemented")

Check warning on line 11 in core/dbt/adapters/contracts/macros.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/contracts/macros.py#L11

Added line #L11 was not covered by tests
10 changes: 10 additions & 0 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import agate

from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig
from dbt.contracts.graph.manifest import Manifest
Expand Down Expand Up @@ -66,6 +67,15 @@
def __init__(self, config: AdapterRequiredConfig) -> None:
...

def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
...

Check warning on line 71 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L71

Added line #L71 was not covered by tests

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
...

Check warning on line 74 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L74

Added line #L74 was not covered by tests

def clear_macro_resolver(self) -> None:
...

Check warning on line 77 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L77

Added line #L77 was not covered by tests

@classmethod
def type(cls) -> str:
pass
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Protocol

from dbt.adapters.base.column import Column
from dbt.common.clients.jinja import MacroProtocol
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
Expand Down Expand Up @@ -1363,7 +1364,7 @@ class MacroContext(ProviderContext):

def __init__(
self,
model: Macro,
model: MacroProtocol,
config: RuntimeConfig,
manifest: Manifest,
provider: Provider,
Expand Down Expand Up @@ -1520,7 +1521,7 @@ def generate_runtime_model_context(


def generate_runtime_macro_context(
macro: Macro,
macro: MacroProtocol,
config: RuntimeConfig,
manifest: Manifest,
package_name: Optional[str],
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
# the config and adapter may be persistent.
if reset:
config.clear_dependencies()
adapter.clear_macro_manifest()
adapter.clear_macro_resolver()

Check warning on line 289 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L289

Added line #L289 was not covered by tests
macro_hook = adapter.connections.set_query_header

flags = get_flags()
Expand Down Expand Up @@ -1000,7 +1000,7 @@

def save_macros_to_adapter(self, adapter):
macro_manifest = MacroManifest(self.manifest.macros)
adapter._macro_manifest_lazy = macro_manifest
adapter.set_macro_resolver(macro_manifest)
# This executes the callable macro_hook and sets the
# query headers
self.macro_hook(macro_manifest)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/freshness.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def execute(self, compiled_node, manifest):

adapter_response, freshness = self.adapter.calculate_freshness_from_metadata(
relation,
manifest=manifest,
macro_resolver=manifest,
)

status = compiled_node.freshness.status(freshness["age"])
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table:
with adapter.connection_named("macro_{}".format(macro_name)):
adapter.clear_transaction()
res = adapter.execute_macro(
macro_name, project=package_name, kwargs=macro_kwargs, manifest=self.manifest
macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest
Copy link
Contributor Author

@MichelleArk MichelleArk Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be possible to remove the macro_resolver usage entirely but I'd like to do that in a follow-on to keep this PR more contained to a strict refactor.

)

return res
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def execute(self, compiled_node, manifest):
model_context = generate_runtime_model_context(compiled_node, self.config, manifest)
compiled_node.compiled_code = self.adapter.execute_macro(
macro_name="get_show_sql",
manifest=manifest,
macro_resolver=manifest,
context_override=model_context,
kwargs={
"compiled_code": model_context["compiled_code"],
Expand Down
17 changes: 16 additions & 1 deletion core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import yaml

from dbt.parser.manifest import ManifestLoader
from dbt.common.exceptions import CompilationError, DbtDatabaseError
import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig
Expand Down Expand Up @@ -289,7 +290,13 @@ def adapter(
adapter = get_adapter(runtime_config)
# We only need the base macros, not macros from dependencies, and don't want
# to run 'dbt deps' here.
adapter.load_macro_manifest(base_macros_only=True)
manifest = ManifestLoader.load_macros(
runtime_config,
adapter.connections.set_query_header,
base_macros_only=True,
)

adapter.set_macro_resolver(manifest)
yield adapter
adapter.cleanup_connections()
reset_adapters()
Expand Down Expand Up @@ -450,6 +457,14 @@ def create_test_schema(self, schema_name=None):

# Drop the unique test schema, usually called in test cleanup
def drop_test_schema(self):
if self.adapter.get_macro_resolver() is None:
manifest = ManifestLoader.load_macros(
self.adapter.config,
self.adapter.connections.set_query_header,
base_macros_only=True,
)
self.adapter.set_macro_resolver(manifest)

with get_connection(self.adapter):
for schema_name in self.created_schemas:
relation = self.adapter.Relation.create(database=self.database, schema=schema_name)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ def _mock_state_check(self):

self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config))
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter._macro_manifest_lazy
self.config, self.adapter.get_macro_resolver()
)

self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")
Expand Down
Loading