Skip to content

Commit

Permalink
Allow json_result_force_utf8_encoding specification in `providers.s…
Browse files Browse the repository at this point in the history
…nowflake.hooks.SnowflakeHook` extra dict (#44264)

* Allow json_result_force_utf8_encoding specification in SnowflakeHook extra dict

* Use a set for the not in
  • Loading branch information
ttzhou authored Nov 28, 2024
1 parent 335f64c commit 518d394
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
17 changes: 16 additions & 1 deletion providers/src/airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def _get_conn_params(self) -> dict[str, str | None]:
region = self._get_field(extra_dict, "region") or ""
role = self._get_field(extra_dict, "role") or ""
insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode"))
json_result_force_utf8_decoding = _try_to_boolean(
self._get_field(extra_dict, "json_result_force_utf8_decoding")
)
schema = conn.schema or ""
client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token"))

Expand All @@ -225,6 +228,9 @@ def _get_conn_params(self) -> dict[str, str | None]:
if insecure_mode:
conn_config["insecure_mode"] = insecure_mode

if json_result_force_utf8_decoding:
conn_config["json_result_force_utf8_decoding"] = json_result_force_utf8_decoding

if client_request_mfa_token:
conn_config["client_request_mfa_token"] = client_request_mfa_token

Expand Down Expand Up @@ -302,7 +308,13 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
for k, v in conn_params.items()
if v
and k
not in ["session_parameters", "insecure_mode", "private_key", "client_request_mfa_token"]
not in {
"session_parameters",
"insecure_mode",
"private_key",
"client_request_mfa_token",
"json_result_force_utf8_decoding",
}
}
)

Expand All @@ -324,6 +336,9 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
if "insecure_mode" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["insecure_mode"] = True
if "json_result_force_utf8_decoding" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = True
for key in ["session_parameters", "private_key"]:
if conn_params.get(key):
engine_kwargs.setdefault("connect_args", {})
Expand Down
22 changes: 22 additions & 0 deletions providers/tests/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "True",
"extra__snowflake__json_result_force_utf8_decoding": "True",
"extra__snowflake__client_request_mfa_token": "True",
},
},
Expand All @@ -158,6 +159,7 @@ class TestPytestSnowflakeHook:
"user": "user",
"warehouse": "af_wh",
"insecure_mode": True,
"json_result_force_utf8_decoding": True,
"client_request_mfa_token": True,
},
),
Expand All @@ -171,6 +173,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "False",
"extra__snowflake__json_result_force_utf8_decoding": "False",
"extra__snowflake__client_request_mfa_token": "False",
},
},
Expand Down Expand Up @@ -247,6 +250,7 @@ class TestPytestSnowflakeHook:
"extra": {
**BASE_CONNECTION_KWARGS["extra"],
"extra__snowflake__insecure_mode": False,
"extra__snowflake__json_result_force_utf8_decoding": True,
"extra__snowflake__client_request_mfa_token": False,
},
},
Expand All @@ -266,6 +270,7 @@ class TestPytestSnowflakeHook:
"session_parameters": None,
"user": "user",
"warehouse": "af_wh",
"json_result_force_utf8_decoding": True,
},
),
],
Expand Down Expand Up @@ -473,6 +478,23 @@ def test_get_sqlalchemy_engine_should_support_insecure_mode(self):
)
assert mock_create_engine.return_value == conn

def test_get_sqlalchemy_engine_should_support_json_result_force_utf8_decoding(self):
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["extra__snowflake__json_result_force_utf8_decoding"] = "True"

with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine,
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn = hook.get_sqlalchemy_engine()
mock_create_engine.assert_called_once_with(
"snowflake://user:[email protected]_region/db/public"
"?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh",
connect_args={"json_result_force_utf8_decoding": True},
)
assert mock_create_engine.return_value == conn

def test_get_sqlalchemy_engine_should_support_session_parameters(self):
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM": "AA", "TEST_PARAM_B": 123}
Expand Down

0 comments on commit 518d394

Please sign in to comment.