From 3ca0b7aa9d5dbbd8922f4caa2b9c8fcdf34bf803 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Mon, 4 Dec 2023 14:00:51 -0800 Subject: [PATCH 1/3] reorg macro tests so that we can move macros --- dbt/include/databricks/macros/adapters.sql | 39 -- .../macros/relations/table/create.sql | 37 ++ tests/unit/macros/base.py | 157 +++-- .../macros/relations/test_table_macros.py | 225 +++++++ tests/unit/macros/test_adapters_macros.py | 570 +++++------------- tests/unit/macros/test_python_macros.py | 111 ++-- 6 files changed, 603 insertions(+), 536 deletions(-) create mode 100644 dbt/include/databricks/macros/relations/table/create.sql create mode 100644 tests/unit/macros/relations/test_table_macros.py diff --git a/dbt/include/databricks/macros/adapters.sql b/dbt/include/databricks/macros/adapters.sql index 4c2848fe8..bb44a14e7 100644 --- a/dbt/include/databricks/macros/adapters.sql +++ b/dbt/include/databricks/macros/adapters.sql @@ -62,45 +62,6 @@ {%- endif %} {%- endmacro -%} - -{% macro databricks__create_table_as(temporary, relation, compiled_code, language='sql') -%} - {%- if language == 'sql' -%} - {%- if temporary -%} - {{ create_temporary_view(relation, compiled_code) }} - {%- else -%} - {% if config.get('file_format', default='delta') == 'delta' %} - create or replace table {{ relation }} - {% else %} - create table {{ relation }} - {% endif %} - {%- set contract_config = config.get('contract') -%} - {% if contract_config and contract_config.enforced %} - {{ get_assert_columns_equivalent(compiled_code) }} - {%- set compiled_code = get_select_subquery(compiled_code) %} - {% endif %} - {{ file_format_clause() }} - {{ options_clause() }} - {{ partition_cols(label="partitioned by") }} - {{ liquid_clustered_cols(label="cluster by") }} - {{ clustered_cols(label="clustered by") }} - {{ location_clause() }} - {{ comment_clause() }} - {{ tblproperties_clause() }} - as - {{ compiled_code }} - {%- endif -%} - {%- elif language == 'python' -%} - {#-- - N.B. Python models _can_ write to temp views HOWEVER they use a different session - and have already expired by the time they need to be used (I.E. in merges for incremental models) - - TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire - dbt invocation. - --#} - {{ databricks__py_write_table(compiled_code=compiled_code, target_relation=relation) }} - {%- endif -%} -{%- endmacro -%} - {% macro get_column_comment_sql(column_name, column_dict) -%} {% if column_name in column_dict and column_dict[column_name]["description"] -%} {% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %} diff --git a/dbt/include/databricks/macros/relations/table/create.sql b/dbt/include/databricks/macros/relations/table/create.sql new file mode 100644 index 000000000..08400ccf8 --- /dev/null +++ b/dbt/include/databricks/macros/relations/table/create.sql @@ -0,0 +1,37 @@ +{% macro databricks__create_table_as(temporary, relation, compiled_code, language='sql') -%} + {%- if language == 'sql' -%} + {%- if temporary -%} + {{ create_temporary_view(relation, compiled_code) }} + {%- else -%} + {% if config.get('file_format', default='delta') == 'delta' %} + create or replace table {{ relation }} + {% else %} + create table {{ relation }} + {% endif %} + {%- set contract_config = config.get('contract') -%} + {% if contract_config and contract_config.enforced %} + {{ get_assert_columns_equivalent(compiled_code) }} + {%- set compiled_code = get_select_subquery(compiled_code) %} + {% endif %} + {{ file_format_clause() }} + {{ options_clause() }} + {{ partition_cols(label="partitioned by") }} + {{ liquid_clustered_cols(label="cluster by") }} + {{ clustered_cols(label="clustered by") }} + {{ location_clause() }} + {{ comment_clause() }} + {{ tblproperties_clause() }} + as + {{ compiled_code }} + {%- endif -%} + {%- elif language == 'python' -%} + {#-- + N.B. Python models _can_ write to temp views HOWEVER they use a different session + and have already expired by the time they need to be used (I.E. in merges for incremental models) + + TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire + dbt invocation. + --#} + {{ databricks__py_write_table(compiled_code=compiled_code, target_relation=relation) }} + {%- endif -%} +{%- endmacro -%} \ No newline at end of file diff --git a/tests/unit/macros/base.py b/tests/unit/macros/base.py index f697815e5..2df25c642 100644 --- a/tests/unit/macros/base.py +++ b/tests/unit/macros/base.py @@ -1,56 +1,139 @@ -import unittest -from unittest import mock import re -from jinja2 import Environment, FileSystemLoader, PackageLoader +from mock import Mock +import pytest +from jinja2 import Environment, FileSystemLoader, PackageLoader, Template +from dbt.adapters.databricks.relation import DatabricksRelation -class TestMacros(unittest.TestCase): - def setUp(self): - self.parent_jinja_env = Environment( +class TemplateBundle: + def __init__(self, template, context, relation): + self.template = template + self.context = context + self.relation = relation + + +class MacroTestBase: + @pytest.fixture(autouse=True) + def config(self, context) -> dict: + local_config = {} + context["config"].get = lambda key, default=None, **kwargs: local_config.get(key, default) + return local_config + + @pytest.fixture(autouse=True) + def var(self, context) -> dict: + local_var = {} + context["var"] = lambda key, default=None, **kwargs: local_var.get(key, default) + return local_var + + @pytest.fixture(scope="class") + def default_context(self) -> dict: + context = { + "validation": Mock(), + "model": Mock(), + "exceptions": Mock(), + "config": Mock(), + "statement": lambda r, caller: r, + "adapter": Mock(), + "var": Mock(), + "return": lambda r: r, + } + + return context + + @pytest.fixture(scope="class") + def spark_env(self) -> Environment: + return Environment( loader=PackageLoader("dbt.include.spark", "macros"), extensions=["jinja2.ext.do"], ) - self.jinja_env = Environment( - loader=FileSystemLoader("dbt/include/databricks/macros"), + + @pytest.fixture(scope="class") + def spark_template_names(self) -> list: + return ["adapters.sql"] + + @pytest.fixture(scope="class") + def spark_context(self, default_context, spark_env, spark_template_names) -> dict: + return self.build_up_context(default_context, spark_env, spark_template_names) + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros"] + + @pytest.fixture(scope="class") + def databricks_env(self, macro_folders_to_load) -> Environment: + return Environment( + loader=FileSystemLoader( + [f"dbt/include/databricks/{folder}" for folder in macro_folders_to_load] + ), extensions=["jinja2.ext.do"], ) - self.config = {} - self.var = {} - self.default_context = { - "validation": mock.Mock(), - "model": mock.Mock(), - "exceptions": mock.Mock(), - "config": mock.Mock(), - "statement": lambda r, caller: r, - "adapter": mock.Mock(), - "var": mock.Mock(), - "return": lambda r: r, - } - self.default_context["config"].get = lambda key, default=None, **kwargs: self.config.get( - key, default - ) + @pytest.fixture(scope="class") + def databricks_template_names(self) -> list: + return [] + + @pytest.fixture(scope="class") + def databricks_context(self, spark_context, databricks_env, databricks_template_names) -> dict: + if not databricks_template_names: + return spark_context + return self.build_up_context(spark_context, databricks_env, databricks_template_names) + + def build_up_context(self, context, env, template_names): + new_context = context.copy() + for template_name in template_names: + template = env.get_template(template_name, globals=context) + new_context.update(template.module.__dict__) + + return new_context - self.default_context["var"] = lambda key, default=None, **kwargs: self.var.get(key, default) + @pytest.fixture + def context(self, databricks_context) -> dict: + return databricks_context.copy() - def _get_template(self, template_filename, parent_context=None): - parent_filename = parent_context or template_filename - parent = self.parent_jinja_env.get_template(parent_filename, globals=self.default_context) - self.default_context.update(parent.module.__dict__) + @pytest.fixture(scope="class") + def template_name(self) -> str: + raise NotImplementedError("Must be implemented by subclasses") - return self.jinja_env.get_template(template_filename, globals=self.default_context) + @pytest.fixture + def template(self, template_name, context, databricks_env) -> Template: + current_template = databricks_env.get_template(template_name, globals=context) - def _run_macro_raw(self, name, *args): def dispatch(macro_name, macro_namespace=None, packages=None): - if hasattr(self.template.module, f"databricks__{macro_name}"): - return getattr(self.template.module, f"databricks__{macro_name}") + if hasattr(current_template.module, f"databricks__{macro_name}"): + return getattr(current_template.module, f"databricks__{macro_name}") + elif f"databricks__{macro_name}" in context: + return context[f"databricks__{macro_name}"] else: - return self.default_context[f"spark__{macro_name}"] + return context[f"spark__{macro_name}"] - self.default_context["adapter"].dispatch = dispatch + context["adapter"].dispatch = dispatch - return getattr(self.template.module, name)(*args) + return current_template - def _run_macro(self, name, *args): - value = self._run_macro_raw(name, *args) + @pytest.fixture(scope="class") + def relation(self): + data = { + "path": { + "database": "some_database", + "schema": "some_schema", + "identifier": "some_table", + }, + "type": None, + } + + return DatabricksRelation.from_dict(data) + + @pytest.fixture + def template_bundle(self, template, context, relation): + context["model"].alias = relation.identifier + return TemplateBundle(template, context, relation) + + def run_macro_raw(self, template, name, *args): + return getattr(template.module, name)(*args) + + def run_macro(self, template, name, *args): + value = self.run_macro_raw(template, name, *args) return re.sub(r"\s\s+", " ", value).strip() + + def render_bundle(self, template_bundle, name, *args): + return self.run_macro(template_bundle.template, name, template_bundle.relation, *args) diff --git a/tests/unit/macros/relations/test_table_macros.py b/tests/unit/macros/relations/test_table_macros.py new file mode 100644 index 000000000..b4ddef893 --- /dev/null +++ b/tests/unit/macros/relations/test_table_macros.py @@ -0,0 +1,225 @@ +from mock import Mock +from jinja2 import Environment, FileSystemLoader, PackageLoader +import re +import pytest + +from tests.unit.macros.base import MacroTestBase + + +class TestCreateTableAs(MacroTestBase): + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "create.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros", "macros/relations/table"] + + @pytest.fixture(scope="class") + def databricks_template_names(self) -> list: + return ["adapters.sql"] + + def render_create_table_as(self, template_bundle, temporary=False, sql="select 1"): + return self.run_macro( + template_bundle.template, + "databricks__create_table_as", + temporary, + template_bundle.relation, + sql, + ) + + def test_macros_create_table_as(self, template_bundle): + sql = self.render_create_table_as(template_bundle) + assert sql == f"create or replace table {template_bundle.relation} using delta as select 1" + + @pytest.mark.parametrize("format", ["parquet", "hudi"]) + def test_macros_create_table_as_file_format(self, format, config, template_bundle): + config["file_format"] = format + sql = self.render_create_table_as(template_bundle) + assert sql == f"create table {template_bundle.relation} using {format} as select 1" + + def test_macros_create_table_as_options(self, config, template_bundle): + config["options"] = {"compression": "gzip"} + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} " + 'using delta options (compression "gzip" ) as select 1' + ) + + assert sql == expected + + def test_macros_create_table_as_hudi_unique_key(self, config, template_bundle): + config["file_format"] = "hudi" + config["unique_key"] = "id" + sql = self.render_create_table_as(template_bundle, sql="select 1 as id") + + expected = ( + f"create table {template_bundle.relation} using hudi options (primaryKey" + ' "id" ) as select 1 as id' + ) + + assert sql == expected + + def test_macros_create_table_as_hudi_unique_key_primary_key_match( + self, config, template_bundle + ): + config["file_format"] = "hudi" + config["unique_key"] = "id" + config["options"] = {"primaryKey": "id"} + sql = self.render_create_table_as(template_bundle, sql="select 1 as id") + + expected = ( + f"create table {template_bundle.relation} using hudi options (primaryKey" + ' "id" ) as select 1 as id' + ) + assert sql == expected + + def test_macros_create_table_as_hudi_unique_key_primary_key_mismatch( + self, config, template_bundle + ): + config["file_format"] = "hudi" + config["unique_key"] = "uuid" + config["options"] = {"primaryKey": "id"} + sql = self.render_create_table_as(template_bundle, sql="select 1 as id, 2 as uuid") + assert "mock.raise_compiler_error()" in sql + + def test_macros_create_table_as_partition(self, config, template_bundle): + config["partition_by"] = "partition_1" + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} using delta" + " partitioned by (partition_1) as select 1" + ) + assert sql == expected + + def test_macros_create_table_as_partitions(self, config, template_bundle): + config["partition_by"] = ["partition_1", "partition_2"] + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta partitioned by (partition_1,partition_2) as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_cluster(self, config, template_bundle): + config["clustered_by"] = "cluster_1" + config["buckets"] = "1" + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta clustered by (cluster_1) into 1 buckets as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_clusters(self, config, template_bundle): + config["clustered_by"] = ["cluster_1", "cluster_2"] + config["buckets"] = "1" + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta clustered by (cluster_1,cluster_2) into 1 buckets as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_liquid_cluster(self, config, template_bundle): + config["liquid_clustered_by"] = "cluster_1" + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} using" + " delta cluster by (cluster_1) as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_liquid_clusters(self, config, template_bundle): + config["liquid_clustered_by"] = ["cluster_1", "cluster_2"] + config["buckets"] = "1" + sql = self.render_create_table_as(template_bundle) + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta cluster by (cluster_1,cluster_2) as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_comment(self, config, template_bundle): + config["persist_docs"] = {"relation": True} + template_bundle.context["model"].description = "Description Test" + + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta comment 'Description Test' as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_tblproperties(self, config, template_bundle): + config["tblproperties"] = {"delta.appendOnly": "true"} + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta tblproperties ('delta.appendOnly' = 'true' ) as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_all_delta(self, config, template_bundle): + config["location_root"] = "/mnt/root" + config["partition_by"] = ["partition_1", "partition_2"] + config["liquid_clustered_by"] = ["cluster_1", "cluster_2"] + config["clustered_by"] = ["cluster_1", "cluster_2"] + config["buckets"] = "1" + config["persist_docs"] = {"relation": True} + config["tblproperties"] = {"delta.appendOnly": "true"} + template_bundle.context["model"].description = "Description Test" + + config["file_format"] = "delta" + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create or replace table {template_bundle.relation} " + "using delta " + "partitioned by (partition_1,partition_2) " + "cluster by (cluster_1,cluster_2) " + "clustered by (cluster_1,cluster_2) into 1 buckets " + "location '/mnt/root/some_table' " + "comment 'Description Test' " + "tblproperties ('delta.appendOnly' = 'true' ) " + "as select 1" + ) + + assert sql == expected + + def test_macros_create_table_as_all_hudi(self, config, template_bundle): + config["location_root"] = "/mnt/root" + config["partition_by"] = ["partition_1", "partition_2"] + config["clustered_by"] = ["cluster_1", "cluster_2"] + config["buckets"] = "1" + config["persist_docs"] = {"relation": True} + config["tblproperties"] = {"delta.appendOnly": "true"} + template_bundle.context["model"].description = "Description Test" + + config["file_format"] = "hudi" + sql = self.render_create_table_as(template_bundle) + + expected = ( + f"create table {template_bundle.relation} " + "using hudi " + "partitioned by (partition_1,partition_2) " + "clustered by (cluster_1,cluster_2) into 1 buckets " + "location '/mnt/root/some_table' " + "comment 'Description Test' " + "tblproperties ('delta.appendOnly' = 'true' ) " + "as select 1" + ) + + assert sql == expected diff --git a/tests/unit/macros/test_adapters_macros.py b/tests/unit/macros/test_adapters_macros.py index c86785e07..bce535f0f 100644 --- a/tests/unit/macros/test_adapters_macros.py +++ b/tests/unit/macros/test_adapters_macros.py @@ -1,347 +1,103 @@ from mock import MagicMock -from dbt.adapters.databricks.relation import DatabricksRelation +import pytest -from tests.unit.macros.base import TestMacros +from tests.unit.macros.base import MacroTestBase -class TestAdaptersMacros(TestMacros): - def setUp(self): - super().setUp() - self.template = self._get_template("adapters.sql") +class TestDatabricksMacros(MacroTestBase): + @pytest.fixture + def template_name(self) -> str: + return "adapters.sql" - def _render_create_table_as(self, relation="my_table", temporary=False, sql="select 1"): - self.default_context["model"].alias = relation - - return self._run_macro("databricks__create_table_as", temporary, relation, sql) - - -class TestSparkMacros(TestAdaptersMacros): - def test_macros_create_table_as(self): - sql = self._render_create_table_as() - - self.assertEqual(sql, "create or replace table my_table using delta as select 1") - - def test_macros_create_table_as_file_format(self): - for format in ["parquet", "hudi"]: - self.config["file_format"] = format - sql = self._render_create_table_as() - self.assertEqual(sql, f"create table my_table using {format} as select 1") - - def test_macros_create_table_as_options(self): - self.config["options"] = {"compression": "gzip"} - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - 'using delta options (compression "gzip" ) as select 1', - ) - - def test_macros_create_table_as_hudi_unique_key(self): - self.config["file_format"] = "hudi" - self.config["unique_key"] = "id" - sql = self._render_create_table_as(sql="select 1 as id") - - self.assertEqual( - sql, - 'create table my_table using hudi options (primaryKey "id" ) as select 1 as id', - ) - - def test_macros_create_table_as_hudi_unique_key_primary_key_match(self): - self.config["file_format"] = "hudi" - self.config["unique_key"] = "id" - self.config["options"] = {"primaryKey": "id"} - sql = self._render_create_table_as(sql="select 1 as id") - - self.assertEqual( - sql, - 'create table my_table using hudi options (primaryKey "id" ) as select 1 as id', - ) - - def test_macros_create_table_as_hudi_unique_key_primary_key_mismatch(self): - self.config["file_format"] = "hudi" - self.config["unique_key"] = "uuid" - self.config["options"] = {"primaryKey": "id"} - sql = self._render_create_table_as(sql="select 1 as id, 2 as uuid") - self.assertIn("mock.raise_compiler_error()", sql) - - def test_macros_create_table_as_partition(self): - self.config["partition_by"] = "partition_1" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table using delta partitioned by (partition_1) as select 1", - ) - - def test_macros_create_table_as_partitions(self): - self.config["partition_by"] = ["partition_1", "partition_2"] - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta partitioned by (partition_1,partition_2) as select 1", - ) - - def test_macros_create_table_as_cluster(self): - self.config["clustered_by"] = "cluster_1" - self.config["buckets"] = "1" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta clustered by (cluster_1) into 1 buckets as select 1", + def test_macros_create_view_as_tblproperties(self, config, template_bundle): + config["tblproperties"] = {"tblproperties_to_view": "true"} + template_bundle.context["model"].alias = "my_table" + template_bundle.context["get_columns_in_query"] = MagicMock(return_value=[]) + sql = self.run_macro( + template_bundle.template, "databricks__create_view_as", "my_table", "select 1" ) - - def test_macros_create_table_as_clusters(self): - self.config["clustered_by"] = ["cluster_1", "cluster_2"] - self.config["buckets"] = "1" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta clustered by (cluster_1,cluster_2) into 1 buckets as select 1", - ) - - def test_macros_create_table_as_liquid_cluster(self): - self.config["liquid_clustered_by"] = "cluster_1" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " "using delta cluster by (cluster_1) as select 1", - ) - - def test_macros_create_table_as_liquid_clusters(self): - self.config["liquid_clustered_by"] = ["cluster_1", "cluster_2"] - self.config["buckets"] = "1" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta cluster by (cluster_1,cluster_2) as select 1", - ) - - def test_macros_create_table_as_location(self): - self.config["location_root"] = "/mnt/root" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta location '/mnt/root/my_table' as select 1", - ) - - def test_macros_create_table_as_comment(self): - self.config["persist_docs"] = {"relation": True} - self.default_context["model"].description = "Description Test" - - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta comment 'Description Test' as select 1", - ) - - def test_macros_create_table_as_tblproperties(self): - self.config["tblproperties"] = {"delta.appendOnly": "true"} - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta tblproperties ('delta.appendOnly' = 'true' ) as select 1", - ) - - def test_macros_create_table_as_all_delta(self): - self.config["location_root"] = "/mnt/root" - self.config["partition_by"] = ["partition_1", "partition_2"] - self.config["liquid_clustered_by"] = ["cluster_1", "cluster_2"] - self.config["clustered_by"] = ["cluster_1", "cluster_2"] - self.config["buckets"] = "1" - self.config["persist_docs"] = {"relation": True} - self.config["tblproperties"] = {"delta.appendOnly": "true"} - self.default_context["model"].description = "Description Test" - - self.config["file_format"] = "delta" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create or replace table my_table " - "using delta " - "partitioned by (partition_1,partition_2) " - "cluster by (cluster_1,cluster_2) " - "clustered by (cluster_1,cluster_2) into 1 buckets " - "location '/mnt/root/my_table' " - "comment 'Description Test' " - "tblproperties ('delta.appendOnly' = 'true' ) " - "as select 1", - ) - - def test_macros_create_table_as_all_hudi(self): - self.config["location_root"] = "/mnt/root" - self.config["partition_by"] = ["partition_1", "partition_2"] - self.config["clustered_by"] = ["cluster_1", "cluster_2"] - self.config["buckets"] = "1" - self.config["persist_docs"] = {"relation": True} - self.config["tblproperties"] = {"delta.appendOnly": "true"} - self.default_context["model"].description = "Description Test" - - self.config["file_format"] = "hudi" - sql = self._render_create_table_as() - - self.assertEqual( - sql, - "create table my_table " - "using hudi " - "partitioned by (partition_1,partition_2) " - "clustered by (cluster_1,cluster_2) into 1 buckets " - "location '/mnt/root/my_table' " - "comment 'Description Test' " - "tblproperties ('delta.appendOnly' = 'true' ) " - "as select 1", - ) - - def test_macros_create_view_as_tblproperties(self): - self.config["tblproperties"] = {"tblproperties_to_view": "true"} - self.default_context["model"].alias = "my_table" - self.default_context["get_columns_in_query"] = MagicMock(return_value=[]) - sql = self._run_macro("databricks__create_view_as", "my_table", "select 1") - - self.assertEqual( - sql, + expected = ( "create or replace view my_table " - "tblproperties ('tblproperties_to_view' = 'true' ) as select 1", - ) - - -class TestDatabricksMacros(TestAdaptersMacros): - def setUp(self): - super().setUp() - data = { - "path": { - "database": "some_database", - "schema": "some_schema", - "identifier": "some_table", - }, - "type": None, - } - - self.relation = DatabricksRelation.from_dict(data) - - def test_macros_create_table_as(self): - sql = self._render_create_table_as(self.relation) - - self.assertEqual( - sql, - ( - "create or replace table " - "`some_database`.`some_schema`.`some_table` " - "using delta as select 1" - ), + "tblproperties ('tblproperties_to_view' = 'true' ) as select 1" ) - def __render_relation_macro(self, name, *args): - self.default_context["model"].alias = self.relation - - return self._run_macro(name, self.relation, *args) - - def test_macros_get_optimize_sql(self): - self.config["zorder"] = "foo" - sql = self.__render_relation_macro("get_optimize_sql") - - self.assertEqual( - sql, - ("optimize " "`some_database`.`some_schema`.`some_table` " "zorder by (foo)"), - ) + assert sql == expected - def test_macro_get_optimize_sql_multiple_args(self): - self.config["zorder"] = ["foo", "bar"] - sql = self.__render_relation_macro("get_optimize_sql") + def test_macros_get_optimize_sql(self, config, template_bundle): + config["zorder"] = "foo" + sql = self.render_bundle(template_bundle, "get_optimize_sql") - self.assertEqual( - sql, - ("optimize " "`some_database`.`some_schema`.`some_table` " "zorder by ( foo, bar )"), - ) + assert sql == "optimize `some_database`.`some_schema`.`some_table` zorder by (foo)" - def test_macros_optimize_with_extraneous_info(self): - self.config["zorder"] = ["foo", "bar"] - self.var["FOO"] = True - r = self.__render_relation_macro("optimize") + def test_macro_get_optimize_sql_multiple_args(self, config, template_bundle): + config["zorder"] = ["foo", "bar"] + sql = self.render_bundle(template_bundle, "get_optimize_sql") - self.assertEqual( - r, - "run_optimize_stmt", - ) + assert sql == "optimize `some_database`.`some_schema`.`some_table` zorder by ( foo, bar )" - def test_macros_optimize_with_skip(self): - for key_val in ["DATABRICKS_SKIP_OPTIMIZE", "databricks_skip_optimize"]: - self.var[key_val] = True - r = self.__render_relation_macro("optimize") + def test_macros_optimize_with_extraneous_info(self, config, var, template_bundle): + config["zorder"] = ["foo", "bar"] + var["FOO"] = True + result = self.render_bundle(template_bundle, "optimize") - self.assertEqual( - r, - "", # should skip - ) + assert result == "run_optimize_stmt" - del self.var[key_val] + @pytest.mark.parametrize("key_val", ["DATABRICKS_SKIP_OPTIMIZE", "databricks_skip_optimize"]) + def test_macros_optimize_with_skip(self, key_val, var, template_bundle): + var[key_val] = True + r = self.render_bundle(template_bundle, "optimize") - def __render_constraints(self, *args): - self.default_context["model"].alias = self.relation + assert r == "" - return self._run_macro("databricks_constraints_to_dbt", *args) + def render_constraints(self, template, *args): + return self.run_macro(template, "databricks_constraints_to_dbt", *args) - def test_macros_databricks_constraints_to_dbt(self): + def test_macros_databricks_constraints_to_dbt(self, template): constraint = {"name": "name", "condition": "id > 0"} - r = self.__render_constraints([constraint]) + r = self.render_constraints(template, [constraint]) - self.assertEquals(r, "[{'name': 'name', 'type': 'check', 'expression': 'id > 0'}]") + assert r == "[{'name': 'name', 'type': 'check', 'expression': 'id > 0'}]" - def test_macros_databricks_constraints_missing_name(self): + def test_macros_databricks_constraints_missing_name(self, template): constraint = {"condition": "id > 0"} - r = self.__render_constraints([constraint]) + r = self.render_constraints(template, [constraint]) assert "raise_compiler_error" in r - def test_macros_databricks_constraints_missing_condition(self): + def test_macros_databricks_constraints_missing_condition(self, template): constraint = {"name": "name", "condition": ""} - r = self.__render_constraints([constraint]) + r = self.render_constraints(template, [constraint]) assert "raise_compiler_error" in r - def test_macros_databricks_constraints_with_type(self): + def test_macros_databricks_constraints_with_type(self, template): constraint = {"type": "check", "name": "name", "expression": "id > 0"} - r = self.__render_constraints([constraint]) + r = self.render_constraints(template, [constraint]) - self.assertEquals(r, "[{'type': 'check', 'name': 'name', 'expression': 'id > 0'}]") + assert r == "[{'type': 'check', 'name': 'name', 'expression': 'id > 0'}]" - def test_macros_databricks_constraints_with_column_missing_expression(self): + def test_macros_databricks_constraints_with_column_missing_expression(self, template): column = {"name": "col"} constraint = {"name": "name", "condition": "id > 0"} - r = self.__render_constraints([constraint], column) + r = self.render_constraints(template, [constraint], column) assert "raise_compiler_error" in r - def test_macros_databricks_constraints_with_column_and_expression(self): + def test_macros_databricks_constraints_with_column_and_expression(self, template): column = {"name": "col"} constraint = {"type": "check", "name": "name", "expression": "id > 0"} - r = self.__render_constraints([constraint], column) + r = self.render_constraints(template, [constraint], column) - self.assertEquals(r, "[{'type': 'check', 'name': 'name', 'expression': 'id > 0'}]") + assert r == "[{'type': 'check', 'name': 'name', 'expression': 'id > 0'}]" - def test_macros_databricks_constraints_with_column_not_null(self): + def test_macros_databricks_constraints_with_column_not_null(self, template): column = {"name": "col"} constraint = "not_null" - r = self.__render_constraints([constraint], column) + r = self.render_constraints(template, [constraint], column) - self.assertEquals(r, "[{'type': 'not_null', 'columns': ['col']}]") + assert r == "[{'type': 'not_null', 'columns': ['col']}]" - def __constraint_model(self): + @pytest.fixture(scope="class") + def constraint_model(self): columns = { "id": {"name": "id", "data_type": "int"}, "name": {"name": "name", "data_type": "string"}, @@ -351,180 +107,182 @@ def __constraint_model(self): "constraints": [{"type": "not_null", "columns": ["id", "name"]}], } - def __render_model_constraints(self, model): - self.default_context["model"].alias = self.relation + def render_model_constraints(self, template, model): + return self.run_macro(template, "get_model_constraints", model) - return self._run_macro("get_model_constraints", model) - - def test_macros_get_model_constraints(self): - model = self.__constraint_model() - r = self.__render_model_constraints(model) + def test_macros_get_model_constraints(self, template, constraint_model): + r = self.render_model_constraints(template, constraint_model) expected = "[{'type': 'not_null', 'columns': ['id', 'name']}]" assert expected in r - def test_macros_get_model_constraints_persist(self): - self.config["persist_constraints"] = True - model = self.__constraint_model() - r = self.__render_model_constraints(model) + def test_macros_get_model_constraints_persist(self, config, template, constraint_model): + config["persist_constraints"] = True + r = self.render_model_constraints(template, constraint_model) expected = "[{'type': 'not_null', 'columns': ['id', 'name']}]" assert expected in r - def test_macros_get_model_constraints_persist_with_meta(self): - self.config["persist_constraints"] = True - model = self.__constraint_model() - model["meta"] = {"constraints": [{"type": "foo"}]} - r = self.__render_model_constraints(model) + def test_macros_get_model_constraints_persist_with_meta( + self, config, template, constraint_model + ): + config["persist_constraints"] = True + constraint_model["meta"] = {"constraints": [{"type": "foo"}]} + r = self.render_model_constraints(template, constraint_model) expected = "[{'type': 'foo'}]" assert expected in r - def test_macros_get_model_constraints_no_persist_with_meta(self): - self.config["persist_constraints"] = False - model = self.__constraint_model() - model["meta"] = {"constraints": [{"type": "foo"}]} - r = self.__render_model_constraints(model) + def test_macros_get_model_constraints_no_persist_with_meta( + self, config, template, constraint_model + ): + config["persist_constraints"] = False + constraint_model["meta"] = {"constraints": [{"type": "foo"}]} + r = self.render_model_constraints(template, constraint_model) expected = "[{'type': 'not_null', 'columns': ['id', 'name']}]" assert expected in r - def __render_column_constraints(self, column): - self.default_context["model"].alias = self.relation - - return self._run_macro("get_column_constraints", column) + def render_column_constraints(self, template, column): + return self.run_macro(template, "get_column_constraints", column) - def test_macros_get_column_constraints(self): + def test_macros_get_column_constraints(self, template): column = {"name": "id"} - r = self.__render_column_constraints(column) + r = self.render_column_constraints(template, column) - self.assertEqual(r, "[]") + assert r == "[]" - def test_macros_get_column_constraints_empty(self): + def test_macros_get_column_constraints_empty(self, config, template): column = {"name": "id"} column["constraints"] = [] - self.config["persist_constraints"] = True - r = self.__render_column_constraints(column) + config["persist_constraints"] = True + r = self.render_column_constraints(template, column) - self.assertEqual(r, "[]") + assert r == "[]" - def test_macros_get_column_constraints_non_null(self): + def test_macros_get_column_constraints_non_null(self, config, template): column = {"name": "id"} column["constraints"] = [{"type": "non_null"}] - self.config["persist_constraints"] = True - r = self.__render_column_constraints(column) + config["persist_constraints"] = True + r = self.render_column_constraints(template, column) - self.assertEqual(r, "[{'type': 'non_null'}]") + r == "[{'type': 'non_null'}]" - def test_macros_get_column_constraints_invalid_meta(self): + def test_macros_get_column_constraints_invalid_meta(self, config, template): column = {"name": "id"} column["constraints"] = [{"type": "non_null"}] - self.config["persist_constraints"] = True + config["persist_constraints"] = True column["meta"] = {"constraint": "foo"} - r = self.__render_column_constraints(column) + r = self.render_column_constraints(template, column) assert "raise_compiler_error" in r - def test_macros_get_column_constraints_valid_meta(self): + def test_macros_get_column_constraints_valid_meta(self, config, template): column = {"name": "id"} column["constraints"] = [{"type": "non_null"}] - self.config["persist_constraints"] = True + config["persist_constraints"] = True column["meta"] = {"constraint": "not_null"} - r = self.__render_column_constraints(column) + r = self.render_column_constraints(template, column) - self.assertEqual(r, "[{'type': 'not_null', 'columns': ['id']}]") + assert r == "[{'type': 'not_null', 'columns': ['id']}]" - def test_macros_get_column_constraints_no_persist(self): + def test_macros_get_column_constraints_no_persist(self, config, template): column = {"name": "id"} column["constraints"] = [{"type": "non_null"}] - self.config["persist_constraints"] = False - r = self.__render_column_constraints(column) - - self.assertEqual(r, "[{'type': 'non_null'}]") - - def __render_constraint_sql(self, constraint, *args): - self.default_context["model"].alias = self.relation - - return self._run_macro("get_constraint_sql", self.relation, constraint, *args) + config["persist_constraints"] = False + r = self.render_column_constraints(template, column) + + r == "[{'type': 'non_null'}]" + + def render_constraint_sql(self, template_bundle, constraint, *args): + return self.run_macro( + template_bundle.template, + "get_constraint_sql", + template_bundle.relation, + constraint, + *args + ) - def __model(self): + @pytest.fixture(scope="class") + def model(self): columns = { "id": {"name": "id", "data_type": "int"}, "name": {"name": "name", "data_type": "string"}, } return {"columns": columns} - def test_macros_get_constraint_sql_not_null_with_columns(self): - model = self.__model() - r = self.__render_constraint_sql({"type": "not_null", "columns": ["id", "name"]}, model) + def test_macros_get_constraint_sql_not_null_with_columns(self, template_bundle, model): + r = self.render_constraint_sql( + template_bundle, {"type": "not_null", "columns": ["id", "name"]}, model + ) expected = ( "['alter table `some_database`.`some_schema`.`some_table` change column id " "set not null ;', 'alter table `some_database`.`some_schema`.`some_table` " "change column name set not null ;']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_not_null_with_column(self): - model = self.__model() - r = self.__render_constraint_sql({"type": "not_null"}, model, model["columns"]["id"]) + def test_macros_get_constraint_sql_not_null_with_column(self, template_bundle, model): + r = self.render_constraint_sql( + template_bundle, {"type": "not_null"}, model, model["columns"]["id"] + ) expected = ( "['alter table `some_database`.`some_schema`.`some_table` change column id " "set not null ;']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_not_null_mismatched_columns(self): - model = self.__model() - r = self.__render_constraint_sql( - {"type": "not_null", "columns": ["name"]}, model, model["columns"]["id"] + def test_macros_get_constraint_sql_not_null_mismatched_columns(self, template_bundle, model): + r = self.render_constraint_sql( + template_bundle, + {"type": "not_null", "columns": ["name"]}, + model, + model["columns"]["id"], ) expected = ( "['alter table `some_database`.`some_schema`.`some_table` change column name " "set not null ;']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_check(self): - model = self.__model() + def test_macros_get_constraint_sql_check(self, template_bundle, model): constraint = { "type": "check", "expression": "id != name", "name": "myconstraint", "columns": ["id", "name"], } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint check (id != name);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_check_named_constraint(self): - model = self.__model() + def test_macros_get_constraint_sql_check_named_constraint(self, template_bundle, model): constraint = { "type": "check", "expression": "id != name", "name": "myconstraint", } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint check (id != name);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_check_none_constraint(self): - model = self.__model() + def test_macros_get_constraint_sql_check_none_constraint(self, template_bundle, model): constraint = { "type": "check", "expression": "id != name", } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint None " @@ -532,81 +290,77 @@ def test_macros_get_constraint_sql_check_none_constraint(self): ) # noqa: E501 assert expected in r - def test_macros_get_constraint_sql_check_missing_expression(self): - model = self.__model() + def test_macros_get_constraint_sql_check_missing_expression(self, template_bundle, model): constraint = { "type": "check", "expression": "", "name": "myconstraint", } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) assert "raise_compiler_error" in r - def test_macros_get_constraint_sql_primary_key(self): - model = self.__model() + def test_macros_get_constraint_sql_primary_key(self, template_bundle, model): constraint = { "type": "primary_key", "name": "myconstraint", "columns": ["name"], } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint primary key(name);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_primary_key_with_specified_column(self): - model = self.__model() + def test_macros_get_constraint_sql_primary_key_with_specified_column( + self, template_bundle, model + ): constraint = { "type": "primary_key", "name": "myconstraint", "columns": ["name"], } column = {"name": "id"} - r = self.__render_constraint_sql(constraint, model, column) + r = self.render_constraint_sql(template_bundle, constraint, model, column) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint primary key(name);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_primary_key_with_name(self): - model = self.__model() + def test_macros_get_constraint_sql_primary_key_with_name(self, template_bundle, model): constraint = { "type": "primary_key", "name": "myconstraint", } column = {"name": "id"} - r = self.__render_constraint_sql(constraint, model, column) + r = self.render_constraint_sql(template_bundle, constraint, model, column) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint primary key(id);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_foreign_key(self): - model = self.__model() + def test_macros_get_constraint_sql_foreign_key(self, template_bundle, model): constraint = { "type": "foreign_key", "name": "myconstraint", "columns": ["name"], "parent": "parent_table", } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add " "constraint myconstraint foreign key(name) references " "some_schema.parent_table;']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_foreign_key_parent_column(self): - model = self.__model() + def test_macros_get_constraint_sql_foreign_key_parent_column(self, template_bundle, model): constraint = { "type": "foreign_key", "name": "myconstraint", @@ -614,17 +368,16 @@ def test_macros_get_constraint_sql_foreign_key_parent_column(self): "parent": "parent_table", "parent_columns": ["parent_name"], } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add " "constraint myconstraint foreign key(name) references " "some_schema.parent_table(parent_name);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_foreign_key_multiple_columns(self): - model = self.__model() + def test_macros_get_constraint_sql_foreign_key_multiple_columns(self, template_bundle, model): constraint = { "type": "foreign_key", "name": "myconstraint", @@ -632,17 +385,18 @@ def test_macros_get_constraint_sql_foreign_key_multiple_columns(self): "parent": "parent_table", "parent_columns": ["parent_name", "parent_id"], } - r = self.__render_constraint_sql(constraint, model) + r = self.render_constraint_sql(template_bundle, constraint, model) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint foreign key(name, id) " "references some_schema.parent_table(parent_name, parent_id);']" - ) # noqa: E501 + ) assert expected in r - def test_macros_get_constraint_sql_foreign_key_columns_supplied_separately(self): - model = self.__model() + def test_macros_get_constraint_sql_foreign_key_columns_supplied_separately( + self, template_bundle, model + ): constraint = { "type": "foreign_key", "name": "myconstraint", @@ -650,11 +404,11 @@ def test_macros_get_constraint_sql_foreign_key_columns_supplied_separately(self) "parent_columns": ["parent_name"], } column = {"name": "id"} - r = self.__render_constraint_sql(constraint, model, column) + r = self.render_constraint_sql(template_bundle, constraint, model, column) expected = ( "['alter table `some_database`.`some_schema`.`some_table` add constraint " "myconstraint foreign key(id) references " "some_schema.parent_table(parent_name);']" - ) # noqa: E501 + ) assert expected in r diff --git a/tests/unit/macros/test_python_macros.py b/tests/unit/macros/test_python_macros.py index 399b5bce1..246b520c0 100644 --- a/tests/unit/macros/test_python_macros.py +++ b/tests/unit/macros/test_python_macros.py @@ -1,79 +1,86 @@ from mock import MagicMock -from tests.unit.macros.base import TestMacros +from tests.unit.macros.base import MacroTestBase +import pytest -class TestPythonMacros(TestMacros): - def setUp(self): - TestMacros.setUp(self) - self.default_context["model"] = MagicMock() - self.template = self._get_template("python.sql", "adapters.sql") - def test_py_get_writer__default_file_format(self): - result = self._run_macro_raw("py_get_writer_options") +class TestPythonMacros(MacroTestBase): + @pytest.fixture(scope="class", autouse=True) + def modify_context(self, default_context) -> dict: + default_context["model"] = MagicMock() + d = {"alias": "schema"} + default_context["model"].__getitem__.side_effect = d.__getitem__ - self.assertEqual(result, '.format("delta")') + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "python.sql" - def test_py_get_writer__specified_file_format(self): - self.config["file_format"] = "parquet" - result = self._run_macro_raw("py_get_writer_options") + @pytest.fixture(scope="class") + def databricks_template_names(self) -> list: + return ["adapters.sql"] - self.assertEqual(result, '.format("parquet")') + def test_py_get_writer__default_file_format(self, template): + result = self.run_macro_raw(template, "py_get_writer_options") - def test_py_get_writer__specified_location_root(self): - self.config["location_root"] = "s3://fake_location" - d = {"alias": "schema"} - self.default_context["model"].__getitem__.side_effect = d.__getitem__ - self.default_context["is_incremental"] = MagicMock(return_value=False) - result = self._run_macro_raw("py_get_writer_options") + assert result == '.format("delta")' + + def test_py_get_writer__specified_file_format(self, config, template): + config["file_format"] = "parquet" + result = self.run_macro_raw(template, "py_get_writer_options") + + assert result == '.format("parquet")' + + def test_py_get_writer__specified_location_root(self, config, template, context): + config["location_root"] = "s3://fake_location" + context["is_incremental"] = MagicMock(return_value=False) + result = self.run_macro_raw(template, "py_get_writer_options") expected = '.format("delta")\n.option("path", "s3://fake_location/schema")' - self.assertEqual(result, expected) + assert result == expected - def test_py_get_writer__specified_location_root_on_incremental(self): - self.config["location_root"] = "s3://fake_location" - d = {"alias": "schema"} - self.default_context["model"].__getitem__.side_effect = d.__getitem__ - self.default_context["is_incremental"] = MagicMock(return_value=True) - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__specified_location_root_on_incremental(self, config, template, context): + config["location_root"] = "s3://fake_location" + context["is_incremental"] = MagicMock(return_value=True) + result = self.run_macro_raw(template, "py_get_writer_options") expected = '.format("delta")\n.option("path", "s3://fake_location/schema__dbt_tmp")' - self.assertEqual(result, expected) + assert result == expected - def test_py_get_writer__partition_by_single_column(self): - self.config["partition_by"] = "name" - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__partition_by_single_column(self, config, template): + config["partition_by"] = "name" + result = self.run_macro_raw(template, "py_get_writer_options") expected = ".format(\"delta\")\n.partitionBy(['name'])" - self.assertEqual(result, expected) + assert result == expected - def test_py_get_writer__partition_by_array(self): - self.config["partition_by"] = ["name", "date"] - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__partition_by_array(self, config, template): + config["partition_by"] = ["name", "date"] + result = self.run_macro_raw(template, "py_get_writer_options") - self.assertEqual(result, (".format(\"delta\")\n.partitionBy(['name', 'date'])")) + assert result == ".format(\"delta\")\n.partitionBy(['name', 'date'])" - def test_py_get_writer__clustered_by_single_column(self): - self.config["clustered_by"] = "name" - self.config["buckets"] = 2 - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__clustered_by_single_column(self, config, template): + config["clustered_by"] = "name" + config["buckets"] = 2 + result = self.run_macro_raw(template, "py_get_writer_options") - self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name'])")) + assert result == ".format(\"delta\")\n.bucketBy(2, ['name'])" - def test_py_get_writer__clustered_by_array(self): - self.config["clustered_by"] = ["name", "date"] - self.config["buckets"] = 2 - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__clustered_by_array(self, config, template): + config["clustered_by"] = ["name", "date"] + config["buckets"] = 2 + result = self.run_macro_raw(template, "py_get_writer_options") - self.assertEqual(result, (".format(\"delta\")\n.bucketBy(2, ['name', 'date'])")) + assert result == ".format(\"delta\")\n.bucketBy(2, ['name', 'date'])" - def test_py_get_writer__clustered_by_without_buckets(self): - self.config["clustered_by"] = ["name", "date"] - result = self._run_macro_raw("py_get_writer_options") + def test_py_get_writer__clustered_by_without_buckets(self, config, template): + config["clustered_by"] = ["name", "date"] + result = self.run_macro_raw(template, "py_get_writer_options") - self.assertEqual(result, ('.format("delta")')) + assert result == '.format("delta")' - def test_py_try_import__golden_path(self): - result = self._run_macro_raw("py_try_import", "pandas", "pandas_available") + def test_py_try_import__golden_path(self, template): + result = self.run_macro_raw(template, "py_try_import", "pandas", "pandas_available") expected = ( "# make sure pandas exists before using it\n" @@ -83,4 +90,4 @@ def test_py_try_import__golden_path(self): "except ImportError:\n" " pandas_available = False\n" ) - self.assertEqual(result, expected) + assert result == expected From 6725490052846517af4802f56492da597900e836 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Mon, 4 Dec 2023 16:23:36 -0800 Subject: [PATCH 2/3] fix linting and hopefully tests --- .python-version | 1 + tests/unit/macros/base.py | 88 +++++++++++++++++-- .../macros/relations/test_table_macros.py | 3 - tests/unit/macros/test_adapters_macros.py | 2 +- tests/unit/macros/test_python_macros.py | 10 ++- 5 files changed, 89 insertions(+), 15 deletions(-) create mode 100644 .python-version diff --git a/.python-version b/.python-version new file mode 100644 index 000000000..2c0733315 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/tests/unit/macros/base.py b/tests/unit/macros/base.py index 2df25c642..3d59d22bd 100644 --- a/tests/unit/macros/base.py +++ b/tests/unit/macros/base.py @@ -1,4 +1,5 @@ import re +from typing import Any, Dict from mock import Mock import pytest from jinja2 import Environment, FileSystemLoader, PackageLoader, Template @@ -15,18 +16,27 @@ def __init__(self, template, context, relation): class MacroTestBase: @pytest.fixture(autouse=True) def config(self, context) -> dict: - local_config = {} + """ + Anything you put in this dict will be returned by config in the rendered template + """ + local_config: Dict[str, Any] = {} context["config"].get = lambda key, default=None, **kwargs: local_config.get(key, default) return local_config @pytest.fixture(autouse=True) def var(self, context) -> dict: - local_var = {} + """ + Anything you put in this dict will be returned by config in the rendered template + """ + local_var: Dict[str, Any] = {} context["var"] = lambda key, default=None, **kwargs: local_var.get(key, default) return local_var @pytest.fixture(scope="class") def default_context(self) -> dict: + """ + This is the default context used in all tests. + """ context = { "validation": Mock(), "model": Mock(), @@ -36,12 +46,16 @@ def default_context(self) -> dict: "adapter": Mock(), "var": Mock(), "return": lambda r: r, + "is_incremental": Mock(return_value=False), } return context @pytest.fixture(scope="class") def spark_env(self) -> Environment: + """ + The environment used for rendering dbt-spark macros + """ return Environment( loader=PackageLoader("dbt.include.spark", "macros"), extensions=["jinja2.ext.do"], @@ -49,18 +63,33 @@ def spark_env(self) -> Environment: @pytest.fixture(scope="class") def spark_template_names(self) -> list: + """ + The list of Spark templates to load for the tests. + Use this if your macro relies on macros defined in templates we inherit from dbt-spark. + """ return ["adapters.sql"] @pytest.fixture(scope="class") def spark_context(self, default_context, spark_env, spark_template_names) -> dict: + """ + Adds all the requested Spark macros to the context + """ return self.build_up_context(default_context, spark_env, spark_template_names) @pytest.fixture(scope="class") def macro_folders_to_load(self) -> list: + """ + This is a list of folders from which we look to load Databricks macro templates. + All folders are relative to the dbt/include/databricks folder. + Folders will be searched for in the order they are listed here, in case of name collisions. + """ return ["macros"] @pytest.fixture(scope="class") def databricks_env(self, macro_folders_to_load) -> Environment: + """ + The environment used for rendering Databricks macros + """ return Environment( loader=FileSystemLoader( [f"dbt/include/databricks/{folder}" for folder in macro_folders_to_load] @@ -70,15 +99,28 @@ def databricks_env(self, macro_folders_to_load) -> Environment: @pytest.fixture(scope="class") def databricks_template_names(self) -> list: + """ + The list of databricks templates to load for referencing imported macros in the + tests. Do not include the template you specify in template_name. Use this when you need a + macro defined in a template other than the one you render for the test. + + Ex: If you are testing the python.sql template, you will also need to load ["adapters.sql"] + """ return [] @pytest.fixture(scope="class") def databricks_context(self, spark_context, databricks_env, databricks_template_names) -> dict: + """ + Adds all the requested Databricks macros to the context + """ if not databricks_template_names: return spark_context return self.build_up_context(spark_context, databricks_env, databricks_template_names) def build_up_context(self, context, env, template_names): + """ + Adds macros from the supplied env and template names to the context. + """ new_context = context.copy() for template_name in template_names: template = env.get_template(template_name, globals=context) @@ -86,16 +128,22 @@ def build_up_context(self, context, env, template_names): return new_context - @pytest.fixture - def context(self, databricks_context) -> dict: - return databricks_context.copy() - @pytest.fixture(scope="class") def template_name(self) -> str: + """ + The name of the Databricks template you want to test, not including the path. + + Example: "adapters.sql" + """ raise NotImplementedError("Must be implemented by subclasses") @pytest.fixture - def template(self, template_name, context, databricks_env) -> Template: + def template(self, template_name, databricks_context, databricks_env) -> Template: + """ + This creates the template you will test against. + You generally don't want to override this. + """ + context = databricks_context.copy() current_template = databricks_env.get_template(template_name, globals=context) def dispatch(macro_name, macro_namespace=None, packages=None): @@ -110,8 +158,21 @@ def dispatch(macro_name, macro_namespace=None, packages=None): return current_template + @pytest.fixture + def context(self, template) -> dict: + """ + Access to the context used to render the template. + Modification of the context will work for mocking adapter calls, but may not work for + mocking macros. + If you need to mock a macro, see the use of is_incremental in default_context. + """ + return template.globals + @pytest.fixture(scope="class") def relation(self): + """ + Dummy relation to use in tests. + """ data = { "path": { "database": "some_database", @@ -125,15 +186,28 @@ def relation(self): @pytest.fixture def template_bundle(self, template, context, relation): + """ + Bundles up the compiled template, its context, and a dummy relation. + """ context["model"].alias = relation.identifier return TemplateBundle(template, context, relation) def run_macro_raw(self, template, name, *args): + """ + Run the named macro from a template, and return the rendered value. + """ return getattr(template.module, name)(*args) def run_macro(self, template, name, *args): + """ + Run the named macro from a template, and return the rendered value. + This version strips off extra whitespace and newlines. + """ value = self.run_macro_raw(template, name, *args) return re.sub(r"\s\s+", " ", value).strip() def render_bundle(self, template_bundle, name, *args): + """ + Convenience method for macros that take a relation as a first argument. + """ return self.run_macro(template_bundle.template, name, template_bundle.relation, *args) diff --git a/tests/unit/macros/relations/test_table_macros.py b/tests/unit/macros/relations/test_table_macros.py index b4ddef893..5e91a3b35 100644 --- a/tests/unit/macros/relations/test_table_macros.py +++ b/tests/unit/macros/relations/test_table_macros.py @@ -1,6 +1,3 @@ -from mock import Mock -from jinja2 import Environment, FileSystemLoader, PackageLoader -import re import pytest from tests.unit.macros.base import MacroTestBase diff --git a/tests/unit/macros/test_adapters_macros.py b/tests/unit/macros/test_adapters_macros.py index bce535f0f..e3850cd9e 100644 --- a/tests/unit/macros/test_adapters_macros.py +++ b/tests/unit/macros/test_adapters_macros.py @@ -200,7 +200,7 @@ def render_constraint_sql(self, template_bundle, constraint, *args): "get_constraint_sql", template_bundle.relation, constraint, - *args + *args, ) @pytest.fixture(scope="class") diff --git a/tests/unit/macros/test_python_macros.py b/tests/unit/macros/test_python_macros.py index 246b520c0..b59b2a4e6 100644 --- a/tests/unit/macros/test_python_macros.py +++ b/tests/unit/macros/test_python_macros.py @@ -1,3 +1,4 @@ +from jinja2 import Template from mock import MagicMock from tests.unit.macros.base import MacroTestBase @@ -6,7 +7,7 @@ class TestPythonMacros(MacroTestBase): @pytest.fixture(scope="class", autouse=True) - def modify_context(self, default_context) -> dict: + def modify_context(self, default_context) -> None: default_context["model"] = MagicMock() d = {"alias": "schema"} default_context["model"].__getitem__.side_effect = d.__getitem__ @@ -32,15 +33,16 @@ def test_py_get_writer__specified_file_format(self, config, template): def test_py_get_writer__specified_location_root(self, config, template, context): config["location_root"] = "s3://fake_location" - context["is_incremental"] = MagicMock(return_value=False) result = self.run_macro_raw(template, "py_get_writer_options") expected = '.format("delta")\n.option("path", "s3://fake_location/schema")' assert result == expected - def test_py_get_writer__specified_location_root_on_incremental(self, config, template, context): + def test_py_get_writer__specified_location_root_on_incremental( + self, config, template: Template, context + ): config["location_root"] = "s3://fake_location" - context["is_incremental"] = MagicMock(return_value=True) + context["is_incremental"].return_value = True result = self.run_macro_raw(template, "py_get_writer_options") expected = '.format("delta")\n.option("path", "s3://fake_location/schema__dbt_tmp")' From 499aa1b402bb187d28bd350a2421654002736ddb Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Tue, 5 Dec 2023 10:48:24 -0800 Subject: [PATCH 3/3] fixing copy-paste typo --- tests/unit/macros/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/macros/base.py b/tests/unit/macros/base.py index 3d59d22bd..cf8025929 100644 --- a/tests/unit/macros/base.py +++ b/tests/unit/macros/base.py @@ -26,7 +26,7 @@ def config(self, context) -> dict: @pytest.fixture(autouse=True) def var(self, context) -> dict: """ - Anything you put in this dict will be returned by config in the rendered template + Anything you put in this dict will be returned by var in the rendered template """ local_var: Dict[str, Any] = {} context["var"] = lambda key, default=None, **kwargs: local_var.get(key, default)