From cbf487ec5b1b836e10954b768907848a121a2ec5 Mon Sep 17 00:00:00 2001 From: Dawn Pattison Date: Fri, 23 Sep 2022 14:00:41 -0500 Subject: [PATCH] Set Schema for Postgres Connector [#1362] (#1375) * Allow a schema to be specified for Postgres beyond the default schema. - Add db_schema as an optional postgres secret. - Add an empty set_schema method on the Base SQL Connector that can optionally be added on a subclass to set a schema for an entire session. - Define PostgreSQLConnector.set_schema to set the search path - Add a required secrets_schema property to be defined on every SQLConnector - Move "create_client" to the base SQLConnector and remove most locations where it was overridden because the primary element that was changing was the secrets schema. - Remove Redshift overrides for retrieve_data and mask_data since their only purposes are to set the schema, which the base sql connector can now do. * Update CHANGELOG. * Update the secrets format in testing now that db_schema can be optionally set. * Update separate missed test concerning new db_schema secrets attribute * Update CHANGELOG.md * Random import removed from this file. --- CHANGELOG.md | 1 + docker/sample_data/postgres_example.sql | 24 ++- .../connection_secrets_postgres.py | 1 + .../ops/service/connectors/sql_connector.py | 188 ++++++------------ .../test_connection_config_endpoints.py | 2 + tests/ops/fixtures/postgres_fixtures.py | 26 +++ ...st_connection_configuration_integration.py | 4 + tests/ops/integration_tests/test_sql_task.py | 85 ++++++++ 8 files changed, 205 insertions(+), 126 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 888e98bad1..47b1a67748 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The types of changes are: * `api_key` auth strategy for SaaS connectors [#1331](https://github.com/ethyca/fidesops/pull/1331) * Access support for Rollbar [#1361](https://github.com/ethyca/fidesops/pull/1361) * Adds a new Timescale connector [#1327](https://github.com/ethyca/fidesops/pull/1327) +* Allow querying the non-default schema with the Postgres Connector [#1375](https://github.com/ethyca/fidesops/pull/1375) ### Removed diff --git a/docker/sample_data/postgres_example.sql b/docker/sample_data/postgres_example.sql index 23b6b6017d..43754707a7 100644 --- a/docker/sample_data/postgres_example.sql +++ b/docker/sample_data/postgres_example.sql @@ -174,4 +174,26 @@ INSERT INTO public.report VALUES INSERT INTO public.type_link_test VALUES ('1', 'name1'), -('2', 'name2'); \ No newline at end of file +('2', 'name2'); + + +CREATE SCHEMA backup_schema; +CREATE TABLE backup_schema.product (LIKE public.product INCLUDING ALL); +CREATE TABLE backup_schema.address (LIKE public.address INCLUDING ALL); +CREATE TABLE backup_schema.customer (LIKE public.customer INCLUDING ALL); +CREATE TABLE backup_schema.employee (LIKE public.employee INCLUDING ALL); +CREATE TABLE backup_schema.payment_card (LIKE public.payment_card INCLUDING ALL); +CREATE TABLE backup_schema.orders (LIKE public.orders INCLUDING ALL); +CREATE TABLE backup_schema.order_item (LIKE public.order_item INCLUDING ALL); +CREATE TABLE backup_schema.visit (LIKE public.visit INCLUDING ALL); +CREATE TABLE backup_schema.login (LIKE public.login INCLUDING ALL); +CREATE TABLE backup_schema.service_request (LIKE public.service_request INCLUDING ALL); +CREATE TABLE backup_schema.report (LIKE public.report INCLUDING ALL); +CREATE TABLE backup_schema.composite_pk_test (LIKE public.composite_pk_test INCLUDING ALL); +CREATE TABLE backup_schema.type_link_test (LIKE public.type_link_test INCLUDING ALL); + +INSERT INTO backup_schema.customer VALUES +(1, 'customer-500@example.com', 'Johanna Customer', '2022-05-01 12:22:11', 7); + +INSERT INTO backup_schema.address VALUES +(7, '311', 'Test Street', 'Test Town', 'TX', '79843'); diff --git a/src/fidesops/ops/schemas/connection_configuration/connection_secrets_postgres.py b/src/fidesops/ops/schemas/connection_configuration/connection_secrets_postgres.py index 2b54889fce..63261a6919 100644 --- a/src/fidesops/ops/schemas/connection_configuration/connection_secrets_postgres.py +++ b/src/fidesops/ops/schemas/connection_configuration/connection_secrets_postgres.py @@ -12,6 +12,7 @@ class PostgreSQLSchema(ConnectionConfigSecretsSchema): username: Optional[str] = None password: Optional[str] = None dbname: Optional[str] = None + db_schema: Optional[str] = None host: Optional[ str ] = None # Either the entire "url" *OR* the "host" should be supplied. diff --git a/src/fidesops/ops/service/connectors/sql_connector.py b/src/fidesops/ops/service/connectors/sql_connector.py index d0de626c05..b9f7c5f0de 100644 --- a/src/fidesops/ops/service/connectors/sql_connector.py +++ b/src/fidesops/ops/service/connectors/sql_connector.py @@ -1,6 +1,6 @@ import logging from abc import abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from snowflake.sqlalchemy import URL as Snowflake_URL from sqlalchemy import Column, text @@ -18,10 +18,11 @@ from fidesops.ops.common_exceptions import ConnectionException from fidesops.ops.graph.traversal import Row, TraversalNode -from fidesops.ops.models.connectionconfig import ConnectionTestStatus +from fidesops.ops.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fidesops.ops.models.policy import Policy from fidesops.ops.models.privacy_request import PrivacyRequest from fidesops.ops.schemas.connection_configuration import ( + ConnectionConfigSecretsSchema, MicrosoftSQLServerSchema, PostgreSQLSchema, RedshiftSchema, @@ -52,6 +53,16 @@ class SQLConnector(BaseConnector[Engine]): """A SQL connector represents an abstract connector to any datastore that can be interacted with via standard SQL via SQLAlchemy""" + secrets_schema: Type[ConnectionConfigSecretsSchema] + + def __init__(self, configuration: ConnectionConfig): + """Instantiate a SQL-based connector""" + super().__init__(configuration) + if not self.secrets_schema: + raise NotImplementedError( + "SQL Connectors must define their secrets schema class" + ) + @staticmethod def cursor_result_to_rows(results: CursorResult) -> List[Row]: """Convert SQLAlchemy results to a list of dictionaries""" @@ -119,6 +130,7 @@ def retrieve_data( return [] logger.info("Starting data retrieval for %s", node.address) with client.connect() as connection: + self.set_schema(connection) results = connection.execute(stmt) return self.cursor_result_to_rows(results) @@ -140,6 +152,7 @@ def mask_data( ) if update_stmt is not None: with client.connect() as connection: + self.set_schema(connection) results: LegacyCursorResult = connection.execute(update_stmt) update_ct = update_ct + results.rowcount return update_ct @@ -150,13 +163,29 @@ def close(self) -> None: logger.debug(" disposing of %s", self.__class__) self.db_client.dispose() + def create_client(self) -> Engine: + """Returns a SQLAlchemy Engine that can be used to interact with a database""" + config = self.secrets_schema(**self.configuration.secrets or {}) + uri = config.url or self.build_uri() + return create_engine( + uri, + hide_parameters=self.hide_parameters, + echo=not self.hide_parameters, + ) + + def set_schema(self, connection: Connection) -> None: + """Optionally override to set the schema for a given database that + persists through the entire session""" + class PostgreSQLConnector(SQLConnector): """Connector specific to postgresql""" + secrets_schema = PostgreSQLSchema + def build_uri(self) -> str: """Build URI of format postgresql://[user[:password]@][netloc][:port][/dbname]""" - config = PostgreSQLSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) user_password = "" if config.username: @@ -169,23 +198,24 @@ def build_uri(self) -> str: dbname = f"/{config.dbname}" if config.dbname else "" return f"postgresql://{user_password}{netloc}{port}{dbname}" - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a PostgreSQL database""" - config = PostgreSQLSchema(**self.configuration.secrets or {}) - uri = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) + def set_schema(self, connection: Connection) -> None: + """Sets the schema for a postgres database if applicable""" + config = self.secrets_schema(**self.configuration.secrets or {}) + if config.db_schema: + logger.info("Setting PostgreSQL search_path before retrieving data") + stmt = text("SET search_path to :search_path") + stmt = stmt.bindparams(search_path=config.db_schema) + connection.execute(stmt) class MySQLConnector(SQLConnector): """Connector specific to MySQL""" + secrets_schema = MySQLSchema + def build_uri(self) -> str: """Build URI of format mysql+pymysql://[user[:password]@][netloc][:port][/dbname]""" - config = MySQLSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) user_password = "" if config.username: @@ -199,16 +229,6 @@ def build_uri(self) -> str: url = f"mysql+pymysql://{user_password}{netloc}{port}{dbname}" return url - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a MySQL database""" - config = MySQLSchema(**self.configuration.secrets or {}) - uri = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - @staticmethod def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: """ @@ -220,9 +240,11 @@ def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: class MariaDBConnector(SQLConnector): """Connector specific to MariaDB""" + secrets_schema = MariaDBSchema + def build_uri(self) -> str: """Build URI of format mariadb+pymysql://[user[:password]@][netloc][:port][/dbname]""" - config = MariaDBSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) user_password = "" if config.username: @@ -236,16 +258,6 @@ def build_uri(self) -> str: url = f"mariadb+pymysql://{user_password}{netloc}{port}{dbname}" return url - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a MariaDB database""" - config = MariaDBSchema(**self.configuration.secrets or {}) - uri = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - @staticmethod def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: """ @@ -257,89 +269,27 @@ def cursor_result_to_rows(results: LegacyCursorResult) -> List[Row]: class RedshiftConnector(SQLConnector): """Connector specific to Amazon Redshift""" + secrets_schema = RedshiftSchema + # Overrides BaseConnector.build_uri def build_uri(self) -> str: """Build URI of format redshift+psycopg2://user:password@[host][:port][/database]""" - config = RedshiftSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) port = f":{config.port}" if config.port else "" database = f"/{config.database}" if config.database else "" url = f"redshift+psycopg2://{config.user}:{config.password}@{config.host}{port}{database}" return url - # Overrides SQLConnector.create_client - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with an Amazon Redshift cluster""" - config = RedshiftSchema(**self.configuration.secrets or {}) - uri = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - def set_schema(self, connection: Connection) -> None: """Sets the search_path for the duration of the session""" - config = RedshiftSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) if config.db_schema: logger.info("Setting Redshift search_path before retrieving data") stmt = text("SET search_path to :search_path") stmt = stmt.bindparams(search_path=config.db_schema) connection.execute(stmt) - # Overrides SQLConnector.retrieve_data - def retrieve_data( - self, - node: TraversalNode, - policy: Policy, - privacy_request: PrivacyRequest, - input_data: Dict[str, List[Any]], - ) -> List[Row]: - """Retrieve data from Amazon Redshift - - For redshift, we also set the search_path to be the schema defined on the ConnectionConfig if - applicable - persists for the current session. - """ - query_config = self.query_config(node) - client = self.client() - stmt = query_config.generate_query(input_data, policy) - if stmt is None: - return [] - - logger.info("Starting data retrieval for %s", node.address) - with client.connect() as connection: - self.set_schema(connection) - results = connection.execute(stmt) - return SQLConnector.cursor_result_to_rows(results) - - # Overrides SQLConnector.mask_data - def mask_data( - self, - node: TraversalNode, - policy: Policy, - privacy_request: PrivacyRequest, - rows: List[Row], - input_data: Dict[str, List[Any]], - ) -> int: - """Execute a masking request. Returns the number of records masked - - For redshift, we also set the search_path to be the schema defined on the ConnectionConfig if - applicable - persists for the current session. - """ - query_config = self.query_config(node) - update_ct = 0 - client = self.client() - for row in rows: - update_stmt = query_config.generate_update_stmt( - row, policy, privacy_request - ) - if update_stmt is not None: - with client.connect() as connection: - self.set_schema(connection) - results: LegacyCursorResult = connection.execute(update_stmt) - update_ct = update_ct + results.rowcount - return update_ct - # Overrides SQLConnector.query_config def query_config(self, node: TraversalNode) -> RedshiftQueryConfig: """Query wrapper corresponding to the input traversal_node.""" @@ -349,19 +299,23 @@ def query_config(self, node: TraversalNode) -> RedshiftQueryConfig: class BigQueryConnector(SQLConnector): """Connector specific to Google BigQuery""" + secrets_schema = BigQuerySchema + # Overrides BaseConnector.build_uri def build_uri(self) -> str: """Build URI of format""" - config = BigQuerySchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) dataset = f"/{config.dataset}" if config.dataset else "" return f"bigquery://{config.keyfile_creds.project_id}{dataset}" # Overrides SQLConnector.create_client def create_client(self) -> Engine: """ - Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery + Returns a SQLAlchemy Engine that can be used to interact with Google BigQuery. + + Overrides to pass in credentials_info """ - config = BigQuerySchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) uri = config.url or self.build_uri() return create_engine( @@ -402,11 +356,13 @@ def mask_data( class SnowflakeConnector(SQLConnector): """Connector specific to Snowflake""" + secrets_schema = SnowflakeSchema + def build_uri(self) -> str: """Build URI of format 'snowflake://:@// ?warehouse=&role=' """ - config = SnowflakeSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) kwargs = {} @@ -428,16 +384,6 @@ def build_uri(self) -> str: url: str = Snowflake_URL(**kwargs) return url - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with Snowflake""" - config = SnowflakeSchema(**self.configuration.secrets or {}) - uri: str = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - def query_config(self, node: TraversalNode) -> SQLQueryConfig: """Query wrapper corresponding to the input traversal_node.""" return SnowflakeQueryConfig(node) @@ -448,6 +394,8 @@ class MicrosoftSQLServerConnector(SQLConnector): Connector specific to Microsoft SQL Server """ + secrets_schema = MicrosoftSQLServerSchema + def build_uri(self) -> URL: """ Build URI of format @@ -455,7 +403,7 @@ def build_uri(self) -> URL: Returns URL obj, since SQLAlchemy's create_engine method accepts either a URL obj or a string """ - config = MicrosoftSQLServerSchema(**self.configuration.secrets or {}) + config = self.secrets_schema(**self.configuration.secrets or {}) url = URL.create( "mssql+pyodbc", @@ -469,16 +417,6 @@ def build_uri(self) -> URL: return url - def create_client(self) -> Engine: - """Returns a SQLAlchemy Engine that can be used to interact with a MicrosoftSQLServer database""" - config = MicrosoftSQLServerSchema(**self.configuration.secrets or {}) - uri = config.url or self.build_uri() - return create_engine( - uri, - hide_parameters=self.hide_parameters, - echo=not self.hide_parameters, - ) - def query_config(self, node: TraversalNode) -> SQLQueryConfig: """Query wrapper corresponding to the input traversal_node.""" return MicrosoftSQLServerQueryConfig(node) diff --git a/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py b/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py index 1672d8d7f7..e9340c455f 100644 --- a/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py @@ -958,6 +958,7 @@ def test_put_connection_config_secrets( "username": None, "password": None, "url": None, + "db_schema": None, } payload = {"url": "postgresql://test_user:test_pass@localhost:1234/my_test_db"} @@ -979,6 +980,7 @@ def test_put_connection_config_secrets( "username": None, "password": None, "url": payload["url"], + "db_schema": None, } assert connection_config.last_test_timestamp is None assert connection_config.last_test_succeeded is None diff --git a/tests/ops/fixtures/postgres_fixtures.py b/tests/ops/fixtures/postgres_fixtures.py index a406510800..2b49cc8aa2 100644 --- a/tests/ops/fixtures/postgres_fixtures.py +++ b/tests/ops/fixtures/postgres_fixtures.py @@ -188,6 +188,32 @@ def read_connection_config( connection_config.delete(db) +@pytest.fixture(scope="function") +def postgres_connection_config_with_schema( + db: Session, +) -> Generator: + """Create a connection config with a db_schema set which allows the PostgresConnector to connect + to a non-default schema""" + connection_config = ConnectionConfig.create( + db=db, + data={ + "name": str(uuid4()), + "key": "my_postgres_db_backup_schema", + "connection_type": ConnectionType.postgres, + "access": AccessLevel.write, + "secrets": integration_secrets["postgres_example"], + "disabled": False, + "description": "Backup postgres data", + }, + ) + connection_config.secrets[ + "db_schema" + ] = "backup_schema" # Matches the second schema created in postgres_example.schema + connection_config.save(db) + yield connection_config + connection_config.delete(db) + + @pytest.fixture(scope="function") def postgres_integration_session_cls(connection_config): example_postgres_uri = PostgreSQLConnector(connection_config).build_uri() diff --git a/tests/ops/integration_tests/test_connection_configuration_integration.py b/tests/ops/integration_tests/test_connection_configuration_integration.py index 2b1977a9ad..a2da794279 100644 --- a/tests/ops/integration_tests/test_connection_configuration_integration.py +++ b/tests/ops/integration_tests/test_connection_configuration_integration.py @@ -70,6 +70,7 @@ def test_postgres_db_connection_incorrect_secrets( "username": None, "password": None, "url": None, + "db_schema": None, } assert connection_config.last_test_timestamp is not None assert connection_config.last_test_succeeded is False @@ -88,6 +89,7 @@ def test_postgres_db_connection_connect_with_components( "dbname": "postgres_example", "username": "postgres", "password": "postgres", + "db_schema": None, } auth_header = generate_auth_header(scopes=[CONNECTION_CREATE_OR_UPDATE]) @@ -113,6 +115,7 @@ def test_postgres_db_connection_connect_with_components( "username": "postgres", "password": "postgres", "url": None, + "db_schema": None, } assert connection_config.last_test_timestamp is not None assert connection_config.last_test_succeeded is True @@ -153,6 +156,7 @@ def test_postgres_db_connection_connect_with_url( "username": None, "password": None, "url": payload["url"], + "db_schema": None, } assert connection_config.last_test_timestamp is not None assert connection_config.last_test_succeeded is True diff --git a/tests/ops/integration_tests/test_sql_task.py b/tests/ops/integration_tests/test_sql_task.py index 52d33e24f3..006cd7d5c3 100644 --- a/tests/ops/integration_tests/test_sql_task.py +++ b/tests/ops/integration_tests/test_sql_task.py @@ -360,6 +360,91 @@ async def test_postgres_access_request_task( ) +@pytest.mark.integration_postgres +@pytest.mark.integration +@pytest.mark.asyncio +async def test_postgres_privacy_requests_against_non_default_schema( + db, + policy, + postgres_connection_config_with_schema, + postgres_integration_db, + erasure_policy, +) -> None: + """Assert that the postgres connector can make access and erasure requests against the non-default (public) schema""" + + privacy_request = PrivacyRequest(id=str(uuid4())) + database_name = "postgres_backup" + customer_email = "customer-500@example.com" + + dataset = integration_db_dataset( + database_name, postgres_connection_config_with_schema.key + ) + graph = DatasetGraph(dataset) + + access_results = await graph_task.run_access_request( + privacy_request, + policy, + graph, + [postgres_connection_config_with_schema], + {"email": customer_email}, + db, + ) + + # Confirm data retrieved from backup_schema, not public schema. This data only exists in the backup_schema. + assert access_results == { + f"{database_name}:address": [ + { + "id": 7, + "street": "Test Street", + "city": "Test Town", + "state": "TX", + "zip": "79843", + } + ], + f"{database_name}:payment_card": [], + f"{database_name}:orders": [], + f"{database_name}:customer": [ + { + "id": 1, + "name": "Johanna Customer", + "email": "customer-500@example.com", + "address_id": 7, + } + ], + } + + rule = erasure_policy.rules[0] + target = rule.targets[0] + target.data_category = "user" + target.save(db) + # Update data category on customer name + field([dataset], database_name, "customer", "name").data_categories = ["user.name"] + + erasure_results = await graph_task.run_erasure( + privacy_request, + erasure_policy, + graph, + [postgres_connection_config_with_schema], + {"email": customer_email}, + get_cached_data_for_erasures(privacy_request.id), + db, + ) + + # Confirm record masked in non-default schema + assert erasure_results == { + f"{database_name}:customer": 1, + f"{database_name}:payment_card": 0, + f"{database_name}:orders": 0, + f"{database_name}:address": 0, + }, "Only one record on customer table has targeted data category" + customer_records = postgres_integration_db.execute( + text("select * from backup_schema.customer where id = 1;") + ) + johanna_record = [c for c in customer_records][0] + assert johanna_record.email == customer_email # Not masked + assert johanna_record.name is None # Masked by erasure request + + @pytest.mark.integration_mssql @pytest.mark.integration @pytest.mark.asyncio