Skip to content

Commit

Permalink
Add support for Snowflake Key Pair Authentication
Browse files Browse the repository at this point in the history
This PR adds support for Snowflake Key Pair Authentication.

Unit tests verify everything's getting passed through correctly.
  • Loading branch information
Alexander Yermakov authored and Alexander Yermakov committed Jan 17, 2019
1 parent e359a69 commit 438b352
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
30 changes: 30 additions & 0 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import dbt.compat
import dbt.exceptions
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger
Expand All @@ -29,6 +31,12 @@
'type': 'string',
'description': "Either 'externalbrowser', or a valid Okta url"
},
'private_key_path': {
'type': 'string',
},
'private_key_passphrase': {
'type': 'string',
},
'database': {
'type': 'string',
},
Expand Down Expand Up @@ -104,6 +112,11 @@ def open(cls, connection):
auth_args = {auth_key: credentials[auth_key]
for auth_key in ['user', 'password', 'authenticator']
if auth_key in credentials}

auth_args['private_key'] = cls._get_private_key(
credentials.get('private_key_path'),
credentials.get('private_key_passphrase'))

handle = snowflake.connector.connect(
account=credentials.account,
database=credentials.database,
Expand Down Expand Up @@ -163,6 +176,23 @@ def _split_queries(cls, sql):
split_query = snowflake.connector.util_text.split_statements(sql_buf)
return [part[0] for part in split_query]

@classmethod
def _get_private_key(cls, private_key_path, private_key_passphrase):
"""Get Snowflake private key by path or None."""
if private_key_path is None or private_key_passphrase is None:
return None

with open(private_key_path, 'rb') as key:
p_key = serialization.load_pem_private_key(
key.read(),
password=private_key_passphrase.encode(),
backend=default_backend())

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption())

def add_query(self, sql, model_name=None, auto_begin=True,
bindings=None, abridge_sql_log=False):

Expand Down
30 changes: 25 additions & 5 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mock import patch

import mock
import unittest

Expand Down Expand Up @@ -150,7 +152,7 @@ def test_client_session_keep_alive_false_by_default(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
warehouse='test_warehouse', private_key=None)
])

def test_client_session_keep_alive_true(self):
Expand All @@ -164,7 +166,7 @@ def test_client_session_keep_alive_true(self):
account='test_account', autocommit=False,
client_session_keep_alive=True, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
warehouse='test_warehouse', private_key=None)
])

def test_user_pass_authentication(self):
Expand All @@ -178,7 +180,7 @@ def test_user_pass_authentication(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse')
user='test_user', warehouse='test_warehouse', private_key=None)
])

def test_authenticator_user_pass_authentication(self):
Expand All @@ -193,7 +195,7 @@ def test_authenticator_user_pass_authentication(self):
client_session_keep_alive=False, database='test_databse',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse',
authenticator='test_sso_url')
authenticator='test_sso_url', private_key=None)
])

def test_authenticator_externalbrowser_authentication(self):
Expand All @@ -207,5 +209,23 @@ def test_authenticator_externalbrowser_authentication(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', authenticator='externalbrowser')
warehouse='test_warehouse', authenticator='externalbrowser',
private_key=None)
])

@patch('dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key', return_value='test_key')
def test_authenticator_private_key_authentication(self, mock_get_private_key):
self.config.credentials = self.config.credentials.incorporate(
private_key_path='/tmp/test_key.p8',
private_key_passphrase='p@ssphr@se')

self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.get(name='new_connection_with_new_config')

self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', private_key='test_key')
])

0 comments on commit 438b352

Please sign in to comment.