Skip to content

Commit

Permalink
consolidate flags (#6788)
Browse files Browse the repository at this point in the history
Co-authored-by: Michelle Ark <[email protected]>
Co-authored-by: Github Build Bot <[email protected]>
  • Loading branch information
3 people authored Feb 7, 2023
1 parent 9c0b62b commit d0b5d75
Show file tree
Hide file tree
Showing 68 changed files with 311 additions and 1,805 deletions.
5 changes: 3 additions & 2 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from dbt.events.functions import fire_event, fire_event_if
from dbt.events.types import CacheAction, CacheDumpGraph
import dbt.flags as flags
from dbt.flags import get_flags
from dbt.utils import lowercase


Expand Down Expand Up @@ -319,6 +319,7 @@ def add(self, relation):
:param BaseRelation relation: The underlying relation.
"""
flags = get_flags()
cached = _CachedRelation(relation)
fire_event_if(
flags.LOG_CACHE_EVENTS,
Expand Down Expand Up @@ -456,7 +457,7 @@ def rename(self, old, new):
ref_key_2=_make_msg_from_ref_key(new),
)
)

flags = get_flags()
fire_event_if(
flags.LOG_CACHE_EVENTS,
lambda: CacheDumpGraph(before_after="before", action="rename", dump=self.dump_graph()),
Expand Down
89 changes: 83 additions & 6 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,85 @@

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from dbt.helper_types import WarnErrorOptions
from dbt.config.project import PartialProject
from dbt.exceptions import DbtProjectError

if os.name != "nt":
# https://bugs.python.org/issue41567
import multiprocessing.popen_spawn_posix # type: ignore # noqa: F401

# TODO anything that has a default in params should be removed here?
# Or maybe only the ones that's in the root click group
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
# cli args without user_config or env var option
"FULL_REFRESH": False,
"STRICT_MODE": False,
"STORE_FAILURES": False,
}


# For backwards compatability, some params are defined across multiple levels,
# Top-level value should take precedence.
# e.g. dbt --target-path test2 run --target-path test2
EXPECTED_DUPLICATE_PARAMS = [
"full_refresh",
"target_path",
"version_check",
"fail_fast",
"indirect_selection",
"store_failures",
]


def convert_config(config_name, config_value):
# This function should take care of converting the values from config and original
# set_from_args to the correct type
ret = config_value
if config_name.lower() == "warn_error_options":
ret = WarnErrorOptions(
include=config_value.get("include", []), exclude=config_value.get("exclude", [])
)
return ret


@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

# set the default flags
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)

if ctx is None:
ctx = get_current_context()

def assign_params(ctx, params_assigned_from_default):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# TODO: this is to avoid duplicate params being defined in two places (version_check in run and cli)
# However this is a bit of a hack and we should find a better way to do this

# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
# when using frozen dataclasses.
# https://docs.python.org/3/library/dataclasses.html#frozen-instances
if hasattr(self, param_name):
raise Exception(f"Duplicate flag names found in click command: {param_name}")
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
if hasattr(self, param_name.upper()):
if param_name not in EXPECTED_DUPLICATE_PARAMS:
raise Exception(
f"Duplicate flag names found in click command: {param_name}"
)
else:
# Expected duplicate param from multi-level click command (ex: dbt --full_refresh run --full_refresh)
# Overwrite user-configured param with value from parent context
if ctx.get_parameter_source(param_name) != ParameterSource.DEFAULT:
object.__setattr__(self, param_name.upper(), param_value)
else:
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)

if ctx.parent:
assign_params(ctx.parent, params_assigned_from_default)

Expand Down Expand Up @@ -64,7 +119,9 @@ def assign_params(ctx, params_assigned_from_default):
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
self,
param_assigned_from_default.upper(),
convert_config(param_assigned_from_default, user_config_param_value),
)
param_assigned_from_default_copy.remove(param_assigned_from_default)
params_assigned_from_default = param_assigned_from_default_copy
Expand All @@ -73,6 +130,26 @@ def assign_params(ctx, params_assigned_from_default):
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Default LOG_PATH from PROJECT_DIR, if available.
if getattr(self, "LOG_PATH", None) is None:
log_path = "logs"
project_dir = getattr(self, "PROJECT_DIR", None)
# If available, set LOG_PATH from log-path in dbt_project.yml
# Known limitations:
# 1. Using PartialProject here, so no jinja rendering of log-path.
# 2. Programmatic invocations of the cli via dbtRunner may pass a Project object directly,
# which is not being used here to extract log-path.
if project_dir:
try:
partial = PartialProject.from_project_root(
project_dir, verify_version=getattr(self, "VERSION_CHECK", True)
)
log_path = str(partial.project_dict.get("log-path", log_path))
except DbtProjectError:
pass

object.__setattr__(self, "LOG_PATH", log_path)

# Support console DO NOT TRACK initiave
object.__setattr__(
self,
Expand Down
1 change: 0 additions & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def seed(ctx, **kwargs):
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/option_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class WarnErrorOptionsType(YAML):
name = "WarnErrorOptionsType"

def convert(self, value, param, ctx):
# this function is being used by param in click
include_exclude = super().convert(value, param, ctx)

return WarnErrorOptions(
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"--log-path",
envvar="DBT_LOG_PATH",
help="Configure the 'log-path'. Only applies this setting for the current run. Overrides the 'DBT_LOG_PATH' if it is set.",
default=lambda: Path.cwd() / "logs",
default=None,
type=click.Path(resolve_path=True, path_type=Path),
)

Expand Down Expand Up @@ -415,7 +415,7 @@ def _version_callback(ctx, _param, value):
warn_error_options = click.option(
"--warn-error-options",
envvar="DBT_WARN_ERROR_OPTIONS",
default=None,
default="{}",
help="""If dbt would normally warn, instead raise an exception based on include/exclude configuration. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations,
and missing sources/refs in tests. This argument should be a YAML string, with keys 'include' or 'exclude'. eg. '{"include": "all", "exclude": ["NoNodesForSelectionCriteria"]}'""",
type=WarnErrorOptionsType(),
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dbt.adapters.factory import adapter_management, register_adapter
from dbt.flags import set_flags
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile
Expand All @@ -21,6 +22,7 @@ def wrapper(*args, **kwargs):
# Flags
flags = Flags(ctx)
ctx.obj["flags"] = flags
set_flags(flags)

# Tracking
initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
UndefinedCompilationError,
UndefinedMacroError,
)
from dbt import flags
from dbt.flags import get_flags
from dbt.node_types import ModelLanguage


Expand Down Expand Up @@ -99,8 +99,9 @@ def _compile(self, source, filename):
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
if filename == "<template>" and flags.MACRO_DEBUGGING:
write = flags.MACRO_DEBUGGING == "write"
macro_debugging = get_flags().MACRO_DEBUGGING
if filename == "<template>" and macro_debugging:
write = macro_debugging == "write"
filename = _linecache_inject(source, write)

return super()._compile(source, filename) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from typing import List, Dict, Any, Tuple, Optional

from dbt import flags
from dbt.flags import get_flags
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
Expand Down Expand Up @@ -378,6 +378,7 @@ def _compile_node(
def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
graph_path = os.path.join(self.config.target_path, filename)
flags = get_flags()
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)

Expand Down
43 changes: 25 additions & 18 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dbt.dataclass_schema import ValidationError

from dbt import flags
from dbt.flags import get_flags
from dbt.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import Credentials, HasCredentials
Expand Down Expand Up @@ -32,22 +32,6 @@
"""


NO_SUPPLIED_PROFILE_ERROR = """\
dbt cannot run because no profile was specified for this dbt project.
To specify a profile for this project, add a line like the this to
your dbt_project.yml file:
profile: [profile name]
Here, [profile name] should be replaced with a profile name
defined in your profiles.yml file. You can find profiles.yml here:
{profiles_file}/profiles.yml
""".format(
profiles_file=flags.DEFAULT_PROFILES_DIR
)


def read_profile(profiles_dir: str) -> Dict[str, Any]:
path = os.path.join(profiles_dir, "profiles.yml")

Expand Down Expand Up @@ -197,10 +181,33 @@ def pick_profile_name(
args_profile_name: Optional[str],
project_profile_name: Optional[str] = None,
) -> str:
# TODO: Duplicating this method as direct copy of the implementation in dbt.cli.resolvers
# dbt.cli.resolvers implementation can't be used because it causes a circular dependency.
# This should be removed and use a safe default access on the Flags module when
# https://github.com/dbt-labs/dbt-core/issues/6259 is closed.
def default_profiles_dir():
from pathlib import Path

return Path.cwd() if (Path.cwd() / "profiles.yml").exists() else Path.home() / ".dbt"

profile_name = project_profile_name
if args_profile_name is not None:
profile_name = args_profile_name
if profile_name is None:
NO_SUPPLIED_PROFILE_ERROR = """\
dbt cannot run because no profile was specified for this dbt project.
To specify a profile for this project, add a line like the this to
your dbt_project.yml file:
profile: [profile name]
Here, [profile name] should be replaced with a profile name
defined in your profiles.yml file. You can find profiles.yml here:
{profiles_file}/profiles.yml
""".format(
profiles_file=default_profiles_dir()
)
raise DbtProjectError(NO_SUPPLIED_PROFILE_ERROR)
return profile_name

Expand Down Expand Up @@ -423,7 +430,7 @@ def render(
target could not be found.
:returns Profile: The new Profile object.
"""

flags = get_flags()
raw_profiles = read_profile(flags.PROFILES_DIR)
profile_name = cls.pick_profile_name(profile_name_override, project_profile_name)
return cls.from_raw_profiles(
Expand Down
11 changes: 8 additions & 3 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import hashlib
import os

from dbt import flags, deprecations
from dbt.flags import get_flags
from dbt import deprecations
from dbt.clients.system import path_exists, resolve_path_from_base, load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import QueryComment
Expand Down Expand Up @@ -373,9 +374,13 @@ def create_project(self, rendered: RenderComponents) -> "Project":

docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
asset_paths: List[str] = value_or(cfg.asset_paths, [])
target_path: str = flag_or(flags.TARGET_PATH, cfg.target_path, "target")
flags = get_flags()

flag_target_path = str(flags.TARGET_PATH) if flags.TARGET_PATH else None
target_path: str = flag_or(flag_target_path, cfg.target_path, "target")

log_path: str = str(flags.LOG_PATH)
clean_targets: List[str] = value_or(cfg.clean_targets, [target_path])
log_path: str = flag_or(flags.LOG_PATH, cfg.log_path, "logs")
packages_install_path: str = value_or(cfg.packages_install_path, "dbt_packages")
# in the default case we'll populate this once we know the adapter type
# It would be nice to just pass along a Quoting here, but that would
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Type,
)

from dbt import flags
from dbt.flags import get_flags
from dbt.adapters.factory import get_include_paths, get_relation_class_by_name
from dbt.config.project import load_raw_project
from dbt.contracts.connection import AdapterRequiredConfig, Credentials, HasCredentials
Expand Down Expand Up @@ -197,11 +197,10 @@ def new_project(self, project_root: str) -> "RuntimeConfig":

# load the new project and its packages. Don't pass cli variables.
renderer = DbtProjectYamlRenderer(profile)

project = Project.from_project_root(
project_root,
renderer,
verify_version=bool(flags.VERSION_CHECK),
verify_version=bool(getattr(self.args, "VERSION_CHECK", True)),
)

runtime_config = self.from_parts(
Expand Down Expand Up @@ -247,6 +246,7 @@ def collect_parts(cls: Type["RuntimeConfig"], args: Any) -> Tuple[Project, Profi
cli_vars,
args,
)
flags = get_flags()
project = load_project(project_root, bool(flags.VERSION_CHECK), profile, cli_vars)
return project, profile

Expand Down
Loading

0 comments on commit d0b5d75

Please sign in to comment.