From 3f8383aa89f45d861ca081e3e9fd2cc9d0b5dfaa Mon Sep 17 00:00:00 2001 From: Tim Ryan Date: Mon, 2 Oct 2023 10:15:14 -0400 Subject: [PATCH 1/4] RM-95 use Table object to prevent SQLInjection RM-95 use Table Revert "RM-95 builtin has_table" This reverts commit 43edb5e3554561cee7f5317e616f1f218118ee65. RM-95 builtin has_table RM-95 remove unused ignore RM-95 add if_exists=True RM-95 use DropTable RM-95 use DropTable RM-95 use DropTable RM-95 update ignore RM-95 update ignore RM-95 add newline RM-95 use target. RM-95 update text expectation RM-95 update to DropTable RM-95 ignore arg-type RM-95 add space afterward RM-95 use newline RM-95 update select statement RM-95 update mock names to strings RM-95 select '*' from table RM-95 explicitly use select * RM-95 re-add quote_schema_and_table RM-95 add newline RM-95 use table.select RM-95 fix test drop table RM-95 use DropTable RM-95 bind connection for drop RM-95 update to use table.drop RM-95 remove mock_schema_and_table ref RM-95 update to use DELETE command RM-95 add mock schema name and table name --- records_mover/db/vertica/vertica_db_driver.py | 8 +++----- records_mover/records/prep.py | 11 ++++++++--- records_mover/records/sources/table.py | 7 +++++-- records_mover/records/targets/spectrum.py | 4 +++- tests/integration/records/purge_old_test_tables.py | 6 ++++-- .../records/records_database_fixture.py | 7 ++++--- .../records/records_datetime_fixture.py | 9 +++++---- .../records/records_numeric_database_fixture.py | 7 ++++--- tests/unit/db/vertica/test_vertica_db_driver.py | 8 ++++---- tests/unit/records/sources/test_table.py | 8 ++++---- tests/unit/records/targets/test_spectrum.py | 4 ++-- tests/unit/records/test_prep.py | 14 ++++++++------ 12 files changed, 54 insertions(+), 39 deletions(-) diff --git a/records_mover/db/vertica/vertica_db_driver.py b/records_mover/db/vertica/vertica_db_driver.py index bc61b2548..4529745d9 100644 --- a/records_mover/db/vertica/vertica_db_driver.py +++ b/records_mover/db/vertica/vertica_db_driver.py @@ -1,7 +1,7 @@ from ..driver import DBDriver import sqlalchemy from sqlalchemy.sql import text -from records_mover.db.quoting import quote_schema_and_table +from sqlalchemy import select from sqlalchemy.schema import Table, Column import logging from typing import Optional, Union, Tuple @@ -47,10 +47,8 @@ def unloader(self) -> Optional[Unloader]: def has_table(self, schema: str, table: str) -> bool: try: - sql = ("SELECT 1 " - f"from {quote_schema_and_table(None, schema, table, db_engine=self.db_engine)} " - "limit 0;") - self.db_conn.execute(text(sql)) + table_to_check = Table(table, self.meta, schema=schema) + self.db_conn.execute(select(text("1"), table_to_check)) return True except sqlalchemy.exc.ProgrammingError: return False diff --git a/records_mover/records/prep.py b/records_mover/records/prep.py index d3a5caecd..d33abcc29 100644 --- a/records_mover/records/prep.py +++ b/records_mover/records/prep.py @@ -5,7 +5,8 @@ from records_mover.db import DBDriver from records_mover.records.table import TargetTableDetails import logging -from sqlalchemy import text +from sqlalchemy import text, Table, MetaData +from sqlalchemy.schema import DropTable logger = logging.getLogger(__name__) @@ -61,7 +62,9 @@ def prep_table_for_load(self, db_engine=db_engine,) if (how_to_prep == ExistingTableHandling.TRUNCATE_AND_OVERWRITE): logger.info("Truncating...") - db_conn.execute(text(f"TRUNCATE TABLE {schema_and_table}")) + meta = MetaData() + table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name) + db_conn.execute(table.delete()) logger.info("Truncated.") elif (how_to_prep == ExistingTableHandling.DELETE_AND_OVERWRITE): logger.info("Deleting rows...") @@ -72,8 +75,10 @@ def prep_table_for_load(self, with conn.begin(): logger.info(f"The connection object is: {conn}") logger.info("Dropping and recreating...") + meta = MetaData() + table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name) drop_table_sql = f"DROP TABLE {schema_and_table}" - conn.execute(text(drop_table_sql)) + conn.execute(DropTable(table)) # type: ignore[arg-type] logger.info(f"Just ran {drop_table_sql}") self.create_table(schema_sql, conn, driver) elif (how_to_prep == ExistingTableHandling.APPEND): diff --git a/records_mover/records/sources/table.py b/records_mover/records/sources/table.py index 4d7826e95..08196aaa8 100644 --- a/records_mover/records/sources/table.py +++ b/records_mover/records/sources/table.py @@ -8,7 +8,7 @@ from ..records_format import BaseRecordsFormat from ..unload_plan import RecordsUnloadPlan from ..results import MoveResult -from sqlalchemy import text +from sqlalchemy import Table, MetaData, select from sqlalchemy.engine import Engine from contextlib import contextmanager from ..schema import RecordsSchema @@ -88,10 +88,13 @@ def to_dataframes_source(self, chunksize = int(entries_per_chunk / num_columns) logger.info(f"Exporting in chunks of up to {chunksize} rows by {num_columns} columns") + meta = MetaData() + table = Table(self.table_name, meta, schema=self.schema_name) quoted_table = quote_schema_and_table(None, self.schema_name, self.table_name, db_engine=db_engine,) + logger.info(f"Reading {quoted_table}...") chunks: Generator['DataFrame', None, None] = \ - pandas.read_sql(text(f"SELECT * FROM {quoted_table}"), + pandas.read_sql(select('*', table), # type: ignore[arg-type] con=db_conn, chunksize=chunksize) try: diff --git a/records_mover/records/targets/spectrum.py b/records_mover/records/targets/spectrum.py index 7ea1a1dac..a7d746518 100644 --- a/records_mover/records/targets/spectrum.py +++ b/records_mover/records/targets/spectrum.py @@ -14,6 +14,7 @@ import logging import sqlalchemy from sqlalchemy import text +from sqlalchemy.schema import DropTable logger = logging.getLogger(__name__) @@ -84,9 +85,10 @@ def prep_bucket(self) -> None: db_engine=self.db_engine) logger.info(f"Dropping external table {schema_and_table}...") with self.db_engine.connect() as cursor: + table = Table(self.table_name, MetaData(), schema=self.schema_name) # See below note about fix from Spectrify cursor.execution_options(isolation_level='AUTOCOMMIT') - cursor.execute(text(f"DROP TABLE IF EXISTS {schema_and_table}")) + cursor.execute(DropTable(table, if_exists=True)) # type: ignore[call-arg, arg-type] logger.info(f"Deleting files in {self.output_loc}...") self.output_loc.purge_directory() diff --git a/tests/integration/records/purge_old_test_tables.py b/tests/integration/records/purge_old_test_tables.py index cc7caa759..f01d630c4 100755 --- a/tests/integration/records/purge_old_test_tables.py +++ b/tests/integration/records/purge_old_test_tables.py @@ -3,7 +3,8 @@ from records_mover.db.quoting import quote_schema_and_table from records_mover import Session from datetime import datetime, timedelta -from sqlalchemy import inspect +from sqlalchemy import inspect, Table, MetaData +from sqlalchemy.schema import DropTable from typing import Optional import sys @@ -39,9 +40,10 @@ def purge_old_tables(schema_name: str, table_name_prefix: str, "DROP TABLE " f"{quote_schema_and_table(None, schema_name, table_name, db_engine=db_engine)}") print(sql) + table = Table(table_name, MetaData(), schema=schema_name) with db_engine.connect() as connection: with connection.begin(): - connection.exec_driver_sql(sql) + connection.exec_driver_sql(DropTable(table)) # type: ignore[arg-type] if __name__ == '__main__': diff --git a/tests/integration/records/records_database_fixture.py b/tests/integration/records/records_database_fixture.py index 685d3e0bd..0953f366e 100644 --- a/tests/integration/records/records_database_fixture.py +++ b/tests/integration/records/records_database_fixture.py @@ -1,6 +1,7 @@ from records_mover.db.quoting import quote_schema_and_table from records_mover.utils.retry import bigquery_retry -from sqlalchemy import text +from sqlalchemy import Table, MetaData +from sqlalchemy.schema import DropTable import logging logger = logging.getLogger(__name__) @@ -20,10 +21,10 @@ def __init__(self, db_engine, schema_name, table_name): @bigquery_retry() def drop_table_if_exists(self, schema, table): - sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}" + table_to_drop = Table(table, MetaData(), schema=schema) with self.engine.connect() as connection: with connection.begin(): - connection.execute(text(sql)) + connection.execute(DropTable(table_to_drop, if_exists=True)) def tear_down(self): self.drop_table_if_exists(self.schema_name, f"{self.table_name}_frozen") diff --git a/tests/integration/records/records_datetime_fixture.py b/tests/integration/records/records_datetime_fixture.py index 668ffd421..ffebfdc17 100644 --- a/tests/integration/records/records_datetime_fixture.py +++ b/tests/integration/records/records_datetime_fixture.py @@ -6,8 +6,9 @@ SAMPLE_YEAR, SAMPLE_MONTH, SAMPLE_DAY, SAMPLE_HOUR, SAMPLE_MINUTE, SAMPLE_SECOND, SAMPLE_OFFSET, SAMPLE_LONG_TZ ) -from sqlalchemy import text +from sqlalchemy import text, Table, MetaData from sqlalchemy.engine import Engine, Connection +from sqlalchemy.schema import DropTable from typing import Optional import logging @@ -28,13 +29,13 @@ def quote_schema_and_table(self, schema, table): @bigquery_retry() def drop_table_if_exists(self, schema, table): - sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}" + table_to_drop = Table(table, MetaData(), schema=schema) if not self.connection: with self.engine.connect() as connection: with connection.begin(): - connection.execute(text(sql)) + connection.execute(DropTable(table_to_drop, if_exists=True)) else: - self.connection.execute(text(sql)) + self.connection.execute(DropTable(table_to_drop, if_exists=True)) def createDateTimeTzTable(self) -> None: if self.engine.name == 'redshift': diff --git a/tests/integration/records/records_numeric_database_fixture.py b/tests/integration/records/records_numeric_database_fixture.py index 66c381ba9..23cf4014b 100644 --- a/tests/integration/records/records_numeric_database_fixture.py +++ b/tests/integration/records/records_numeric_database_fixture.py @@ -1,5 +1,6 @@ from records_mover.db.quoting import quote_schema_and_table -from sqlalchemy import text +from sqlalchemy import Table, MetaData +from sqlalchemy.schema import DropTable class RecordsNumericDatabaseFixture: @@ -136,10 +137,10 @@ def quote_schema_and_table(self, schema, table): db_engine=self.engine) def drop_table_if_exists(self, schema, table): - sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}" + table_to_drop = Table(table, MetaData(), schema=schema) with self.engine.connect() as connection: with connection.begin(): - connection.execute(text(sql)) + connection.execute(DropTable(table_to_drop, if_exists=True)) def tear_down(self): self.drop_table_if_exists(self.schema_name, self.table_name) diff --git a/tests/unit/db/vertica/test_vertica_db_driver.py b/tests/unit/db/vertica/test_vertica_db_driver.py index e9b0ba9c5..fda4efc62 100644 --- a/tests/unit/db/vertica/test_vertica_db_driver.py +++ b/tests/unit/db/vertica/test_vertica_db_driver.py @@ -54,14 +54,14 @@ def test_schema_sql_but_not_from_export_objects(self): self.assertTrue(sql is not None) def test_has_table_true(self): - mock_schema = Mock(name='schema') - mock_table = Mock(name='table') + mock_schema = 'myschema' + mock_table = 'mytable' self.assertEqual(True, self.vertica_db_driver.has_table(mock_schema, mock_table)) def test_has_table_false(self): - mock_schema = Mock(name='schema') - mock_table = Mock(name='table') + mock_schema = 'myschema' + mock_table = 'mytable' self.mock_db_engine.execute.side_effect = sqlalchemy.exc.ProgrammingError('statement', {}, 'orig') self.assertEqual(False, diff --git a/tests/unit/records/sources/test_table.py b/tests/unit/records/sources/test_table.py index 0c0ac34df..5b1b13ada 100644 --- a/tests/unit/records/sources/test_table.py +++ b/tests/unit/records/sources/test_table.py @@ -6,8 +6,8 @@ class TestTableRecordsSource(unittest.TestCase): def setUp(self): - self.mock_schema_name = Mock(name='schema_name') - self.mock_table_name = Mock(name='table_name') + self.mock_schema_name = 'mock_schema_name' + self.mock_table_name = 'mock_table_name' self.mock_driver = MagicMock(name='driver') self.mock_loader = self.mock_driver.loader.return_value self.mock_unloader = self.mock_driver.unloader.return_value @@ -37,7 +37,6 @@ def test_to_dataframes_source(self, mock_column = Mock(name='column') mock_columns = [mock_column] mock_db_engine.dialect.get_columns.return_value = mock_columns - mock_quoted_table = mock_quote_schema_and_table.return_value mock_chunks = mock_read_sql.return_value with self.table_records_source.to_dataframes_source(mock_processing_instructions) as\ df_source: @@ -49,7 +48,8 @@ def test_to_dataframes_source(self, self.mock_table_name, driver=self.mock_driver) str_arg = str(mock_read_sql.call_args.args[0]) - self.assertEqual(str_arg, f"SELECT * FROM {mock_quoted_table}") + self.assertEqual(str_arg, + f"SELECT * \nFROM {self.mock_schema_name}.{self.mock_table_name}") kwargs = mock_read_sql.call_args.kwargs self.assertEqual(kwargs['con'], mock_db_conn) self.assertEqual(kwargs['chunksize'], 2000000) diff --git a/tests/unit/records/targets/test_spectrum.py b/tests/unit/records/targets/test_spectrum.py index bddbdbc84..d67e8ae48 100644 --- a/tests/unit/records/targets/test_spectrum.py +++ b/tests/unit/records/targets/test_spectrum.py @@ -40,7 +40,6 @@ def test_init(self): @patch('records_mover.records.targets.spectrum.quote_schema_and_table') def test_pre_load_hook_preps_bucket_with_default_prep(self, mock_quote_schema_and_table): - mock_schema_and_table = mock_quote_schema_and_table.return_value mock_cursor = self.target.driver.db_engine.connect.return_value.__enter__.return_value self.target.pre_load_hook() @@ -51,7 +50,8 @@ def test_pre_load_hook_preps_bucket_with_default_prep(self, mock_quote_schema_an mock_cursor.execution_options.assert_called_with(isolation_level='AUTOCOMMIT') arg = mock_cursor.execute.call_args.args[0] arg_str = str(arg) - self.assertEqual(arg_str, f"DROP TABLE IF EXISTS {mock_schema_and_table}") + self.assertEqual( + arg_str, f"\nDROP TABLE IF EXISTS {self.target.schema_name}.{self.target.table_name}") self.mock_output_loc.purge_directory.assert_called_with() @patch('records_mover.records.targets.spectrum.RecordsDirectory') diff --git a/tests/unit/records/test_prep.py b/tests/unit/records/test_prep.py index 16b403d0b..e4ded4932 100644 --- a/tests/unit/records/test_prep.py +++ b/tests/unit/records/test_prep.py @@ -8,6 +8,8 @@ class TestPrep(unittest.TestCase): def setUp(self): self.mock_tbl = Mock(name='target_table_details') + self.mock_tbl.schema_name = 'mock_schema_name' + self.mock_tbl.table_name = 'mock_table_name' self.prep = TablePrep(self.mock_tbl) @patch('records_mover.records.prep.quote_schema_and_table') @@ -33,7 +35,6 @@ def test_prep_table_exists_truncate_implicit(self, mock_quote_schema_and_table): mock_quote_schema_and_table mock_driver.has_table.return_value = True how_to_prep = ExistingTableHandling.TRUNCATE_AND_OVERWRITE - mock_schema_and_table = mock_quote_schema_and_table.return_value self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver) @@ -43,7 +44,8 @@ def test_prep_table_exists_truncate_implicit(self, mock_quote_schema_and_table): self.mock_tbl.table_name, db_engine=mock_driver.db_engine) str_arg = str(mock_driver.db_conn.execute.call_args.args[0]) - self.assertEqual(str_arg, f"TRUNCATE TABLE {mock_schema_and_table}") + self.assertEqual(str_arg, + f"DELETE FROM {self.mock_tbl.schema_name}.{self.mock_tbl.table_name}") @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): @@ -79,7 +81,6 @@ def test_prep_table_exists_drop_implicit(self, mock_quote_schema_and_table): mock_quote_schema_and_table mock_driver.has_table.return_value = True how_to_prep = ExistingTableHandling.DROP_AND_RECREATE - mock_schema_and_table = mock_quote_schema_and_table.return_value self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver) @@ -91,7 +92,8 @@ def test_prep_table_exists_drop_implicit(self, mock_quote_schema_and_table): db_engine=mock_driver.db_engine) str_args = [str(call_arg.args[0]) for call_arg in mock_conn.execute.call_args_list] drop_table_str_arg, mock_schema_sql_str_arg = str_args[0], str_args[1] - self.assertEqual(drop_table_str_arg, f"DROP TABLE {mock_schema_and_table}") + self.assertEqual(drop_table_str_arg, + f"\nDROP TABLE {self.mock_tbl.schema_name}.{self.mock_tbl.table_name}") self.assertEqual(mock_schema_sql_str_arg, mock_schema_sql) mock_driver.set_grant_permissions_for_groups.\ assert_called_with(self.mock_tbl.schema_name, @@ -151,7 +153,6 @@ def test_prep_table_exists_drop_explicit(self, mock_quote_schema_and_table): mock_quote_schema_and_table mock_driver.has_table.return_value = True how_to_prep = ExistingTableHandling.DELETE_AND_OVERWRITE - mock_schema_and_table = mock_quote_schema_and_table.return_value self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver, @@ -163,7 +164,8 @@ def test_prep_table_exists_drop_explicit(self, mock_quote_schema_and_table): db_engine=mock_driver.db_engine) str_args = [str(call_arg.args[0]) for call_arg in mock_conn.execute.call_args_list] drop_table_str_arg, mock_schema_sql_str_arg = str_args[0], str_args[1] - self.assertEqual(drop_table_str_arg, f"DROP TABLE {mock_schema_and_table}") + self.assertEqual(drop_table_str_arg, + f"\nDROP TABLE {self.mock_tbl.schema_name}.{self.mock_tbl.table_name}") self.assertEqual(mock_schema_sql_str_arg, mock_schema_sql) mock_driver.set_grant_permissions_for_groups.\ assert_called_with(self.mock_tbl.schema_name, From a7ec7f2adddcd8ea1e8e3d046733a11178b39fd4 Mon Sep 17 00:00:00 2001 From: Tim Ryan Date: Mon, 30 Oct 2023 13:38:16 -0400 Subject: [PATCH 2/4] RM-87 use table_obj --- records_mover/db/redshift/unloader.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/records_mover/db/redshift/unloader.py b/records_mover/db/redshift/unloader.py index e6963ba9f..b4d895faa 100644 --- a/records_mover/db/redshift/unloader.py +++ b/records_mover/db/redshift/unloader.py @@ -2,10 +2,10 @@ from sqlalchemy_redshift.commands import UnloadFromSelect from ...records.records_directory import RecordsDirectory import sqlalchemy +from sqlalchemy import MetaData from sqlalchemy.sql import text from sqlalchemy.schema import Table import logging -from records_mover.db.quoting import quote_schema_and_table from records_mover.logging import register_secret from .records_unload import redshift_unload_options from ...records.unload_plan import RecordsUnloadPlan @@ -84,9 +84,8 @@ def unload_to_s3_directory(self, # register_secret(aws_creds.token) register_secret(aws_creds.secret_key) - select = text( - "SELECT * FROM " - f"{quote_schema_and_table(None, schema, table, db_engine=self.db_engine)}") + table_obj = Table(table, MetaData(), autoload_with=self.db_engine, schema=schema) + select = sqlalchemy.select(table_obj) unload = UnloadFromSelect(select=select, access_key_id=aws_creds.access_key, secret_access_key=aws_creds.secret_key, From 1308686323e87c2beb039ea7d635f9d15ca6e4ac Mon Sep 17 00:00:00 2001 From: Tim Ryan Date: Mon, 30 Oct 2023 14:33:21 -0400 Subject: [PATCH 3/4] RM-95 Protect against SQL Injection attacks RM-95 address Flake quality issue RM-95 quote tables RM-95 add db argument RM-95 use quote_schema_and_table RM-95 use quote_schema_and_table Revert "RM-95 use quoted schema and table" This reverts commit 4404d75250511d71f916202dfe51f2d35b370fea. RM-95 use quoted schema and table RM-95 use quoted schema and table RM-95 manually remove backticks RM-95 update test to use table_obj RM-95 update to use table object RM-95 redshift_db_driver to use sqlalchemy-priv RM-95 fix typecheck RM-95 update driver to user sqlalchemy priv RM-95 add sqlalchemy privileges for priv mngmnt RM-95 remove autoload_with RM-95 add where TRUE to delete RM-95 fix flake8 errors RM-95 remove autoload_with RM-95 update to use table_obj RM-95 update typing --- records_mover/db/driver.py | 14 ++-- records_mover/db/mysql/load_options.py | 4 +- .../db/redshift/redshift_db_driver.py | 9 ++- records_mover/db/redshift/unloader.py | 6 +- records_mover/db/vertica/vertica_db_driver.py | 3 +- records_mover/records/prep.py | 9 +-- records_mover/records/sources/table.py | 2 +- setup.py | 1 + .../db/mysql/test_load_options_class.py | 12 ++-- .../records/purge_old_test_tables.py | 2 +- .../records/records_database_fixture.py | 10 +-- .../records/records_datetime_fixture.py | 40 +++++------ .../records_numeric_database_fixture.py | 14 ++-- .../single_db/test_records_load_datetime.py | 8 +-- tests/integration/records/table_validator.py | 26 +++++--- .../db/redshift/test_redshift_db_driver.py | 12 +--- .../test_redshift_db_driver_unload.py | 4 +- tests/unit/db/test_db_driver.py | 66 +++++++------------ tests/unit/records/test_prep.py | 10 +-- 19 files changed, 117 insertions(+), 135 deletions(-) diff --git a/records_mover/db/driver.py b/records_mover/db/driver.py index 466618d0c..c9e16ee47 100644 --- a/records_mover/db/driver.py +++ b/records_mover/db/driver.py @@ -1,5 +1,6 @@ from ..check_db_conn_engine import check_db_conn_engine from sqlalchemy.schema import CreateTable +from sqlalchemy_privileges import GrantPrivileges # type: ignore[import-untyped] from ..records.records_format import BaseRecordsFormat from .loader import LoaderFromFileobj, LoaderFromRecordsDirectory from .unloader import Unloader @@ -7,7 +8,6 @@ import sqlalchemy from sqlalchemy import MetaData from sqlalchemy.schema import Table -from records_mover.db.quoting import quote_group_name, quote_user_name, quote_schema_and_table from abc import ABCMeta, abstractmethod from records_mover.records import RecordsSchema from typing import Union, Dict, List, Tuple, Optional, TYPE_CHECKING @@ -68,16 +68,14 @@ def set_grant_permissions_for_groups(self, db_engine: Optional[sqlalchemy.engine.Engine] = None ) -> None: db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) - schema_and_table: str = quote_schema_and_table(None, schema_name, table, - db_engine=self.db_engine) for perm_type in groups: groups_list = groups[perm_type] for group in groups_list: - group_name: str = quote_group_name(None, group, db_engine=self.db_engine) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") - perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {group_name}' + table_obj = Table(table, MetaData(), schema=schema_name) + perms_sql = str(GrantPrivileges(perm_type, table_obj, group)) if db_conn: db_conn.execute(perms_sql) else: @@ -92,16 +90,14 @@ def set_grant_permissions_for_users(self, schema_name: str, table: str, db_engine: Optional[sqlalchemy.engine.Engine] = None ) -> None: db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) - schema_and_table: str = quote_schema_and_table(None, schema_name, table, - db_engine=self.db_engine) for perm_type in users: user_list = users[perm_type] for user in user_list: - user_name: str = quote_user_name(self.db_engine, user) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") - perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {user_name}' + table_obj = Table(table, MetaData(), schema=schema_name) + perms_sql = str(GrantPrivileges(perm_type, table_obj, user)) if db_conn: db_conn.execute(perms_sql) else: diff --git a/records_mover/db/mysql/load_options.py b/records_mover/db/mysql/load_options.py index aea885ea4..c7f576e18 100644 --- a/records_mover/db/mysql/load_options.py +++ b/records_mover/db/mysql/load_options.py @@ -69,10 +69,12 @@ def generate_load_data_sql(self, filename: str, schema_name: str, table_name: str) -> TextClause: + remove_backticks_schema_name = schema_name.replace('`', '') + remove_backticks_table_name = table_name.replace('`', '') sql = f"""\ LOAD DATA LOCAL INFILE :filename -INTO TABLE {schema_name}.{table_name} +INTO TABLE `{remove_backticks_schema_name}`.`{remove_backticks_table_name}` CHARACTER SET :character_set FIELDS TERMINATED BY :fields_terminated_by diff --git a/records_mover/db/redshift/redshift_db_driver.py b/records_mover/db/redshift/redshift_db_driver.py index 3846237fa..26e066427 100644 --- a/records_mover/db/redshift/redshift_db_driver.py +++ b/records_mover/db/redshift/redshift_db_driver.py @@ -1,6 +1,7 @@ from ..driver import DBDriver import sqlalchemy -from sqlalchemy.schema import Table +from sqlalchemy_privileges import GrantPrivileges # type: ignore[import-untyped] +from sqlalchemy.schema import Table, MetaData from records_mover.records import RecordsSchema from records_mover.records.records_format import BaseRecordsFormat, AvroRecordsFormat import logging @@ -14,7 +15,6 @@ import timeout_decorator from typing import Optional, Union, Dict, List, Tuple from ...url.base import BaseDirectoryUrl -from records_mover.db.quoting import quote_group_name, quote_schema_and_table from .unloader import RedshiftUnloader from ..unloader import Unloader from .loader import RedshiftLoader @@ -85,15 +85,14 @@ def set_grant_permissions_for_groups(self, db_engine: Optional[sqlalchemy.engine.Engine] = None ) -> None: db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) - schema_and_table = quote_schema_and_table(None, schema_name, table, db_engine=db_engine) for perm_type in groups: groups_list = groups[perm_type] for group in groups_list: - group_name: str = quote_group_name(None, group, db_engine=self.db_engine) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") - perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO GROUP {group_name}' + table_obj = Table(table, MetaData(), schema=schema_name) + perms_sql = str(GrantPrivileges(perm_type, table_obj, group)) if db_conn: db_conn.execute(perms_sql) else: diff --git a/records_mover/db/redshift/unloader.py b/records_mover/db/redshift/unloader.py index b4d895faa..9c122b2b3 100644 --- a/records_mover/db/redshift/unloader.py +++ b/records_mover/db/redshift/unloader.py @@ -84,9 +84,9 @@ def unload_to_s3_directory(self, # register_secret(aws_creds.token) register_secret(aws_creds.secret_key) - table_obj = Table(table, MetaData(), autoload_with=self.db_engine, schema=schema) - select = sqlalchemy.select(table_obj) - unload = UnloadFromSelect(select=select, + table_obj = Table(table, MetaData(), schema=schema) + select = str(sqlalchemy.select('*', table_obj)) # type: ignore[arg-type] # noqa: F821 + unload = UnloadFromSelect(select=text(select), access_key_id=aws_creds.access_key, secret_access_key=aws_creds.secret_key, session_token=aws_creds.token, manifest=True, diff --git a/records_mover/db/vertica/vertica_db_driver.py b/records_mover/db/vertica/vertica_db_driver.py index 4529745d9..76fa45592 100644 --- a/records_mover/db/vertica/vertica_db_driver.py +++ b/records_mover/db/vertica/vertica_db_driver.py @@ -48,7 +48,8 @@ def unloader(self) -> Optional[Unloader]: def has_table(self, schema: str, table: str) -> bool: try: table_to_check = Table(table, self.meta, schema=schema) - self.db_conn.execute(select(text("1"), table_to_check)) + self.db_conn.execute( + select(text("1"), table_to_check)) # type: ignore[arg-type] # noqa: F821 return True except sqlalchemy.exc.ProgrammingError: return False diff --git a/records_mover/records/prep.py b/records_mover/records/prep.py index d33abcc29..eb7e21cc8 100644 --- a/records_mover/records/prep.py +++ b/records_mover/records/prep.py @@ -5,7 +5,7 @@ from records_mover.db import DBDriver from records_mover.records.table import TargetTableDetails import logging -from sqlalchemy import text, Table, MetaData +from sqlalchemy import text, Table, MetaData, delete from sqlalchemy.schema import DropTable logger = logging.getLogger(__name__) @@ -64,11 +64,12 @@ def prep_table_for_load(self, logger.info("Truncating...") meta = MetaData() table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name) - db_conn.execute(table.delete()) + db_conn.execute(table.delete().where(True)) logger.info("Truncated.") elif (how_to_prep == ExistingTableHandling.DELETE_AND_OVERWRITE): logger.info("Deleting rows...") - db_conn.execute(text(f"DELETE FROM {schema_and_table} WHERE true")) + table_obj = Table(self.tbl.table_name, MetaData(), schema=self.tbl.schema_name) + db_conn.execute(delete(table_obj).where(True)) logger.info("Deleted") elif (how_to_prep == ExistingTableHandling.DROP_AND_RECREATE): with db_engine.connect() as conn: @@ -78,7 +79,7 @@ def prep_table_for_load(self, meta = MetaData() table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name) drop_table_sql = f"DROP TABLE {schema_and_table}" - conn.execute(DropTable(table)) # type: ignore[arg-type] + conn.execute(DropTable(table)) # type: ignore[arg-type] # noqa: F821 logger.info(f"Just ran {drop_table_sql}") self.create_table(schema_sql, conn, driver) elif (how_to_prep == ExistingTableHandling.APPEND): diff --git a/records_mover/records/sources/table.py b/records_mover/records/sources/table.py index 08196aaa8..b0ab3f4ee 100644 --- a/records_mover/records/sources/table.py +++ b/records_mover/records/sources/table.py @@ -94,7 +94,7 @@ def to_dataframes_source(self, self.table_name, db_engine=db_engine,) logger.info(f"Reading {quoted_table}...") chunks: Generator['DataFrame', None, None] = \ - pandas.read_sql(select('*', table), # type: ignore[arg-type] + pandas.read_sql(select('*', table), # type: ignore[arg-type] # noqa: F821 con=db_conn, chunksize=chunksize) try: diff --git a/setup.py b/setup.py index aa7cb45a4..6fe7211a0 100755 --- a/setup.py +++ b/setup.py @@ -146,6 +146,7 @@ def initialize_options(self) -> None: db_dependencies = [ 'sqlalchemy>=1.4', + 'sqlalchemy_privileges>=0.2.0', ] smart_open_dependencies = [ diff --git a/tests/component/db/mysql/test_load_options_class.py b/tests/component/db/mysql/test_load_options_class.py index 021df411c..f785b6bc3 100644 --- a/tests/component/db/mysql/test_load_options_class.py +++ b/tests/component/db/mysql/test_load_options_class.py @@ -21,7 +21,7 @@ def test_generate_load_data_sql_boring(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'my_filename.txt' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf8' FIELDS TERMINATED BY '\t' @@ -51,7 +51,7 @@ def test_generate_load_data_sql_different_constants(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'another_filename.txt' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf16' FIELDS TERMINATED BY ',' @@ -81,7 +81,7 @@ def test_generate_load_data_sql_field_enclosed_by_None(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'another_filename.txt' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf16' FIELDS TERMINATED BY ',' @@ -110,7 +110,7 @@ def test_generate_load_data_sql_field_optionally_enclosed_by(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'another_filename.txt' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf16' FIELDS TERMINATED BY ',' @@ -160,7 +160,7 @@ def test_generate_load_data_sql_windows_filename(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'c:\\\\Some Path\\\\OH GOD LET IT END~1.CSV' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf16' FIELDS TERMINATED BY ',' @@ -190,7 +190,7 @@ def test_vertica_dialect_style_terminators(self) -> None: expected_sql = """\ LOAD DATA LOCAL INFILE 'another_filename.txt' -INTO TABLE myschema.mytable +INTO TABLE `myschema`.`mytable` CHARACTER SET 'utf16' FIELDS TERMINATED BY '\002' diff --git a/tests/integration/records/purge_old_test_tables.py b/tests/integration/records/purge_old_test_tables.py index f01d630c4..2f27891a4 100755 --- a/tests/integration/records/purge_old_test_tables.py +++ b/tests/integration/records/purge_old_test_tables.py @@ -43,7 +43,7 @@ def purge_old_tables(schema_name: str, table_name_prefix: str, table = Table(table_name, MetaData(), schema=schema_name) with db_engine.connect() as connection: with connection.begin(): - connection.exec_driver_sql(DropTable(table)) # type: ignore[arg-type] + connection.exec_driver_sql(DropTable(table)) # type: ignore[arg-type] # noqa: F821 if __name__ == '__main__': diff --git a/tests/integration/records/records_database_fixture.py b/tests/integration/records/records_database_fixture.py index 0953f366e..183f9f2ba 100644 --- a/tests/integration/records/records_database_fixture.py +++ b/tests/integration/records/records_database_fixture.py @@ -34,7 +34,7 @@ def tear_down(self): def bring_up(self): if self.engine.name == 'redshift': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 123 AS num, '123' AS numstr, 'foo' AS str, @@ -49,7 +49,7 @@ def bring_up(self): """ # noqa elif self.engine.name == 'vertica': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 123 AS num, '123' AS numstr, 'foo' AS str, @@ -64,7 +64,7 @@ def bring_up(self): """ # noqa elif self.engine.name == 'bigquery': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 123 AS num, '123' AS numstr, 'foo' AS str, @@ -79,7 +79,7 @@ def bring_up(self): """ # noqa elif self.engine.name == 'postgresql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 123 AS num, '123' AS numstr, 'foo' AS str, @@ -94,7 +94,7 @@ def bring_up(self): """ # noqa elif self.engine.name == 'mysql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 123 AS num, '123' AS numstr, 'foo' AS str, diff --git a/tests/integration/records/records_datetime_fixture.py b/tests/integration/records/records_datetime_fixture.py index ffebfdc17..8e83db558 100644 --- a/tests/integration/records/records_datetime_fixture.py +++ b/tests/integration/records/records_datetime_fixture.py @@ -40,27 +40,27 @@ def drop_table_if_exists(self, schema, table): def createDateTimeTzTable(self) -> None: if self.engine.name == 'redshift': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d} {SAMPLE_LONG_TZ}'::TIMESTAMPTZ as timestamptz; """ # noqa elif self.engine.name == 'vertica': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d} {SAMPLE_LONG_TZ}'::TIMESTAMPTZ as timestamptz; """ # noqa elif self.engine.name == 'bigquery': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT cast('{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d} {SAMPLE_LONG_TZ}' AS TIMESTAMP) as timestamptz; """ # noqa elif self.engine.name == 'postgresql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d} {SAMPLE_LONG_TZ}'::TIMESTAMPTZ as "timestamptz"; """ # noqa elif self.engine.name == 'mysql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT TIMESTAMP '{SAMPLE_YEAR}-{SAMPLE_MONTH:02d}-{SAMPLE_DAY:02d} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}.000000{SAMPLE_OFFSET}' AS "timestamptz"; """ # noqa else: @@ -72,27 +72,27 @@ def createDateTimeTzTable(self) -> None: def createDateTimeTable(self) -> None: if self.engine.name == 'redshift': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}'::TIMESTAMP AS timestamp; """ # noqa elif self.engine.name == 'vertica': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}'::TIMESTAMP AS timestamp; """ # noqa elif self.engine.name == 'bigquery': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT cast('{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}' AS DATETIME) AS timestamp; """ # noqa elif self.engine.name == 'postgresql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}'::TIMESTAMP AS "timestamp"; """ # noqa elif self.engine.name == 'mysql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT TIMESTAMP '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY} {SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}' AS "timestamp"; """ # noqa else: @@ -109,27 +109,27 @@ def createDateTimeTable(self) -> None: def createDateTable(self) -> None: if self.engine.name == 'redshift': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY}'::DATE AS date; """ # noqa elif self.engine.name == 'vertica': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY}'::DATE AS date; """ # noqa elif self.engine.name == 'bigquery': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT cast('{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY}' as DATE) AS date; """ # noqa elif self.engine.name == 'postgresql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY}'::DATE AS date; """ # noqa elif self.engine.name == 'mysql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT DATE '{SAMPLE_YEAR}-{SAMPLE_MONTH}-{SAMPLE_DAY}' AS "date"; """ # noqa else: @@ -146,27 +146,27 @@ def createDateTable(self) -> None: def createTimeTable(self): if self.engine.name == 'redshift': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}' AS "time"; """ # noqa elif self.engine.name == 'vertica': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}'::TIME AS "time"; """ # noqa elif self.engine.name == 'bigquery': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT cast('{SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}' as TIME) AS time; """ # noqa elif self.engine.name == 'postgresql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT '{SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}'::TIME AS "time"; """ # noqa elif self.engine.name == 'mysql': create_tables = f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT TIME '{SAMPLE_HOUR:02d}:{SAMPLE_MINUTE:02d}:{SAMPLE_SECOND:02d}' AS "time"; """ # noqa else: diff --git a/tests/integration/records/records_numeric_database_fixture.py b/tests/integration/records/records_numeric_database_fixture.py index 23cf4014b..11f7fb9ed 100644 --- a/tests/integration/records/records_numeric_database_fixture.py +++ b/tests/integration/records/records_numeric_database_fixture.py @@ -19,7 +19,7 @@ def bring_up(self): if self.engine.name == 'redshift': # Redshift supports a number of different numeric types create_tables = [f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 32767::smallint AS int16, 2147483647::INTEGER AS int32, 9223372036854775807::BIGINT AS int64, @@ -31,7 +31,7 @@ def bring_up(self): elif self.engine.name == 'vertica': # Vertica only supports a few large numeric types create_tables = [f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 9223372036854775807::BIGINT AS int64, 1234.56::NUMERIC(6, 2) AS fixed_6_2, 19223372036854775807.78::FLOAT AS float64; @@ -40,20 +40,20 @@ def bring_up(self): elif self.engine.name == 'bigquery': # BigQuery only supports a few large numeric types create_tables = [f""" - CREATE TABLE {self.schema_name}.{self.table_name} ( + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} ( `int64` INT64, `fixed_6_2` NUMERIC(6, 2), `float64` FLOAT64); """, # noqa f""" - INSERT INTO {self.schema_name}.{self.table_name} (`int64`, `fixed_6_2`, `float64`) + INSERT INTO {self.quote_schema_and_table(self.schema_name, self.table_name)} (`int64`, `fixed_6_2`, `float64`) VALUES (9223372036854775807, 1234.56, 19223372036854775807.78); """, # noqa ] elif self.engine.name == 'postgresql': # Postgres supports a number of different numeric types create_tables = [f""" - CREATE TABLE {self.schema_name}.{self.table_name} AS + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS SELECT 32767::smallint AS int16, 2147483647::INTEGER AS int32, 9223372036854775807::BIGINT AS int64, @@ -67,7 +67,7 @@ def bring_up(self): # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html # create_tables = [f""" - CREATE TABLE {self.schema_name}.{self.table_name} ( + CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} ( `int8` TINYINT, `uint8` TINYINT UNSIGNED, `int16` SMALLINT, @@ -86,7 +86,7 @@ def bring_up(self): ); """, # noqa f""" - INSERT INTO {self.schema_name}.{self.table_name} + INSERT INTO {self.quote_schema_and_table(self.schema_name, self.table_name)} ( `int8`, `uint8`, diff --git a/tests/integration/records/single_db/test_records_load_datetime.py b/tests/integration/records/single_db/test_records_load_datetime.py index b00b0e1ac..8426eb299 100644 --- a/tests/integration/records/single_db/test_records_load_datetime.py +++ b/tests/integration/records/single_db/test_records_load_datetime.py @@ -8,7 +8,7 @@ ) from records_mover.records import RecordsSchema, RecordsFormat, PartialRecordsHints from records_mover.records.schema.field.field_types import FieldType -from sqlalchemy import text +from sqlalchemy import Table, MetaData, select, text logger = logging.getLogger(__name__) @@ -59,11 +59,11 @@ def load(self, def pull_result(self, column_name: str) -> datetime.datetime: + table_obj = Table(self.table_name, MetaData(), schema=self.schema_name) with self.engine.connect() as connection: with connection.begin(): - out = connection.execute(text( - f'SELECT {column_name} ' - f'from {self.schema_name}.{self.table_name}')) + out = connection.execute( + select(text(column_name), table_obj)) # type: ignore[arg-type] # noqa: F821 ret_all = out.fetchall() assert 1 == len(ret_all) ret = ret_all[0] diff --git a/tests/integration/records/table_validator.py b/tests/integration/records/table_validator.py index 2b730466c..e65869c96 100644 --- a/tests/integration/records/table_validator.py +++ b/tests/integration/records/table_validator.py @@ -11,6 +11,7 @@ expected_table2table_column_types ) from records_mover.records import DelimitedVariant +from records_mover.db.quoting import quote_schema_and_table from .mover_test_case import MoverTestCase from .table_timezone_validator import RecordsTableTimezoneValidator @@ -127,7 +128,7 @@ def format_actual_expected_column_types(*expected): self.target_db_engine.name)), expected_single_database_column_types[self.source_db_engine.name], expected_single_database_column_types[self.target_db_engine.name], - expected_df_loaded_database_column_types.get(self.target_db_engine.name))),\ + expected_df_loaded_database_column_types.get(self.target_db_engine.name))), \ f'Could not find column types filed under '\ f"{(self.source_db_engine.name, self.target_db_engine.name)} "\ 'or either individually: '\ @@ -138,6 +139,11 @@ def validate_data_values(self, table_name: str) -> None: params = {} + quoted_schema_and_table = quote_schema_and_table(None, + schema=schema_name, + table=table_name, + db_engine=self.target_db_engine) + load_variant = self.tc.determine_load_variant() with self.target_db_engine.connect() as connection: @@ -170,7 +176,7 @@ def validate_data_values(self, timestamptz, format_timestamp(:tzformatstr, CAST(`timestamptz` as timestamp)) as timestamptzstr - FROM {schema_name}.{table_name} + FROM {quoted_schema_and_table} """) params = { "tzformatstr": "%E4Y-%m-%d %H:%M:%E*S %Z", @@ -183,7 +189,7 @@ def validate_data_values(self, DATE_FORMAT(`timestamp`, '%Y-%m-%d %H:%i:%s.%f') as timestampstr, timestamptz, DATE_FORMAT(timestamptz, '%Y-%m-%d %H:%i:%s.%f+00') as timestamptzstr - FROM {schema_name}.{table_name} + FROM {quoted_schema_and_table} """) elif self.tc.raw_avro_types_written(): # no real date/time column types used, so can't cast types @@ -193,7 +199,7 @@ def validate_data_values(self, "timestamp" as timestampstr, timestamptz, timestamptz as timestamptzstr - FROM {schema_name}.{table_name} + FROM {quoted_schema_and_table} """) else: select_sql = text(f""" @@ -202,7 +208,7 @@ def validate_data_values(self, to_char("timestamp", 'YYYY-MM-DD HH24:MI:SS.US') as timestampstr, timestamptz, to_char(timestamptz, 'YYYY-MM-DD HH24:MI:SS.US TZ') as timestamptzstr - FROM {schema_name}.{table_name} + FROM {quoted_schema_and_table} """) out = connection.execute(select_sql, params) ret_all = out.fetchall() @@ -217,17 +223,17 @@ def validate_data_values(self, if self.tc.raw_avro_types_written(): assert ret.date == 10957, ret.date else: - assert ret.date == datetime.date(2000, 1, 1),\ + assert ret.date == datetime.date(2000, 1, 1), \ f"Expected datetime.date(2000, 1, 1), got {ret.date}" if self.tc.raw_avro_types_written(): assert ret.time == 0, ret.time elif self.tc.supports_time_without_date(): if self.tc.selects_time_types_as_timedelta(): - assert ret.time == datetime.timedelta(0, 0),\ + assert ret.time == datetime.timedelta(0, 0), \ f"Incorrect time: {ret.time} (of type {type(ret.time)})" else: - assert ret.time == datetime.time(0, 0),\ + assert ret.time == datetime.time(0, 0), \ f"Incorrect time: {ret.time} (of type {type(ret.time)})" else: # fall back to storing as string @@ -243,11 +249,11 @@ def validate_data_values(self, ((self.file_variant is not None) and self.tc.variant_doesnt_support_seconds(self.file_variant))): assert ret.timestamp ==\ - datetime.datetime(2000, 1, 2, 12, 34),\ + datetime.datetime(2000, 1, 2, 12, 34), \ f"Found timestamp {ret.timestamp}" else: - assert (ret.timestamp == datetime.datetime(2000, 1, 2, 12, 34, 56, 789012)),\ + assert (ret.timestamp == datetime.datetime(2000, 1, 2, 12, 34, 56, 789012)), \ f"ret.timestamp was {ret.timestamp} of type {type(ret.timestamp)}" print("ROW OUTPUT VALUES:", ret) print("ROW OUTPUT AS DICTIONARY:", ret._asdict()) diff --git a/tests/unit/db/redshift/test_redshift_db_driver.py b/tests/unit/db/redshift/test_redshift_db_driver.py index 4f94e9221..93f274d27 100644 --- a/tests/unit/db/redshift/test_redshift_db_driver.py +++ b/tests/unit/db/redshift/test_redshift_db_driver.py @@ -1,5 +1,4 @@ from .base_test_redshift_db_driver import BaseTestRedshiftDBDriver -from unittest.mock import patch import sqlalchemy @@ -16,14 +15,9 @@ def test_schema_sql_no_admin_views(self): sql = self.redshift_db_driver.schema_sql('myschema', 'mytable') self.assertEqual(sql, mock_schema_sql) - @patch('records_mover.db.redshift.redshift_db_driver.quote_group_name') - @patch('records_mover.db.redshift.redshift_db_driver.quote_schema_and_table') - def test_set_grant_permissions_for_group(self, mock_quote_schema_and_table, - mock_quote_group_name): + def test_set_grant_permissions_for_group(self): mock_schema = 'mock_schema' mock_table = 'mock_table' - mock_quote_schema_and_table.return_value = 'mock_schema.mock_table' - mock_quote_group_name.return_value = '"a_group"' groups = {'all': ['a_group']} mock_conn = self.mock_db_engine.engine.connect.return_value.__enter__.return_value self.redshift_db_driver.set_grant_permissions_for_groups(mock_schema, @@ -32,7 +26,7 @@ def test_set_grant_permissions_for_group(self, mock_quote_schema_and_table, None, db_conn=mock_conn) mock_conn.execute.assert_called_with( - f'GRANT all ON TABLE {mock_schema}.{mock_table} TO GROUP "a_group"') + f'GRANT ALL ON {mock_schema}.{mock_table} TO "a_group"\n') def test_best_scheme_to_load_from(self): out = self.redshift_db_driver.loader().best_scheme_to_load_from() @@ -95,7 +89,7 @@ def test_type_for_floating_point(self): (64, 49): 49, (100, 80): 53, } - for (input_fp_total_bits, input_fp_significand_bits),\ + for (input_fp_total_bits, input_fp_significand_bits), \ expected_fp_significand_bits in expectations.items(): actual_col_type =\ self.redshift_db_driver.type_for_floating_point(input_fp_total_bits, diff --git a/tests/unit/db/redshift/test_redshift_db_driver_unload.py b/tests/unit/db/redshift/test_redshift_db_driver_unload.py index a68adf374..86020d9b8 100644 --- a/tests/unit/db/redshift/test_redshift_db_driver_unload.py +++ b/tests/unit/db/redshift/test_redshift_db_driver_unload.py @@ -44,7 +44,7 @@ def test_unload_to_non_s3(self, 'gzip': True, 'manifest': True, 'secret_access_key': mock_secret_key, - 'select': ('SELECT * FROM myschema.mytable',), + 'select': ('SELECT * \nFROM myschema.mytable',), 'session_token': mock_token, 'unload_location': self.mock_s3_temp_base_loc.temporary_directory().__enter__().url } @@ -80,7 +80,7 @@ def test_unload(self, 'gzip': True, 'manifest': True, 'secret_access_key': 'fake_aws_secret', - 'select': ('SELECT * FROM myschema.mytable',), + 'select': ('SELECT * \nFROM myschema.mytable',), 'session_token': 'fake_aws_token', 'unload_location': 's3://mybucket/myparent/mychild/' } diff --git a/tests/unit/db/test_db_driver.py b/tests/unit/db/test_db_driver.py index c2060105b..e07e9f749 100644 --- a/tests/unit/db/test_db_driver.py +++ b/tests/unit/db/test_db_driver.py @@ -1,6 +1,6 @@ from .fakes import fake_text import unittest -from mock import Mock, MagicMock, patch, call +from mock import Mock, MagicMock, call from records_mover.db.driver import GenericDBDriver import sqlalchemy @@ -74,65 +74,49 @@ def test_has_table(self): self.mock_db_engine._sa_instance_state.has_table.assert_called_with(mock_table, schema=mock_schema) - @patch('records_mover.db.driver.quote_group_name') - @patch('records_mover.db.driver.quote_schema_and_table') - def test_set_grant_permissions_for_groups(self, - mock_quote_schema_and_table, - mock_quote_group_name): - mock_schema_name = Mock(name='schema_name') - mock_table = Mock(name='table') + def test_set_grant_permissions_for_groups(self): + mock_schema_name = 'schema_name' + mock_table = 'table' mock_db = Mock(name='db') groups = { - 'write': ['group_a', 'group_b'] + 'insert': ['group_a', 'group_b'] } - mock_schema_and_table = mock_quote_schema_and_table.return_value - mock_group_name = mock_quote_group_name.return_value self.db_driver.set_grant_permissions_for_groups(mock_schema_name, mock_table, groups, None, db_conn=mock_db) - mock_quote_schema_and_table.assert_called_with(None, - mock_schema_name, - mock_table, - db_engine=self.mock_db_engine) mock_db.execute.assert_has_calls([ - call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_group_name}"), - call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_group_name}"), + call( + f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" " + f"TO \"{groups['insert'][0]}\"\n"), + call( + f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" " + f"TO \"{groups['insert'][1]}\"\n"), ]) - @patch('records_mover.db.driver.quote_user_name') - @patch('records_mover.db.driver.quote_schema_and_table') - def test_set_grant_permissions_for_users(self, - mock_quote_schema_and_table, - mock_quote_user_name): - mock_schema_name = Mock(name='schema_name') - mock_table = Mock(name='table') + def test_set_grant_permissions_for_users(self): + mock_schema_name = 'schema_name' + mock_table = 'my_table' mock_db = Mock(name='db') users = { - 'write': ['user_a', 'user_b'] + 'insert': ['user_a', 'user_b'] } - mock_schema_and_table = mock_quote_schema_and_table.return_value - mock_user_name = mock_quote_user_name.return_value self.db_driver.set_grant_permissions_for_users(mock_schema_name, mock_table, users, None, db_conn=mock_db) - mock_quote_schema_and_table.assert_called_with(None, - mock_schema_name, - mock_table, - db_engine=self.mock_db_engine) mock_db.execute.assert_has_calls([ - call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_user_name}"), - call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_user_name}"), + call( + f"GRANT INSERT ON {mock_schema_name}.{mock_table} " + f"TO \"{users['insert'][0]}\"\n"), + call( + f"GRANT INSERT ON {mock_schema_name}.{mock_table} " + f"TO \"{users['insert'][1]}\"\n"), ]) - @patch('records_mover.db.driver.quote_user_name') - @patch('records_mover.db.driver.quote_schema_and_table') - def test_set_grant_permissions_for_users_bobby_tables(self, - mock_quote_schema_and_table, - mock_quote_user_name): + def test_set_grant_permissions_for_users_bobby_tables(self): mock_schema_name = Mock(name='schema_name') mock_table = Mock(name='table') mock_db = Mock(name='db') @@ -146,11 +130,7 @@ def test_set_grant_permissions_for_users_bobby_tables(self, None, db_conn=mock_db) - @patch('records_mover.db.driver.quote_user_name') - @patch('records_mover.db.driver.quote_schema_and_table') - def test_set_grant_permissions_for_groups_bobby_tables(self, - mock_quote_schema_and_table, - mock_quote_user_name): + def test_set_grant_permissions_for_groups_bobby_tables(self): mock_schema_name = Mock(name='schema_name') mock_table = Mock(name='table') mock_db = Mock(name='db') diff --git a/tests/unit/records/test_prep.py b/tests/unit/records/test_prep.py index e4ded4932..737d484d7 100644 --- a/tests/unit/records/test_prep.py +++ b/tests/unit/records/test_prep.py @@ -44,8 +44,9 @@ def test_prep_table_exists_truncate_implicit(self, mock_quote_schema_and_table): self.mock_tbl.table_name, db_engine=mock_driver.db_engine) str_arg = str(mock_driver.db_conn.execute.call_args.args[0]) - self.assertEqual(str_arg, - f"DELETE FROM {self.mock_tbl.schema_name}.{self.mock_tbl.table_name}") + self.assertEqual( + str_arg, + f"DELETE FROM {self.mock_tbl.schema_name}.{self.mock_tbl.table_name} WHERE true") @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): @@ -57,7 +58,6 @@ def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): mock_quote_schema_and_table mock_driver.has_table.return_value = True how_to_prep = ExistingTableHandling.DELETE_AND_OVERWRITE - mock_schema_and_table = mock_quote_schema_and_table.return_value self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver) @@ -67,7 +67,9 @@ def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): self.mock_tbl.table_name, db_engine=mock_driver.db_engine) str_arg = str(mock_driver.db_conn.execute.call_args.args[0]) - self.assertEqual(str_arg, f"DELETE FROM {mock_schema_and_table} WHERE true") + self.assertEqual( + str_arg, + f"DELETE FROM {self.mock_tbl.schema_name}.{self.mock_tbl.table_name} WHERE true") @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_drop_implicit(self, mock_quote_schema_and_table): From 4dbf0fe899044e25bfb012d026e0ca96970e7a62 Mon Sep 17 00:00:00 2001 From: Tim Ryan Date: Mon, 20 Nov 2023 17:06:24 +0000 Subject: [PATCH 4/4] RM-95 wrap string in sqlalchemy.text RM-95 quality fix --- records_mover/db/driver.py | 6 +++--- tests/unit/db/test_db_driver.py | 28 +++++++++++----------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/records_mover/db/driver.py b/records_mover/db/driver.py index c9e16ee47..c27de2ace 100644 --- a/records_mover/db/driver.py +++ b/records_mover/db/driver.py @@ -6,7 +6,7 @@ from .unloader import Unloader import logging import sqlalchemy -from sqlalchemy import MetaData +from sqlalchemy import MetaData, text from sqlalchemy.schema import Table from abc import ABCMeta, abstractmethod from records_mover.records import RecordsSchema @@ -75,7 +75,7 @@ def set_grant_permissions_for_groups(self, raise TypeError("Please make sure your permission types" " are an acceptable value.") table_obj = Table(table, MetaData(), schema=schema_name) - perms_sql = str(GrantPrivileges(perm_type, table_obj, group)) + perms_sql = text(str(GrantPrivileges(perm_type, table_obj, group))) if db_conn: db_conn.execute(perms_sql) else: @@ -97,7 +97,7 @@ def set_grant_permissions_for_users(self, schema_name: str, table: str, raise TypeError("Please make sure your permission types" " are an acceptable value.") table_obj = Table(table, MetaData(), schema=schema_name) - perms_sql = str(GrantPrivileges(perm_type, table_obj, user)) + perms_sql = text(str(GrantPrivileges(perm_type, table_obj, user))) if db_conn: db_conn.execute(perms_sql) else: diff --git a/tests/unit/db/test_db_driver.py b/tests/unit/db/test_db_driver.py index e07e9f749..1e8b9e23c 100644 --- a/tests/unit/db/test_db_driver.py +++ b/tests/unit/db/test_db_driver.py @@ -1,6 +1,6 @@ from .fakes import fake_text import unittest -from mock import Mock, MagicMock, call +from mock import Mock, MagicMock from records_mover.db.driver import GenericDBDriver import sqlalchemy @@ -86,14 +86,11 @@ def test_set_grant_permissions_for_groups(self): groups, None, db_conn=mock_db) - mock_db.execute.assert_has_calls([ - call( - f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" " - f"TO \"{groups['insert'][0]}\"\n"), - call( - f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" " - f"TO \"{groups['insert'][1]}\"\n"), - ]) + str_args = [str(arg[0][0]) for arg in mock_db.execute.call_args_list] + self.assertEqual([ + f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" TO \"{groups['insert'][0]}\"\n", + f"GRANT INSERT ON {mock_schema_name}.\"{mock_table}\" TO \"{groups['insert'][1]}\"\n", + ], str_args) def test_set_grant_permissions_for_users(self): mock_schema_name = 'schema_name' @@ -107,14 +104,11 @@ def test_set_grant_permissions_for_users(self): users, None, db_conn=mock_db) - mock_db.execute.assert_has_calls([ - call( - f"GRANT INSERT ON {mock_schema_name}.{mock_table} " - f"TO \"{users['insert'][0]}\"\n"), - call( - f"GRANT INSERT ON {mock_schema_name}.{mock_table} " - f"TO \"{users['insert'][1]}\"\n"), - ]) + str_args = [str(arg[0][0]) for arg in mock_db.execute.call_args_list] + self.assertEqual([ + f"GRANT INSERT ON {mock_schema_name}.{mock_table} TO \"{users['insert'][0]}\"\n", + f"GRANT INSERT ON {mock_schema_name}.{mock_table} TO \"{users['insert'][1]}\"\n", + ], str_args) def test_set_grant_permissions_for_users_bobby_tables(self): mock_schema_name = Mock(name='schema_name')