Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENTERPRISE-1418] Add support for plain JWT authentication #1078

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240610-171026.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support JWT Authentication
time: 2024-06-10T17:10:26.421463-04:00
custom:
Author: llam15
Issue: 1079 726
37 changes: 31 additions & 6 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag


Expand All @@ -70,7 +70,7 @@ class SnowflakeAdapterResponse(AdapterResponse):
@dataclass
class SnowflakeCredentials(Credentials):
account: str
user: str
user: Optional[str] = None
warehouse: Optional[str] = None
role: Optional[str] = None
password: Optional[str] = None
Expand All @@ -96,15 +96,31 @@ class SnowflakeCredentials(Credentials):
reuse_connections: Optional[bool] = None

def __post_init__(self):
if self.authenticator != "oauth" and (
self.oauth_client_secret or self.oauth_client_id or self.token
):
if self.authenticator != "oauth" and (self.oauth_client_secret or self.oauth_client_id):
# the user probably forgot to set 'authenticator' like I keep doing
warn_or_error(
AdapterEventWarning(
base_msg="Authenticator is not set to oauth, but an oauth-only parameter is set! Did you mean to set authenticator: oauth?"
)
)

if self.authenticator not in ["oauth", "jwt"]:
if self.token:
warn_or_error(
AdapterEventWarning(
base_msg=(
"The token parameter was set, but the authenticator was "
"not set to 'oauth' or 'jwt'."
)
)
)

if not self.user:
# The user attribute is only optional if 'authenticator' is 'jwt' or 'oauth'
warn_or_error(
AdapterEventError(base_msg="Invalid profile: 'user' is a required property.")
)

self.account = self.account.replace("_", "-")

@property
Expand Down Expand Up @@ -146,6 +162,8 @@ def auth_args(self):
# Pull all of the optional authentication args for the connector,
# let connector handle the actual arg validation
result = {}
if self.user:
result["user"] = self.user
if self.password:
result["password"] = self.password
if self.host:
Expand Down Expand Up @@ -180,6 +198,14 @@ def auth_args(self):
)

result["token"] = token

elif self.authenticator == "jwt":
# If authenticator is 'jwt', then the 'token' value should be used
# unmodified. We expose this as 'jwt' in the profile, but the value
# passed into the snowflake.connect method should still be 'oauth'
result["token"] = self.token
result["authenticator"] = "oauth"

# enable id token cache for linux
result["client_store_temporary_credential"] = True
# enable mfa token cache for linux
Expand Down Expand Up @@ -346,7 +372,6 @@ def connect():

handle = snowflake.connector.connect(
account=creds.account,
user=creds.user,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
Expand Down
91 changes: 91 additions & 0 deletions tests/functional/oauth/test_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Please follow the instructions in test_oauth.py for instructions on how to set up
the security integration required to retrieve a JWT from Snowflake.
"""

import pytest
import os
from dbt.tests.util import run_dbt, check_relations_equal

from dbt.adapters.snowflake import SnowflakeCredentials

_MODELS__MODEL_1_SQL = """
select 1 as id
"""


_MODELS__MODEL_2_SQL = """
select 2 as id
"""


_MODELS__MODEL_3_SQL = """
select * from {{ ref('model_1') }}
union all
select * from {{ ref('model_2') }}
"""


_MODELS__MODEL_4_SQL = """
select 1 as id
union all
select 2 as id
"""


class TestSnowflakeJWT:
"""Tests that setting authenticator: jwt allows setting token to a plain JWT
that will be passed into the Snowflake connection without modification."""

@pytest.fixture(scope="class", autouse=True)
def access_token(self):
"""Because JWTs are short-lived, we need to get a fresh JWT via the refresh
token flow before running the test.

This fixture leverages the existing SnowflakeCredentials._get_access_token
method to retrieve a valid JWT from Snowflake.
"""
client_id = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_ID")
client_secret = os.getenv("SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET")
refresh_token = os.getenv("SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN")

credentials = SnowflakeCredentials(
account=os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
database="",
schema="",
authenticator="oauth",
oauth_client_id=client_id,
oauth_client_secret=client_secret,
token=refresh_token,
)

yield credentials._get_access_token()

@pytest.fixture(scope="class", autouse=True)
def dbt_profile_target(self, access_token):
"""A dbt_profile that has authenticator set to JWT, and token set to
a JWT accepted by Snowflake. Also omits the user, as the user attribute
is optional when the authenticator is set to JWT.
"""
return {
"type": "snowflake",
"threads": 4,
"account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
"database": os.getenv("SNOWFLAKE_TEST_DATABASE"),
"warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"),
"authenticator": "jwt",
"token": access_token,
}

@pytest.fixture(scope="class")
def models(self):
return {
"model_1.sql": _MODELS__MODEL_1_SQL,
"model_2.sql": _MODELS__MODEL_2_SQL,
"model_3.sql": _MODELS__MODEL_3_SQL,
"model_4.sql": _MODELS__MODEL_4_SQL,
}

def test_snowflake_basic(self, project):
run_dbt()
check_relations_equal(project.adapter, ["MODEL_3", "MODEL_4"])
32 changes: 32 additions & 0 deletions tests/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,38 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p
]
)

def test_authenticator_jwt_authentication(self):
self.config.credentials = self.config.credentials.replace(
authenticator="jwt", token="my-jwt-token", user=None
)
self.adapter = SnowflakeAdapter(self.config, get_context("spawn"))
conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config")

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls(
[
mock.call(
account="test-account",
autocommit=True,
client_session_keep_alive=False,
database="test_database",
role=None,
schema="public",
warehouse="test_warehouse",
authenticator="oauth",
token="my-jwt-token",
private_key=None,
application="dbt",
client_request_mfa_token=True,
client_store_temporary_credential=True,
insecure_mode=False,
session_parameters={},
reuse_connections=None,
)
]
)

def test_query_tag(self):
self.config.credentials = self.config.credentials.replace(
password="test_password", query_tag="test_query_tag"
Expand Down
Loading