diff --git a/dbt/compilation.py b/dbt/compilation.py index c35ce1d4909..1c688d27771 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -102,7 +102,10 @@ def initialize(self): dbt.clients.system.make_directory(self.config.target_path) dbt.clients.system.make_directory(self.config.modules_path) - def compile_node(self, node, manifest): + def compile_node(self, node, manifest, extra_context=None): + if extra_context is None: + extra_context = {} + logger.debug("Compiling {}".format(node.get('unique_id'))) data = node.to_dict() @@ -117,6 +120,7 @@ def compile_node(self, node, manifest): context = dbt.context.runtime.generate( compiled_node, self.config, manifest) + context.update(extra_context) compiled_node.compiled_sql = dbt.clients.jinja.get_rendered( node.get('raw_sql'), diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 0755092b86d..f2ae1e22b05 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -247,13 +247,13 @@ def execute(self, compiled_node, manifest): def compile(self, manifest): return self._compile_node(self.adapter, self.config, self.node, - manifest) + manifest, {}) @classmethod - def _compile_node(cls, adapter, config, node, manifest): + def _compile_node(cls, adapter, config, node, manifest, extra_context): compiler = dbt.compilation.Compiler(config) - node = compiler.compile_node(node, manifest) - node = cls._inject_runtime_config(adapter, node) + node = compiler.compile_node(node, manifest, extra_context) + node = cls._inject_runtime_config(adapter, node, extra_context) if(node.injected_sql is not None and not (dbt.utils.is_type(node, NodeType.Archive))): @@ -271,9 +271,10 @@ def _compile_node(cls, adapter, config, node, manifest): return node @classmethod - def _inject_runtime_config(cls, adapter, node): + def _inject_runtime_config(cls, adapter, node, extra_context): wrapped_sql = node.wrapped_sql context = cls._node_context(adapter, node) + context.update(extra_context) sql = dbt.clients.jinja.get_rendered(wrapped_sql, context) node.wrapped_sql = sql return node @@ -310,7 +311,7 @@ def raise_on_first_error(self): return False @classmethod - def run_hooks(cls, config, adapter, manifest, hook_type): + def run_hooks(cls, config, adapter, manifest, hook_type, extra_context): nodes = manifest.nodes.values() hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation) @@ -328,7 +329,8 @@ def run_hooks(cls, config, adapter, manifest, hook_type): # Also, consider configuring psycopg2 (and other adapters?) to # ensure that a transaction is only created if dbt initiates it. adapter.clear_transaction(model_name) - compiled = cls._compile_node(adapter, config, hook, manifest) + compiled = cls._compile_node(adapter, config, hook, manifest, + extra_context) statement = compiled.wrapped_sql hook_index = hook.get('index', len(hooks)) @@ -346,10 +348,10 @@ def run_hooks(cls, config, adapter, manifest, hook_type): adapter.release_connection(model_name) @classmethod - def safe_run_hooks(cls, config, adapter, manifest, hook_type): + def safe_run_hooks(cls, config, adapter, manifest, hook_type, + extra_context): try: - cls.run_hooks(config, adapter, manifest, hook_type) - + cls.run_hooks(config, adapter, manifest, hook_type, extra_context) except dbt.exceptions.RuntimeException: logger.info("Database error while running {}".format(hook_type)) raise @@ -376,7 +378,7 @@ def populate_adapter_cache(cls, config, adapter, manifest): @classmethod def before_run(cls, config, adapter, manifest): cls.populate_adapter_cache(config, adapter, manifest) - cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start) + cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start, {}) cls.create_schemas(config, adapter, manifest) @classmethod @@ -397,7 +399,15 @@ def print_results_line(cls, results, execution_time): @classmethod def after_run(cls, config, adapter, results, manifest): - cls.safe_run_hooks(config, adapter, manifest, RunHookType.End) + # in on-run-end hooks, provide the value 'schemas', which is a list of + # unique schemas that successfully executed models were in + # errored failed skipped + schemas = list(set( + r.node.schema for r in results + if not any((r.errored, r.failed, r.skipped)) + )) + cls.safe_run_hooks(config, adapter, manifest, RunHookType.End, + {'schemas': schemas, 'results': results}) @classmethod def after_hooks(cls, config, adapter, results, manifest, elapsed): diff --git a/test/integration/014_hook_tests/test_run_hooks.py b/test/integration/014_hook_tests/test_run_hooks.py index 304f3b82b61..a92790c8473 100644 --- a/test/integration/014_hook_tests/test_run_hooks.py +++ b/test/integration/014_hook_tests/test_run_hooks.py @@ -45,6 +45,8 @@ def project_config(self): "{{ custom_run_hook('end', target, run_started_at, invocation_id) }}", "create table {{ target.schema }}.end_hook_order_test ( id int )", "drop table {{ target.schema }}.end_hook_order_test", + "create table {{ target.schema }}.schemas ( schema text )", + "insert into {{ target.schema }}.schemas values ({% for schema in schemas %}( '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %})", ] } @@ -63,6 +65,12 @@ def get_ctx_vars(self, state): return ctx + def assert_used_schemas(self): + schemas_query = 'select * from {}.schemas'.format(self.unique_schema()) + results = self.run_sql(schemas_query, fetch='all') + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], self.unique_schema()) + def check_hooks(self, state): ctx = self.get_ctx_vars(state) @@ -81,7 +89,7 @@ def check_hooks(self, state): self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set') @attr(type='postgres') - def test_pre_and_post_run_hooks(self): + def test__postgres__pre_and_post_run_hooks(self): self.run_dbt(['run']) self.check_hooks('start') @@ -89,9 +97,10 @@ def test_pre_and_post_run_hooks(self): self.assertTableDoesNotExist("start_hook_order_test") self.assertTableDoesNotExist("end_hook_order_test") + self.assert_used_schemas() @attr(type='postgres') - def test_pre_and_post_seed_hooks(self): + def test__postgres__pre_and_post_seed_hooks(self): self.run_dbt(['seed']) self.check_hooks('start') @@ -99,4 +108,4 @@ def test_pre_and_post_seed_hooks(self): self.assertTableDoesNotExist("start_hook_order_test") self.assertTableDoesNotExist("end_hook_order_test") - + self.assert_used_schemas()