Skip to content

Commit

Permalink
Expose write_pandas_kwargs (#57)
Browse files Browse the repository at this point in the history
Co-authored-by: nick-amplify <[email protected]>
  • Loading branch information
jrbourbeau and nick-amplify authored Mar 15, 2024
1 parent a56fe3f commit b5da59e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
10 changes: 7 additions & 3 deletions dask_snowflake/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Sequence
from typing import Optional, Sequence

import pandas as pd
import pyarrow as pa
Expand All @@ -26,6 +26,7 @@ def write_snowflake(
df: pd.DataFrame,
name: str,
connection_kwargs: dict,
write_pandas_kwargs: Optional[dict] = None,
):
connection_kwargs = {
**{"application": dask.config.get("snowflake.partner", "dask")},
Expand All @@ -39,6 +40,7 @@ def write_snowflake(
# NOTE: since ensure_db_exists uses uppercase for the table name
table_name=name.upper(),
quote_identifiers=False,
**(write_pandas_kwargs or {}),
)


Expand Down Expand Up @@ -73,6 +75,7 @@ def to_snowflake(
df: dd.DataFrame,
name: str,
connection_kwargs: dict,
write_pandas_kwargs: Optional[dict] = None,
compute: bool = True,
):
"""Write a Dask DataFrame to a Snowflake table.
Expand All @@ -90,7 +93,8 @@ def to_snowflake(
Whether or not to compute immediately. If ``True``, write DataFrame
partitions to Snowflake immediately. If ``False``, return a list of
delayed objects that can be computed later. Defaults to ``True``.
write_pandas_kwargs:
Additional keyword arguments that will be passed to ``snowflake.connector.pandas_tools.write_pandas``.
Examples
--------
Expand All @@ -115,7 +119,7 @@ def to_snowflake(
# right partner application ID.
ensure_db_exists(df._meta, name, connection_kwargs).compute()
parts = [
write_snowflake(partition, name, connection_kwargs)
write_snowflake(partition, name, connection_kwargs, write_pandas_kwargs)
for partition in df.to_delayed()
]
if compute:
Expand Down
23 changes: 23 additions & 0 deletions dask_snowflake/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ def test_arrow_options(table, connection_kwargs, client):
)


def test_write_pandas_kwargs(table, connection_kwargs, client):
to_snowflake(
ddf.repartition(npartitions=1), name=table, connection_kwargs=connection_kwargs
)
# Overwrite existing table
to_snowflake(
ddf.repartition(npartitions=1),
name=table,
connection_kwargs=connection_kwargs,
write_pandas_kwargs={"overwrite": True},
)

query = f"SELECT * FROM {table}"
df_out = read_snowflake(query, connection_kwargs=connection_kwargs, npartitions=2)
# FIXME: Why does read_snowflake return lower-case columns names?
df_out.columns = df_out.columns.str.upper()
# FIXME: We need to sort the DataFrame because paritions are written
# in a non-sequential order.
dd.utils.assert_eq(
df, df_out.sort_values(by="A").reset_index(drop=True), check_dtype=False
)


def test_application_id_default(table, connection_kwargs, monkeypatch):
# Patch Snowflake's normal connection mechanism with checks that
# the expected application ID is set
Expand Down

0 comments on commit b5da59e

Please sign in to comment.