Skip to content

Commit

Permalink
Merge pull request #1028 from fishtown-analytics/feature/schemas-on-o…
Browse files Browse the repository at this point in the history
…n-run-end

make schemas available to on-run-end hooks (#908)
  • Loading branch information
beckjake authored Oct 15, 2018
2 parents dd25750 + 4c4bd0c commit 0ca86a5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
6 changes: 5 additions & 1 deletion dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def initialize(self):
dbt.clients.system.make_directory(self.config.target_path)
dbt.clients.system.make_directory(self.config.modules_path)

def compile_node(self, node, manifest):
def compile_node(self, node, manifest, extra_context=None):
if extra_context is None:
extra_context = {}

logger.debug("Compiling {}".format(node.get('unique_id')))

data = node.to_dict()
Expand All @@ -117,6 +120,7 @@ def compile_node(self, node, manifest):

context = dbt.context.runtime.generate(
compiled_node, self.config, manifest)
context.update(extra_context)

compiled_node.compiled_sql = dbt.clients.jinja.get_rendered(
node.get('raw_sql'),
Expand Down
34 changes: 22 additions & 12 deletions dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ def execute(self, compiled_node, manifest):

def compile(self, manifest):
return self._compile_node(self.adapter, self.config, self.node,
manifest)
manifest, {})

@classmethod
def _compile_node(cls, adapter, config, node, manifest):
def _compile_node(cls, adapter, config, node, manifest, extra_context):
compiler = dbt.compilation.Compiler(config)
node = compiler.compile_node(node, manifest)
node = cls._inject_runtime_config(adapter, node)
node = compiler.compile_node(node, manifest, extra_context)
node = cls._inject_runtime_config(adapter, node, extra_context)

if(node.injected_sql is not None and
not (dbt.utils.is_type(node, NodeType.Archive))):
Expand All @@ -271,9 +271,10 @@ def _compile_node(cls, adapter, config, node, manifest):
return node

@classmethod
def _inject_runtime_config(cls, adapter, node):
def _inject_runtime_config(cls, adapter, node, extra_context):
wrapped_sql = node.wrapped_sql
context = cls._node_context(adapter, node)
context.update(extra_context)
sql = dbt.clients.jinja.get_rendered(wrapped_sql, context)
node.wrapped_sql = sql
return node
Expand Down Expand Up @@ -310,7 +311,7 @@ def raise_on_first_error(self):
return False

@classmethod
def run_hooks(cls, config, adapter, manifest, hook_type):
def run_hooks(cls, config, adapter, manifest, hook_type, extra_context):

nodes = manifest.nodes.values()
hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation)
Expand All @@ -328,7 +329,8 @@ def run_hooks(cls, config, adapter, manifest, hook_type):
# Also, consider configuring psycopg2 (and other adapters?) to
# ensure that a transaction is only created if dbt initiates it.
adapter.clear_transaction(model_name)
compiled = cls._compile_node(adapter, config, hook, manifest)
compiled = cls._compile_node(adapter, config, hook, manifest,
extra_context)
statement = compiled.wrapped_sql

hook_index = hook.get('index', len(hooks))
Expand All @@ -346,10 +348,10 @@ def run_hooks(cls, config, adapter, manifest, hook_type):
adapter.release_connection(model_name)

@classmethod
def safe_run_hooks(cls, config, adapter, manifest, hook_type):
def safe_run_hooks(cls, config, adapter, manifest, hook_type,
extra_context):
try:
cls.run_hooks(config, adapter, manifest, hook_type)

cls.run_hooks(config, adapter, manifest, hook_type, extra_context)
except dbt.exceptions.RuntimeException:
logger.info("Database error while running {}".format(hook_type))
raise
Expand All @@ -376,7 +378,7 @@ def populate_adapter_cache(cls, config, adapter, manifest):
@classmethod
def before_run(cls, config, adapter, manifest):
cls.populate_adapter_cache(config, adapter, manifest)
cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start)
cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start, {})
cls.create_schemas(config, adapter, manifest)

@classmethod
Expand All @@ -397,7 +399,15 @@ def print_results_line(cls, results, execution_time):

@classmethod
def after_run(cls, config, adapter, results, manifest):
cls.safe_run_hooks(config, adapter, manifest, RunHookType.End)
# in on-run-end hooks, provide the value 'schemas', which is a list of
# unique schemas that successfully executed models were in
# errored failed skipped
schemas = list(set(
r.node.schema for r in results
if not any((r.errored, r.failed, r.skipped))
))
cls.safe_run_hooks(config, adapter, manifest, RunHookType.End,
{'schemas': schemas, 'results': results})

@classmethod
def after_hooks(cls, config, adapter, results, manifest, elapsed):
Expand Down
15 changes: 12 additions & 3 deletions test/integration/014_hook_tests/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def project_config(self):
"{{ custom_run_hook('end', target, run_started_at, invocation_id) }}",
"create table {{ target.schema }}.end_hook_order_test ( id int )",
"drop table {{ target.schema }}.end_hook_order_test",
"create table {{ target.schema }}.schemas ( schema text )",
"insert into {{ target.schema }}.schemas values ({% for schema in schemas %}( '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %})",
]
}

Expand All @@ -63,6 +65,12 @@ def get_ctx_vars(self, state):

return ctx

def assert_used_schemas(self):
schemas_query = 'select * from {}.schemas'.format(self.unique_schema())
results = self.run_sql(schemas_query, fetch='all')
self.assertEqual(len(results), 1)
self.assertEqual(results[0][0], self.unique_schema())

def check_hooks(self, state):
ctx = self.get_ctx_vars(state)

Expand All @@ -81,22 +89,23 @@ def check_hooks(self, state):
self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set')

@attr(type='postgres')
def test_pre_and_post_run_hooks(self):
def test__postgres__pre_and_post_run_hooks(self):
self.run_dbt(['run'])

self.check_hooks('start')
self.check_hooks('end')

self.assertTableDoesNotExist("start_hook_order_test")
self.assertTableDoesNotExist("end_hook_order_test")
self.assert_used_schemas()

@attr(type='postgres')
def test_pre_and_post_seed_hooks(self):
def test__postgres__pre_and_post_seed_hooks(self):
self.run_dbt(['seed'])

self.check_hooks('start')
self.check_hooks('end')

self.assertTableDoesNotExist("start_hook_order_test")
self.assertTableDoesNotExist("end_hook_order_test")

self.assert_used_schemas()

0 comments on commit 0ca86a5

Please sign in to comment.