diff --git a/dbt/adapters/snowflake/__version__.py b/dbt/adapters/snowflake/__version__.py index a6b977228..a9fe3c3ee 100644 --- a/dbt/adapters/snowflake/__version__.py +++ b/dbt/adapters/snowflake/__version__.py @@ -1 +1 @@ -version = "1.2.0a1" +version = "1.3.0a1" diff --git a/dbt/include/snowflake/macros/adapters.sql b/dbt/include/snowflake/macros/adapters.sql index c5f07ff5f..b672f475f 100644 --- a/dbt/include/snowflake/macros/adapters.sql +++ b/dbt/include/snowflake/macros/adapters.sql @@ -273,3 +273,8 @@ {{ snowflake_dml_explicit_transaction(truncate_dml) }} {%- endcall %} {% endmacro %} + +{% macro load_df_def() %} + global snowpark_session + load_df_function = snowpark_session.table +{% endmacro %} diff --git a/dbt/include/snowflake/macros/materializations/table.sql b/dbt/include/snowflake/macros/materializations/table.sql index 49f97069b..0d7f194ac 100644 --- a/dbt/include/snowflake/macros/materializations/table.sql +++ b/dbt/include/snowflake/macros/materializations/table.sql @@ -18,10 +18,33 @@ {{ drop_relation_if_exists(old_relation) }} {% endif %} - --build model - {% call statement('main') -%} - {{ create_table_as(false, target_relation, sql) }} - {%- endcall %} + {% if config.get('language', 'sql') == 'python' -%}} + {%- set proc_name = api.Relation.create(identifier=identifier ~ "__dbt_sp", + schema=schema, + database=database) -%} + {% set materialization_logic = py_materialize_as_table() %} + {% set setup_stored_proc = py_create_stored_procedure(proc_name, materialization_logic, model, sql) %} + + {% do log("Creating stored procedure: " ~ proc_name, info=true) %} + {% do run_query(setup_stored_proc) %} + {% do log("Finished creating stored procedure: " ~ proc_name, info=true) %} + + --build model + {% call statement('main') -%} + CALL {{ proc_name }}('{{ target_relation }}'); + + {%- endcall %} + + -- cleanup stuff + {% do run_query("drop procedure if exists " ~ proc_name ~ "(string)") %} + + {%- else -%} + --build model + {% call statement('main') -%} + {{ create_table_as(false, target_relation, sql) }} + {%- endcall %} + + {%- endif %} {{ run_hooks(post_hooks) }} @@ -32,3 +55,41 @@ {{ return({'relations': [target_relation]}) }} {% endmaterialization %} + +{% macro py_materialize_as_table(config) %} + +def materialize(session, df, target_relation): + df.write.mode("overwrite").save_as_table(target_relation) + +{% endmacro %} + +{% macro py_create_stored_procedure(proc_name, materialization_logic, model, user_supplied_logic) %} + +{% set packages = ['snowflake-snowpark-python'] + config.get('packages', []) %} + +CREATE OR REPLACE PROCEDURE {{ proc_name }} (target_relation STRING) +RETURNS STRING +LANGUAGE PYTHON +RUNTIME_VERSION = '3.8' +PACKAGES = ('{{ packages | join("', '") }}') +HANDLER = 'run' +AS +$$ + +snowpark_session = None + +{#-- can we wrap in 'def model:' here? or will formatting screw us? --#} +{{ user_supplied_logic }} + +{{ materialization_logic }} + +def run(session, target_relation): + global snowpark_session + snowpark_session = session + df = model(dbt) + materialize(session, df, target_relation) + return "OK" + +$$; + +{% endmacro %} diff --git a/setup.py b/setup.py index 332b5d68e..0319a9908 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def _get_dbt_core_version(): package_name = "dbt-snowflake" -package_version = "1.2.0a1" +package_version = "1.3.0a1" dbt_core_version = _get_dbt_core_version() description = """The Snowflake adapter plugin for dbt""" diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index fec10aa33..4a117b11f 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -12,6 +12,9 @@ from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests + + class TestSimpleMaterializationsSnowflake(BaseSimpleMaterializations): pass @@ -51,4 +54,8 @@ class TestSnapshotTimestampSnowflake(BaseSnapshotTimestamp): class TestBaseAdapterMethodSnowflake(BaseAdapterMethod): @pytest.fixture(scope="class") def equal_tables(self): - return ["MODEL", "EXPECTED"] \ No newline at end of file + return ["MODEL", "EXPECTED"] + + +class TestBasePythonModelSnowflake(BasePythonModelTests): + pass