Skip to content

Commit

Permalink
Merge pull request #3 from dbt-labs/python-model
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenyuLInx authored Jun 28, 2022
2 parents 80c4001 + d61a212 commit 4219827
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/snowflake/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.2.0a1"
version = "1.3.0a1"
5 changes: 5 additions & 0 deletions dbt/include/snowflake/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,8 @@
{{ snowflake_dml_explicit_transaction(truncate_dml) }}
{%- endcall %}
{% endmacro %}

{% macro load_df_def() %}
global snowpark_session
load_df_function = snowpark_session.table
{% endmacro %}
69 changes: 65 additions & 4 deletions dbt/include/snowflake/macros/materializations/table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,33 @@
{{ drop_relation_if_exists(old_relation) }}
{% endif %}

--build model
{% call statement('main') -%}
{{ create_table_as(false, target_relation, sql) }}
{%- endcall %}
{% if config.get('language', 'sql') == 'python' -%}}
{%- set proc_name = api.Relation.create(identifier=identifier ~ "__dbt_sp",
schema=schema,
database=database) -%}
{% set materialization_logic = py_materialize_as_table() %}
{% set setup_stored_proc = py_create_stored_procedure(proc_name, materialization_logic, model, sql) %}

{% do log("Creating stored procedure: " ~ proc_name, info=true) %}
{% do run_query(setup_stored_proc) %}
{% do log("Finished creating stored procedure: " ~ proc_name, info=true) %}

--build model
{% call statement('main') -%}
CALL {{ proc_name }}('{{ target_relation }}');

{%- endcall %}

-- cleanup stuff
{% do run_query("drop procedure if exists " ~ proc_name ~ "(string)") %}

{%- else -%}
--build model
{% call statement('main') -%}
{{ create_table_as(false, target_relation, sql) }}
{%- endcall %}

{%- endif %}

{{ run_hooks(post_hooks) }}

Expand All @@ -32,3 +55,41 @@
{{ return({'relations': [target_relation]}) }}

{% endmaterialization %}

{% macro py_materialize_as_table(config) %}

def materialize(session, df, target_relation):
df.write.mode("overwrite").save_as_table(target_relation)

{% endmacro %}

{% macro py_create_stored_procedure(proc_name, materialization_logic, model, user_supplied_logic) %}

{% set packages = ['snowflake-snowpark-python'] + config.get('packages', []) %}

CREATE OR REPLACE PROCEDURE {{ proc_name }} (target_relation STRING)
RETURNS STRING
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('{{ packages | join("', '") }}')
HANDLER = 'run'
AS
$$

snowpark_session = None

{#-- can we wrap in 'def model:' here? or will formatting screw us? --#}
{{ user_supplied_logic }}

{{ materialization_logic }}

def run(session, target_relation):
global snowpark_session
snowpark_session = session
df = model(dbt)
materialize(session, df, target_relation)
return "OK"

$$;

{% endmacro %}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_dbt_core_version():


package_name = "dbt-snowflake"
package_version = "1.2.0a1"
package_version = "1.3.0a1"
dbt_core_version = _get_dbt_core_version()
description = """The Snowflake adapter plugin for dbt"""

Expand Down
9 changes: 8 additions & 1 deletion tests/functional/adapter/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols
from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp
from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod
from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests



class TestSimpleMaterializationsSnowflake(BaseSimpleMaterializations):
pass
Expand Down Expand Up @@ -51,4 +54,8 @@ class TestSnapshotTimestampSnowflake(BaseSnapshotTimestamp):
class TestBaseAdapterMethodSnowflake(BaseAdapterMethod):
@pytest.fixture(scope="class")
def equal_tables(self):
return ["MODEL", "EXPECTED"]
return ["MODEL", "EXPECTED"]


class TestBasePythonModelSnowflake(BasePythonModelTests):
pass

0 comments on commit 4219827

Please sign in to comment.