Skip to content

Commit

Permalink
Set Schema for Postgres Connector [#1362] (#1375)
Browse files Browse the repository at this point in the history
* Allow a schema to be specified for Postgres beyond the default schema.

- Add db_schema as an optional postgres secret.
- Add an empty set_schema method on the Base SQL Connector that can optionally be added on a subclass to set a schema for an entire session.
- Define PostgreSQLConnector.set_schema to set the search path
- Add a required secrets_schema property to be defined on every SQLConnector
- Move "create_client" to the base SQLConnector and remove most locations where it was overridden because the primary element that was changing was the secrets schema.
- Remove Redshift overrides for retrieve_data and mask_data since their only purposes are to set the schema, which the base sql connector can now do.

* Update CHANGELOG.

* Update the secrets format in testing now that db_schema can be optionally set.

* Update separate missed test concerning new db_schema secrets attribute

* Update CHANGELOG.md

* Random import removed from this file.
  • Loading branch information
pattisdr authored Sep 23, 2022
1 parent ebd9e7d commit cbf487e
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 126 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The types of changes are:
* `api_key` auth strategy for SaaS connectors [#1331](https://github.com/ethyca/fidesops/pull/1331)
* Access support for Rollbar [#1361](https://github.com/ethyca/fidesops/pull/1361)
* Adds a new Timescale connector [#1327](https://github.com/ethyca/fidesops/pull/1327)
* Allow querying the non-default schema with the Postgres Connector [#1375](https://github.com/ethyca/fidesops/pull/1375)

### Removed

Expand Down
24 changes: 23 additions & 1 deletion docker/sample_data/postgres_example.sql
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,26 @@ INSERT INTO public.report VALUES

INSERT INTO public.type_link_test VALUES
('1', 'name1'),
('2', 'name2');
('2', 'name2');


CREATE SCHEMA backup_schema;
CREATE TABLE backup_schema.product (LIKE public.product INCLUDING ALL);
CREATE TABLE backup_schema.address (LIKE public.address INCLUDING ALL);
CREATE TABLE backup_schema.customer (LIKE public.customer INCLUDING ALL);
CREATE TABLE backup_schema.employee (LIKE public.employee INCLUDING ALL);
CREATE TABLE backup_schema.payment_card (LIKE public.payment_card INCLUDING ALL);
CREATE TABLE backup_schema.orders (LIKE public.orders INCLUDING ALL);
CREATE TABLE backup_schema.order_item (LIKE public.order_item INCLUDING ALL);
CREATE TABLE backup_schema.visit (LIKE public.visit INCLUDING ALL);
CREATE TABLE backup_schema.login (LIKE public.login INCLUDING ALL);
CREATE TABLE backup_schema.service_request (LIKE public.service_request INCLUDING ALL);
CREATE TABLE backup_schema.report (LIKE public.report INCLUDING ALL);
CREATE TABLE backup_schema.composite_pk_test (LIKE public.composite_pk_test INCLUDING ALL);
CREATE TABLE backup_schema.type_link_test (LIKE public.type_link_test INCLUDING ALL);

INSERT INTO backup_schema.customer VALUES
(1, '[email protected]', 'Johanna Customer', '2022-05-01 12:22:11', 7);

INSERT INTO backup_schema.address VALUES
(7, '311', 'Test Street', 'Test Town', 'TX', '79843');
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class PostgreSQLSchema(ConnectionConfigSecretsSchema):
username: Optional[str] = None
password: Optional[str] = None
dbname: Optional[str] = None
db_schema: Optional[str] = None
host: Optional[
str
] = None # Either the entire "url" *OR* the "host" should be supplied.
Expand Down
188 changes: 63 additions & 125 deletions src/fidesops/ops/service/connectors/sql_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from snowflake.sqlalchemy import URL as Snowflake_URL
from sqlalchemy import Column, text
Expand All @@ -18,10 +18,11 @@

from fidesops.ops.common_exceptions import ConnectionException
from fidesops.ops.graph.traversal import Row, TraversalNode
from fidesops.ops.models.connectionconfig import ConnectionTestStatus
from fidesops.ops.models.connectionconfig import ConnectionConfig, ConnectionTestStatus
from fidesops.ops.models.policy import Policy
from fidesops.ops.models.privacy_request import PrivacyRequest
from fidesops.ops.schemas.connection_configuration import (
ConnectionConfigSecretsSchema,
MicrosoftSQLServerSchema,
PostgreSQLSchema,
RedshiftSchema,
Expand Down Expand Up @@ -52,6 +53,16 @@ class SQLConnector(BaseConnector[Engine]):
"""A SQL connector represents an abstract connector to any datastore that can be
interacted with via standard SQL via SQLAlchemy"""

secrets_schema: Type[ConnectionConfigSecretsSchema]

def __init__(self, configuration: ConnectionConfig):
"""Instantiate a SQL-based connector"""
super().__init__(configuration)
if not self.secrets_schema:
raise NotImplementedError(
"SQL Connectors must define their secrets schema class"
)

@staticmethod
def cursor_result_to_rows(results: CursorResult) -> List[Row]:
"""Convert SQLAlchemy results to a list of dictionaries"""
Expand Down Expand Up @@ -119,6 +130,7 @@ def retrieve_data(
return []
logger.info("Starting data retrieval for %s", node.address)
with client.connect() as connection:
self.set_schema(connection)
results = connection.execute(stmt)
return self.cursor_result_to_rows(results)

Expand All @@ -140,6 +152,7 @@ def mask_data(
)
if update_stmt is not None:
with client.connect() as connection:
self.set_schema(connection)
results: LegacyCursorResult = connection.execute(update_stmt)
update_ct = update_ct + results.rowcount
return update_ct
Expand All @@ -150,13 +163,29 @@ def close(self) -> None:
logger.debug(" disposing of %s", self.__class__)
self.db_client.dispose()

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with a database"""
config = self.secrets_schema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

def set_schema(self, connection: Connection) -> None:
"""Optionally override to set the schema for a given database that
persists through the entire session"""


class PostgreSQLConnector(SQLConnector):
"""Connector specific to postgresql"""

secrets_schema = PostgreSQLSchema

def build_uri(self) -> str:
"""Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]"""
config = PostgreSQLSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

user_password = ""
if config.username:
Expand All @@ -169,23 +198,24 @@ def build_uri(self) -> str:
dbname = f"/{config.dbname}" if config.dbname else ""
return f"postgresql://{user_password}{netloc}{port}{dbname}"

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with a PostgreSQL database"""
config = PostgreSQLSchema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)
def set_schema(self, connection: Connection) -> None:
"""Sets the schema for a postgres database if applicable"""
config = self.secrets_schema(**self.configuration.secrets or {})
if config.db_schema:
logger.info("Setting PostgreSQL search_path before retrieving data")
stmt = text("SET search_path to :search_path")
stmt = stmt.bindparams(search_path=config.db_schema)
connection.execute(stmt)


class MySQLConnector(SQLConnector):
"""Connector specific to MySQL"""

secrets_schema = MySQLSchema

def build_uri(self) -> str:
"""Build URI of format mysql+pymysql://[user[:password]@][netloc][:port][/dbname]"""
config = MySQLSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

user_password = ""
if config.username:
Expand All @@ -199,16 +229,6 @@ def build_uri(self) -> str:
url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}"
return url

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with a MySQL database"""
config = MySQLSchema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Expand All @@ -220,9 +240,11 @@ def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
class MariaDBConnector(SQLConnector):
"""Connector specific to MariaDB"""

secrets_schema = MariaDBSchema

def build_uri(self) -> str:
"""Build URI of format mariadb+pymysql://[user[:password]@][netloc][:port][/dbname]"""
config = MariaDBSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

user_password = ""
if config.username:
Expand All @@ -236,16 +258,6 @@ def build_uri(self) -> str:
url = f"mariadb+pymysql://{user_password}{netloc}{port}{dbname}"
return url

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with a MariaDB database"""
config = MariaDBSchema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

@staticmethod
def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
"""
Expand All @@ -257,89 +269,27 @@ def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]:
class RedshiftConnector(SQLConnector):
"""Connector specific to Amazon Redshift"""

secrets_schema = RedshiftSchema

# Overrides BaseConnector.build_uri
def build_uri(self) -> str:
"""Build URI of format redshift+psycopg2://user:password@[host][:port][/database]"""
config = RedshiftSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

port = f":{config.port}" if config.port else ""
database = f"/{config.database}" if config.database else ""
url = f"redshift+psycopg2://{config.user}:{config.password}@{config.host}{port}{database}"
return url

# Overrides SQLConnector.create_client
def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with an Amazon Redshift cluster"""
config = RedshiftSchema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

def set_schema(self, connection: Connection) -> None:
"""Sets the search_path for the duration of the session"""
config = RedshiftSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})
if config.db_schema:
logger.info("Setting Redshift search_path before retrieving data")
stmt = text("SET search_path to :search_path")
stmt = stmt.bindparams(search_path=config.db_schema)
connection.execute(stmt)

# Overrides SQLConnector.retrieve_data
def retrieve_data(
self,
node: TraversalNode,
policy: Policy,
privacy_request: PrivacyRequest,
input_data: Dict[str, List[Any]],
) -> List[Row]:
"""Retrieve data from Amazon Redshift
For redshift, we also set the search_path to be the schema defined on the ConnectionConfig if
applicable - persists for the current session.
"""
query_config = self.query_config(node)
client = self.client()
stmt = query_config.generate_query(input_data, policy)
if stmt is None:
return []

logger.info("Starting data retrieval for %s", node.address)
with client.connect() as connection:
self.set_schema(connection)
results = connection.execute(stmt)
return SQLConnector.cursor_result_to_rows(results)

# Overrides SQLConnector.mask_data
def mask_data(
self,
node: TraversalNode,
policy: Policy,
privacy_request: PrivacyRequest,
rows: List[Row],
input_data: Dict[str, List[Any]],
) -> int:
"""Execute a masking request. Returns the number of records masked
For redshift, we also set the search_path to be the schema defined on the ConnectionConfig if
applicable - persists for the current session.
"""
query_config = self.query_config(node)
update_ct = 0
client = self.client()
for row in rows:
update_stmt = query_config.generate_update_stmt(
row, policy, privacy_request
)
if update_stmt is not None:
with client.connect() as connection:
self.set_schema(connection)
results: LegacyCursorResult = connection.execute(update_stmt)
update_ct = update_ct + results.rowcount
return update_ct

# Overrides SQLConnector.query_config
def query_config(self, node: TraversalNode) -> RedshiftQueryConfig:
"""Query wrapper corresponding to the input traversal_node."""
Expand All @@ -349,19 +299,23 @@ def query_config(self, node: TraversalNode) -> RedshiftQueryConfig:
class BigQueryConnector(SQLConnector):
"""Connector specific to Google BigQuery"""

secrets_schema = BigQuerySchema

# Overrides BaseConnector.build_uri
def build_uri(self) -> str:
"""Build URI of format"""
config = BigQuerySchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})
dataset = f"/{config.dataset}" if config.dataset else ""
return f"bigquery://{config.keyfile_creds.project_id}{dataset}"

# Overrides SQLConnector.create_client
def create_client(self) -> Engine:
"""
Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery
Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery.
Overrides to pass in credentials_info
"""
config = BigQuerySchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()

return create_engine(
Expand Down Expand Up @@ -402,11 +356,13 @@ def mask_data(
class SnowflakeConnector(SQLConnector):
"""Connector specific to Snowflake"""

secrets_schema = SnowflakeSchema

def build_uri(self) -> str:
"""Build URI of format 'snowflake://<user_login_name>:<password>@<account_identifier>/<database_name>/
<schema_name>?warehouse=<warehouse_name>&role=<role_name>'
"""
config = SnowflakeSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

kwargs = {}

Expand All @@ -428,16 +384,6 @@ def build_uri(self) -> str:
url: str = Snowflake_URL(**kwargs)
return url

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with Snowflake"""
config = SnowflakeSchema(**self.configuration.secrets or {})
uri: str = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

def query_config(self, node: TraversalNode) -> SQLQueryConfig:
"""Query wrapper corresponding to the input traversal_node."""
return SnowflakeQueryConfig(node)
Expand All @@ -448,14 +394,16 @@ class MicrosoftSQLServerConnector(SQLConnector):
Connector specific to Microsoft SQL Server
"""

secrets_schema = MicrosoftSQLServerSchema

def build_uri(self) -> URL:
"""
Build URI of format
mssql+pyodbc://[username]:[password]@[host]:[port]/[dbname]?driver=ODBC+Driver+17+for+SQL+Server
Returns URL obj, since SQLAlchemy's create_engine method accepts either a URL obj or a string
"""

config = MicrosoftSQLServerSchema(**self.configuration.secrets or {})
config = self.secrets_schema(**self.configuration.secrets or {})

url = URL.create(
"mssql+pyodbc",
Expand All @@ -469,16 +417,6 @@ def build_uri(self) -> URL:

return url

def create_client(self) -> Engine:
"""Returns a SQLAlchemy Engine that can be used to interact with a MicrosoftSQLServer database"""
config = MicrosoftSQLServerSchema(**self.configuration.secrets or {})
uri = config.url or self.build_uri()
return create_engine(
uri,
hide_parameters=self.hide_parameters,
echo=not self.hide_parameters,
)

def query_config(self, node: TraversalNode) -> SQLQueryConfig:
"""Query wrapper corresponding to the input traversal_node."""
return MicrosoftSQLServerQueryConfig(node)
Expand Down
Loading

0 comments on commit cbf487e

Please sign in to comment.