Skip to content

Commit

Permalink
Render sql_header for dbt show
Browse files Browse the repository at this point in the history
  • Loading branch information
jtcohen6 committed Aug 16, 2023
1 parent ac539fd commit 7784625
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
17 changes: 13 additions & 4 deletions core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from dbt.task.compile import CompileTask, CompileRunner
from dbt.task.seed import SeedRunner

from dbt.context.providers import generate_runtime_model_context
from dbt.clients.jinja import get_rendered


class ShowRunner(CompileRunner):
def __init__(self, config, adapter, node, node_index, num_nodes):
Expand All @@ -23,13 +26,19 @@ def execute(self, compiled_node, manifest):
# Allow passing in -1 (or any negative number) to get all rows
limit = None if self.config.args.limit < 0 else self.config.args.limit

compiled_code = compiled_node.compiled_code

if "sql_header" in compiled_node.unrendered_config:
compiled_node.compiled_code = (
compiled_node.unrendered_config["sql_header"] + compiled_node.compiled_code
)
# Currently, we only render sql_header at *parse* time for *running* models:
# See dbt-core issues #2793, #3264, #7151
# So technically this should be "generate_parser_model_context" (I think) instead of "generate_runtime_model_context"
# Generating the context will be slower if we don't actually need to render the sql_header (if it contains no Jinja)
context = generate_runtime_model_context(compiled_node, self.config, manifest)
sql_header = get_rendered(compiled_node.unrendered_config["sql_header"], context)
compiled_code = sql_header + compiled_code

adapter_response, execute_result = self.adapter.execute(
compiled_node.compiled_code, fetch=True, limit=limit
compiled_code, fetch=True, limit=limit
)
end_time = time.time()

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/show/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

models__sql_header = """
{% call set_sql_header(config) %}
set session time zone 'Asia/Kolkata';
set session time zone '{{ var("timezone") }}';
{%- endcall %}
select current_setting('timezone') as timezone
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/functional/show/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def test_seed(self, project):
assert "Previewing node 'sample_seed'" in log_output

def test_sql_header(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"])
run_dbt(["build", "--vars", "timezone: Asia/Kolkata"])
(results, log_output) = run_dbt_and_capture(
["show", "--select", "sql_header", "--vars", "timezone: Asia/Kolkata"]
)
assert "Asia/Kolkata" in log_output


Expand Down

0 comments on commit 7784625

Please sign in to comment.