diff --git a/.changes/unreleased/Fixes-20230626-115838.yaml b/.changes/unreleased/Fixes-20230626-115838.yaml new file mode 100644 index 00000000000..03f07c4237d --- /dev/null +++ b/.changes/unreleased/Fixes-20230626-115838.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Move project_root contextvar into events.contextvars +time: 2023-06-26T11:58:38.965299-04:00 +custom: + Author: gshank + Issue: "7937" diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index d3e113bf606..aca875ea2f1 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -45,7 +45,7 @@ SeedExceedsLimitAndPathChanged, SeedExceedsLimitChecksumChanged, ) -from dbt.events.contextvars import set_contextvars +from dbt.events.contextvars import set_log_contextvars from dbt.flags import get_flags from dbt.node_types import ModelLanguage, NodeType, AccessType from dbt_semantic_interfaces.references import ( @@ -327,7 +327,7 @@ def node_info(self): def update_event_status(self, **kwargs): for k, v in kwargs.items(): self._event_status[k] = v - set_contextvars(node_info=self.node_info) + set_log_contextvars(node_info=self.node_info) def clear_event_status(self): self._event_status = dict() diff --git a/core/dbt/events/contextvars.py b/core/dbt/events/contextvars.py index 8688a992ee4..5bdb78fe4e2 100644 --- a/core/dbt/events/contextvars.py +++ b/core/dbt/events/contextvars.py @@ -5,48 +5,65 @@ LOG_PREFIX = "log_" -LOG_PREFIX_LEN = len(LOG_PREFIX) +TASK_PREFIX = "task_" -_log_context_vars: Dict[str, contextvars.ContextVar] = {} +_context_vars: Dict[str, contextvars.ContextVar] = {} -def get_contextvars() -> Dict[str, Any]: +def get_contextvars(prefix: str) -> Dict[str, Any]: rv = {} ctx = contextvars.copy_context() + prefix_len = len(prefix) for k in ctx: - if k.name.startswith(LOG_PREFIX) and ctx[k] is not Ellipsis: - rv[k.name[LOG_PREFIX_LEN:]] = ctx[k] + if k.name.startswith(prefix) and ctx[k] is not Ellipsis: + rv[k.name[prefix_len:]] = ctx[k] return rv def get_node_info(): - cvars = get_contextvars() + cvars = get_contextvars(LOG_PREFIX) if "node_info" in cvars: return cvars["node_info"] else: return {} -def clear_contextvars() -> None: +def get_project_root(): + cvars = get_contextvars(TASK_PREFIX) + if "project_root" in cvars: + return cvars["project_root"] + else: + return None + + +def clear_contextvars(prefix: str) -> None: ctx = contextvars.copy_context() for k in ctx: - if k.name.startswith(LOG_PREFIX): + if k.name.startswith(prefix): k.set(Ellipsis) +def set_log_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: + return set_contextvars(LOG_PREFIX, **kwargs) + + +def set_task_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: + return set_contextvars(TASK_PREFIX, **kwargs) + + # put keys and values into context. Returns the contextvar.Token mapping # Save and pass to reset_contextvars -def set_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: +def set_contextvars(prefix: str, **kwargs: Any) -> Mapping[str, contextvars.Token]: cvar_tokens = {} for k, v in kwargs.items(): - log_key = f"{LOG_PREFIX}{k}" + log_key = f"{prefix}{k}" try: - var = _log_context_vars[log_key] + var = _context_vars[log_key] except KeyError: var = contextvars.ContextVar(log_key, default=Ellipsis) - _log_context_vars[log_key] = var + _context_vars[log_key] = var cvar_tokens[k] = var.set(v) @@ -54,30 +71,44 @@ def set_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]: # reset by Tokens -def reset_contextvars(**kwargs: contextvars.Token) -> None: +def reset_contextvars(prefix: str, **kwargs: contextvars.Token) -> None: for k, v in kwargs.items(): - log_key = f"{LOG_PREFIX}{k}" - var = _log_context_vars[log_key] + log_key = f"{prefix}{k}" + var = _context_vars[log_key] var.reset(v) # remove from contextvars -def unset_contextvars(*keys: str) -> None: +def unset_contextvars(prefix: str, *keys: str) -> None: for k in keys: - if k in _log_context_vars: - log_key = f"{LOG_PREFIX}{k}" - _log_context_vars[log_key].set(Ellipsis) + if k in _context_vars: + log_key = f"{prefix}{k}" + _context_vars[log_key].set(Ellipsis) # Context manager or decorator to set and unset the context vars @contextlib.contextmanager def log_contextvars(**kwargs: Any) -> Generator[None, None, None]: - context = get_contextvars() + context = get_contextvars(LOG_PREFIX) + saved = {k: context[k] for k in context.keys() & kwargs.keys()} + + set_contextvars(LOG_PREFIX, **kwargs) + try: + yield + finally: + unset_contextvars(LOG_PREFIX, *kwargs.keys()) + set_contextvars(LOG_PREFIX, **saved) + + +# Context manager for earlier in task.run +@contextlib.contextmanager +def task_contextvars(**kwargs: Any) -> Generator[None, None, None]: + context = get_contextvars(TASK_PREFIX) saved = {k: context[k] for k in context.keys() & kwargs.keys()} - set_contextvars(**kwargs) + set_contextvars(TASK_PREFIX, **kwargs) try: yield finally: - unset_contextvars(*kwargs.keys()) - set_contextvars(**saved) + unset_contextvars(TASK_PREFIX, *kwargs.keys()) + set_contextvars(TASK_PREFIX, **saved) diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index d35e0c21aff..182189401f9 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -26,7 +26,7 @@ DbtRuntimeError, ) from dbt.node_types import NodeType -from dbt.task.contextvars import cv_project_root +from dbt.events.contextvars import get_project_root SELECTOR_GLOB = "*" @@ -326,7 +326,11 @@ class PathSelectorMethod(SelectorMethod): def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]: """Yields nodes from included that match the given path.""" # get project root from contextvar - root = Path(cv_project_root.get()) + project_root = get_project_root() + if project_root: + root = Path(project_root) + else: + root = Path.cwd() paths = set(p.relative_to(root) for p in root.glob(selector)) for node, real_node in self.all_nodes(included_nodes): ofp = Path(real_node.original_file_path) diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 1e28a91ef3f..a7ec1e046db 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -45,7 +45,6 @@ from dbt.graph import Graph from dbt.logger import log_manager from .printer import print_run_result_error -from dbt.task.contextvars import cv_project_root class NoneConfig: @@ -76,8 +75,6 @@ def __init__(self, args, config, project=None): self.args = args self.config = config self.project = config if isinstance(config, Project) else project - if self.config: - cv_project_root.set(self.config.project_root) @classmethod def pre_init_hook(cls, args): diff --git a/core/dbt/task/contextvars.py b/core/dbt/task/contextvars.py deleted file mode 100644 index 6524b0935d1..00000000000 --- a/core/dbt/task/contextvars.py +++ /dev/null @@ -1,6 +0,0 @@ -from contextvars import ContextVar - -# This is a place to hold common contextvars used in tasks so that we can -# avoid circular imports. - -cv_project_root: ContextVar = ContextVar("project_root") diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index b1a74fd1126..ef18e19d313 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -36,7 +36,7 @@ EndRunResult, NothingToDo, ) -from dbt.events.contextvars import log_contextvars +from dbt.events.contextvars import log_contextvars, task_contextvars from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus from dbt.contracts.state import PreviousState @@ -430,25 +430,30 @@ def run(self): """ Run dbt for the query, based on the graph. """ - self._runtime_initialize() + # We set up a context manager here with "task_contextvars" because we + # we need the project_root in runtime_initialize. + with task_contextvars(project_root=self.config.project_root): + self._runtime_initialize() - if self._flattened_nodes is None: - raise DbtInternalError("after _runtime_initialize, _flattened_nodes was still None") + if self._flattened_nodes is None: + raise DbtInternalError( + "after _runtime_initialize, _flattened_nodes was still None" + ) - if len(self._flattened_nodes) == 0: - with TextOnly(): - fire_event(Formatting("")) - warn_or_error(NothingToDo()) - result = self.get_result( - results=[], - generated_at=datetime.utcnow(), - elapsed_time=0.0, - ) - else: - with TextOnly(): - fire_event(Formatting("")) - selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) - result = self.execute_with_hooks(selected_uids) + if len(self._flattened_nodes) == 0: + with TextOnly(): + fire_event(Formatting("")) + warn_or_error(NothingToDo()) + result = self.get_result( + results=[], + generated_at=datetime.utcnow(), + elapsed_time=0.0, + ) + else: + with TextOnly(): + fire_event(Formatting("")) + selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) + result = self.execute_with_hooks(selected_uids) # We have other result types here too, including FreshnessResult if isinstance(result, RunExecutionResult):