Skip to content

Commit

Permalink
Fixes #7299: dbt retry (#7763)
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke authored Jun 5, 2023
1 parent 60d116b commit dc35f56
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230602-083302.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: dbt retry
time: 2023-06-02T08:33:02.410456-07:00
custom:
Author: stu-k aranke
Issue: "7299"
5 changes: 4 additions & 1 deletion core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def add_fn(x):

spinal_cased = k.replace("_", "-")

if v in (None, False):
if k == "macro" and command == CliCommand.RUN_OPERATION:
add_fn(v)
elif v in (None, False):
add_fn(f"--no-{spinal_cased}")
elif v is True:
add_fn(f"--{spinal_cased}")
Expand Down Expand Up @@ -384,6 +386,7 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SNAPSHOT: cli.snapshot,
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:
Expand Down
32 changes: 32 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dbt.task.generate import GenerateTask
from dbt.task.init import InitTask
from dbt.task.list import ListTask
from dbt.task.retry import RetryTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
Expand Down Expand Up @@ -576,6 +577,36 @@ def run(ctx, **kwargs):
return results, success


# dbt run
@cli.command("retry")
@click.pass_context
@p.project_dir
@p.profiles_dir
@p.vars
@p.profile
@p.target
@p.state
@p.threads
@p.fail_fast
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success


# dbt run operation
@cli.command("run-operation")
@click.pass_context
Expand All @@ -586,6 +617,7 @@ def run(ctx, **kwargs):
@p.project_dir
@p.target
@p.target_path
@p.threads
@p.vars
@requires.postflight
@requires.preflight
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Command(Enum):
SNAPSHOT = "snapshot"
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"

@classmethod
def from_str(cls, s: str) -> "Command":
Expand Down
113 changes: 113 additions & 0 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from pathlib import Path

from dbt.cli.flags import Flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
from dbt.task.build import BuildTask
from dbt.task.compile import CompileTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}

TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
"generate": GenerateTask,
"seed": SeedTask,
"snapshot": SnapshotTask,
"test": TestTask,
"run": RunTask,
"run-operation": RunOperationTask,
}

CMD_DICT = {
"build": CliCommand.BUILD,
"compile": CliCommand.COMPILE,
"generate": CliCommand.DOCS_GENERATE,
"seed": CliCommand.SEED,
"snapshot": CliCommand.SNAPSHOT,
"test": CliCommand.TEST,
"run": CliCommand.RUN,
"run-operation": CliCommand.RUN_OPERATION,
}


class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest):
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

self.previous_state = PreviousState(
state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
)

if not self.previous_state.results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name)

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)

# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

retry_flags = Flags.from_dict(cli_command, self.previous_args)
retry_config = RuntimeConfig.from_args(args=retry_flags)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
return GraphQueue(
new_graph.graph,
self.manifest,
unique_ids,
)

task = TaskWrapper(
retry_flags,
retry_config,
self.manifest,
)

return_value = task.run()
return return_value

def interpret_results(self, *args, **kwargs):
return self.task_class.interpret_results(*args, **kwargs)
9 changes: 6 additions & 3 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def _run_unsafe(self) -> agate.Table:
def run(self) -> RunResultsArtifact:
start = datetime.utcnow()
self.compile_manifest()

success = True

try:
self._run_unsafe()
except dbt.exceptions.Exception as exc:
Expand All @@ -59,8 +62,7 @@ def run(self) -> RunResultsArtifact:
fire_event(RunningOperationUncaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False
else:
success = True

end = datetime.utcnow()

package_name, macro_name = self._get_macro_parts()
Expand Down Expand Up @@ -108,5 +110,6 @@ def run(self) -> RunResultsArtifact:

return results

def interpret_results(self, results):
@classmethod
def interpret_results(cls, results):
return results.results[0].status == RunStatus.Success
47 changes: 47 additions & 0 deletions tests/functional/retry/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
models__sample_model = """select 1 as id, baz as foo"""
models__second_model = """select 1 as id, 2 as bar"""

models__union_model = """
select foo + bar as sum3 from {{ ref('sample_model') }}
left join {{ ref('second_model') }} on sample_model.id = second_model.id
"""

schema_yml = """
models:
- name: sample_model
columns:
- name: foo
tests:
- accepted_values:
values: [3]
quote: false
config:
severity: warn
- name: second_model
columns:
- name: bar
tests:
- accepted_values:
values: [3]
quote: false
config:
severity: warn
- name: union_model
columns:
- name: sum3
tests:
- accepted_values:
values: [3]
quote: false
"""

macros__alter_timezone_sql = """
{% macro alter_timezone(timezone='America/Los_Angeles') %}
{% set sql %}
SET TimeZone='{{ timezone }}';
{% endset %}
{% do run_query(sql) %}
{% do log("Timezone set to: " + timezone, info=True) %}
{% endmacro %}
"""
Loading

0 comments on commit dc35f56

Please sign in to comment.