From 7f3be0fe58ec863679169e945441e01fcdc0e64c Mon Sep 17 00:00:00 2001 From: Ben Rifkind Date: Wed, 3 Feb 2021 21:45:06 -0700 Subject: [PATCH] feat: Add config key for connect_arg for SqlAlchemyExtractor (#434) * General connect_args for SqlAlchemyExtractor Signed-off-by: benrifkind * lint Signed-off-by: benrifkind * more lint Signed-off-by: benrifkind --- .../extractor/sql_alchemy_extractor.py | 10 ++++++- .../extractor/test_sql_alchemy_extractor.py | 30 ++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/databuilder/extractor/sql_alchemy_extractor.py b/databuilder/extractor/sql_alchemy_extractor.py index 2c90a8fd3..a375781a5 100644 --- a/databuilder/extractor/sql_alchemy_extractor.py +++ b/databuilder/extractor/sql_alchemy_extractor.py @@ -14,6 +14,7 @@ class SQLAlchemyExtractor(Extractor): # Config keys CONN_STRING = 'conn_string' EXTRACT_SQL = 'extract_sql' + CONNECT_ARGS = 'connect_args' """ An Extractor that extracts records via SQLAlchemy. Database that supports SQLAlchemy can use this extractor """ @@ -25,6 +26,7 @@ def init(self, conf: ConfigTree) -> None: """ self.conf = conf self.conn_string = conf.get_string(SQLAlchemyExtractor.CONN_STRING) + self.connection = self._get_connection() self.extract_sql = conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) @@ -40,7 +42,13 @@ def _get_connection(self) -> Any: """ Create a SQLAlchemy connection to Database """ - engine = create_engine(self.conn_string) + connect_args = { + k: v + for k, v in self.conf.get_config( + self.CONNECT_ARGS, default=ConfigTree() + ).items() + } + engine = create_engine(self.conn_string, connect_args=connect_args) conn = engine.connect() return conn diff --git a/tests/unit/extractor/test_sql_alchemy_extractor.py b/tests/unit/extractor/test_sql_alchemy_extractor.py index b9bd967a6..32b4ed6e1 100644 --- a/tests/unit/extractor/test_sql_alchemy_extractor.py +++ b/tests/unit/extractor/test_sql_alchemy_extractor.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import unittest -from typing import Any +from typing import Any, Dict from mock import patch from pyhocon import ConfigFactory @@ -94,6 +94,34 @@ def test_extraction_with_model_class(self: Any, mock_method: Any) -> None: self.assertIsInstance(result, TableMetadataResult) self.assertEqual(result.name, 'test_table') + @patch('databuilder.extractor.sql_alchemy_extractor.create_engine') + def test_get_connection(self: Any, mock_method: Any) -> None: + """ + Test that configs are passed through correctly to the _get_connection method + """ + extractor = SQLAlchemyExtractor() + config_dict: Dict[str, Any] = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;' + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={}) + + extractor = SQLAlchemyExtractor() + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;', + 'extractor.sqlalchemy.connect_args': {"protocol": "https"}, + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={"protocol": "https"}) + class TableMetadataResult: """