Skip to content

Commit

Permalink
feat(db-engine-specs): add support for Postgres root cert (apache#11720)
Browse files Browse the repository at this point in the history
* feat(db-engine-specs): add support for Postgres root cert

* remove logging of json decode exception message

* fix error message

* fix error message
  • Loading branch information
villebro authored and amitmiran137 committed Jan 14, 2021
1 parent 9954d05 commit e273d17
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
? `${t('ERROR: ')}${
typeof error.message === 'string'
? error.message
: (error.message as Record<string, string[]>).sqlalchemy_uri
: Object.entries(error.message as Record<string, string[]>)
.map(([key, value]) => `(${key}) ${value.join(', ')}`)
.join('\n')
}`
: t('ERROR: Connection failed. '),
);
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
"dirver://user:password@database-host/database-name"
"driver://user:password@database-host/database-name"
)
]
)
Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetException
from superset.utils import core as utils

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,12 +66,12 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
:raises SupersetException: If database extra json payload is unparseable
"""
try:
extra = json.loads(database.extra or "{}")
except json.JSONDecodeError as ex:
logger.error(ex)
raise ex
except json.JSONDecodeError:
raise SupersetException("Unable to parse database extras")

if database.server_cert:
engine_params = extra.get("engine_params", {})
Expand Down
29 changes: 29 additions & 0 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
Expand All @@ -22,11 +24,14 @@
from sqlalchemy.dialects.postgresql.base import PGInspector

from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetException
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover

logger = logging.getLogger()


# Replace psycopg2.tz.FixedOffsetTimezone with pytz, which is serializable by PyArrow
# https://github.com/stub42/pytz/blob/b70911542755aeeea7b5a9e066df5e1c87e8f2c8/src/pytz/reference.py#L25
Expand Down Expand Up @@ -115,3 +120,27 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')"""
return None

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
:raises SupersetException: If database extra json payload is unparseable
"""
try:
extra = json.loads(database.extra or "{}")
except json.JSONDecodeError:
raise SupersetException("Unable to parse database extras")

if database.server_cert:
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
connect_args["sslmode"] = connect_args.get("sslmode", "verify-full")
path = utils.create_ssl_cert_file(database.server_cert)
connect_args["sslrootcert"] = path
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra
4 changes: 2 additions & 2 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_create_database(self):
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": ssl_certificate,
"server_cert": None,
"extra": json.dumps(extra),
}

Expand Down Expand Up @@ -761,7 +761,7 @@ def test_test_connection(self):
"extra": json.dumps(extra),
"impersonate_user": False,
"sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
"server_cert": ssl_certificate,
"server_cert": None,
}
url = "api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
Expand Down
20 changes: 20 additions & 0 deletions tests/db_engine_specs/druid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock

from sqlalchemy import column

from superset.db_engine_specs.druid import DruidEngineSpec
from tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.fixtures.certificates import ssl_certificate
from tests.fixtures.database import default_db_extra


class TestDruidDbEngineSpec(TestDbEngineSpec):
Expand Down Expand Up @@ -54,3 +58,19 @@ def test_timegrain_expressions(self):
col=sqla_col, pdf=None, time_grain=grain
)
self.assertEqual(str(actual), expected)

def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]

def test_extras_with_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args
30 changes: 30 additions & 0 deletions tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from superset.db_engine_specs import engines
from superset.db_engine_specs.postgres import PostgresEngineSpec
from tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.fixtures.certificates import ssl_certificate
from tests.fixtures.database import default_db_extra


class TestPostgresDbEngineSpec(TestDbEngineSpec):
Expand Down Expand Up @@ -124,3 +126,31 @@ def test_engine_alias_name(self):
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", engines)

def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = PostgresEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]

def test_extras_with_ssl_default(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-full"
assert "sslrootcert" in connect_args

def test_extras_with_ssl_custom(self):
db = mock.Mock()
db.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"sslmode": "verify-ca"}}',
)
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-ca"
assert "sslrootcert" in connect_args
22 changes: 22 additions & 0 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
default_db_extra = """{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": []
}"""

0 comments on commit e273d17

Please sign in to comment.