Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rm 95 remove sql injection possibilities #254

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions records_mover/db/driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from ..check_db_conn_engine import check_db_conn_engine
from sqlalchemy.schema import CreateTable
from sqlalchemy_privileges import GrantPrivileges # type: ignore[import-untyped]
from ..records.records_format import BaseRecordsFormat
from .loader import LoaderFromFileobj, LoaderFromRecordsDirectory
from .unloader import Unloader
import logging
import sqlalchemy
from sqlalchemy import MetaData
from sqlalchemy import MetaData, text
from sqlalchemy.schema import Table
from records_mover.db.quoting import quote_group_name, quote_user_name, quote_schema_and_table
from abc import ABCMeta, abstractmethod
from records_mover.records import RecordsSchema
from typing import Union, Dict, List, Tuple, Optional, TYPE_CHECKING
Expand Down Expand Up @@ -68,16 +68,14 @@ def set_grant_permissions_for_groups(self,
db_engine: Optional[sqlalchemy.engine.Engine] = None
) -> None:
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
schema_and_table: str = quote_schema_and_table(None, schema_name, table,
db_engine=self.db_engine)
for perm_type in groups:
groups_list = groups[perm_type]
for group in groups_list:
group_name: str = quote_group_name(None, group, db_engine=self.db_engine)
if not perm_type.isalpha():
raise TypeError("Please make sure your permission types"
" are an acceptable value.")
perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {group_name}'
table_obj = Table(table, MetaData(), schema=schema_name)
ryantimjohn marked this conversation as resolved.
Show resolved Hide resolved
perms_sql = text(str(GrantPrivileges(perm_type, table_obj, group)))
if db_conn:
db_conn.execute(perms_sql)
else:
Expand All @@ -92,16 +90,14 @@ def set_grant_permissions_for_users(self, schema_name: str, table: str,
db_engine: Optional[sqlalchemy.engine.Engine] = None
) -> None:
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
schema_and_table: str = quote_schema_and_table(None, schema_name, table,
db_engine=self.db_engine)
for perm_type in users:
user_list = users[perm_type]
for user in user_list:
user_name: str = quote_user_name(self.db_engine, user)
if not perm_type.isalpha():
raise TypeError("Please make sure your permission types"
" are an acceptable value.")
perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {user_name}'
table_obj = Table(table, MetaData(), schema=schema_name)
perms_sql = text(str(GrantPrivileges(perm_type, table_obj, user)))
if db_conn:
db_conn.execute(perms_sql)
else:
Expand Down
4 changes: 3 additions & 1 deletion records_mover/db/mysql/load_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def generate_load_data_sql(self,
filename: str,
schema_name: str,
table_name: str) -> TextClause:
remove_backticks_schema_name = schema_name.replace('`', '')
remove_backticks_table_name = table_name.replace('`', '')
sql = f"""\
LOAD DATA
LOCAL INFILE :filename
INTO TABLE {schema_name}.{table_name}
INTO TABLE `{remove_backticks_schema_name}`.`{remove_backticks_table_name}`
ryantimjohn marked this conversation as resolved.
Show resolved Hide resolved
CHARACTER SET :character_set
FIELDS
TERMINATED BY :fields_terminated_by
Expand Down
9 changes: 4 additions & 5 deletions records_mover/db/redshift/redshift_db_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..driver import DBDriver
import sqlalchemy
from sqlalchemy.schema import Table
from sqlalchemy_privileges import GrantPrivileges # type: ignore[import-untyped]
from sqlalchemy.schema import Table, MetaData
from records_mover.records import RecordsSchema
from records_mover.records.records_format import BaseRecordsFormat, AvroRecordsFormat
import logging
Expand All @@ -14,7 +15,6 @@
import timeout_decorator
from typing import Optional, Union, Dict, List, Tuple
from ...url.base import BaseDirectoryUrl
from records_mover.db.quoting import quote_group_name, quote_schema_and_table
from .unloader import RedshiftUnloader
from ..unloader import Unloader
from .loader import RedshiftLoader
Expand Down Expand Up @@ -85,15 +85,14 @@ def set_grant_permissions_for_groups(self,
db_engine: Optional[sqlalchemy.engine.Engine] = None
) -> None:
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
schema_and_table = quote_schema_and_table(None, schema_name, table, db_engine=db_engine)
for perm_type in groups:
groups_list = groups[perm_type]
for group in groups_list:
group_name: str = quote_group_name(None, group, db_engine=self.db_engine)
if not perm_type.isalpha():
raise TypeError("Please make sure your permission types"
" are an acceptable value.")
perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO GROUP {group_name}'
table_obj = Table(table, MetaData(), schema=schema_name)
perms_sql = str(GrantPrivileges(perm_type, table_obj, group))
if db_conn:
db_conn.execute(perms_sql)
else:
Expand Down
9 changes: 4 additions & 5 deletions records_mover/db/redshift/unloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from sqlalchemy_redshift.commands import UnloadFromSelect
from ...records.records_directory import RecordsDirectory
import sqlalchemy
from sqlalchemy import MetaData
from sqlalchemy.sql import text
from sqlalchemy.schema import Table
import logging
from records_mover.db.quoting import quote_schema_and_table
from records_mover.logging import register_secret
from .records_unload import redshift_unload_options
from ...records.unload_plan import RecordsUnloadPlan
Expand Down Expand Up @@ -84,10 +84,9 @@ def unload_to_s3_directory(self,
#
register_secret(aws_creds.token)
register_secret(aws_creds.secret_key)
select = text(
"SELECT * FROM "
f"{quote_schema_and_table(None, schema, table, db_engine=self.db_engine)}")
unload = UnloadFromSelect(select=select,
table_obj = Table(table, MetaData(), schema=schema)
select = str(sqlalchemy.select('*', table_obj)) # type: ignore[arg-type] # noqa: F821
ryantimjohn marked this conversation as resolved.
Show resolved Hide resolved
unload = UnloadFromSelect(select=text(select),
access_key_id=aws_creds.access_key,
secret_access_key=aws_creds.secret_key,
session_token=aws_creds.token, manifest=True,
Expand Down
9 changes: 4 additions & 5 deletions records_mover/db/vertica/vertica_db_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..driver import DBDriver
import sqlalchemy
from sqlalchemy.sql import text
from records_mover.db.quoting import quote_schema_and_table
from sqlalchemy import select
from sqlalchemy.schema import Table, Column
import logging
from typing import Optional, Union, Tuple
Expand Down Expand Up @@ -47,10 +47,9 @@ def unloader(self) -> Optional[Unloader]:

def has_table(self, schema: str, table: str) -> bool:
try:
sql = ("SELECT 1 "
f"from {quote_schema_and_table(None, schema, table, db_engine=self.db_engine)} "
"limit 0;")
self.db_conn.execute(text(sql))
table_to_check = Table(table, self.meta, schema=schema)
self.db_conn.execute(
select(text("1"), table_to_check)) # type: ignore[arg-type] # noqa: F821
return True
except sqlalchemy.exc.ProgrammingError:
return False
Expand Down
14 changes: 10 additions & 4 deletions records_mover/records/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from records_mover.db import DBDriver
from records_mover.records.table import TargetTableDetails
import logging
from sqlalchemy import text
from sqlalchemy import text, Table, MetaData, delete
from sqlalchemy.schema import DropTable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,19 +62,24 @@ def prep_table_for_load(self,
db_engine=db_engine,)
if (how_to_prep == ExistingTableHandling.TRUNCATE_AND_OVERWRITE):
logger.info("Truncating...")
db_conn.execute(text(f"TRUNCATE TABLE {schema_and_table}"))
meta = MetaData()
table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name)
ryantimjohn marked this conversation as resolved.
Show resolved Hide resolved
db_conn.execute(table.delete().where(True))
logger.info("Truncated.")
elif (how_to_prep == ExistingTableHandling.DELETE_AND_OVERWRITE):
logger.info("Deleting rows...")
db_conn.execute(text(f"DELETE FROM {schema_and_table} WHERE true"))
table_obj = Table(self.tbl.table_name, MetaData(), schema=self.tbl.schema_name)
db_conn.execute(delete(table_obj).where(True))
logger.info("Deleted")
elif (how_to_prep == ExistingTableHandling.DROP_AND_RECREATE):
with db_engine.connect() as conn:
with conn.begin():
logger.info(f"The connection object is: {conn}")
logger.info("Dropping and recreating...")
meta = MetaData()
table = Table(self.tbl.table_name, meta, schema=self.tbl.schema_name)
drop_table_sql = f"DROP TABLE {schema_and_table}"
conn.execute(text(drop_table_sql))
conn.execute(DropTable(table)) # type: ignore[arg-type] # noqa: F821
logger.info(f"Just ran {drop_table_sql}")
self.create_table(schema_sql, conn, driver)
elif (how_to_prep == ExistingTableHandling.APPEND):
Expand Down
7 changes: 5 additions & 2 deletions records_mover/records/sources/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..records_format import BaseRecordsFormat
from ..unload_plan import RecordsUnloadPlan
from ..results import MoveResult
from sqlalchemy import text
from sqlalchemy import Table, MetaData, select
from sqlalchemy.engine import Engine
from contextlib import contextmanager
from ..schema import RecordsSchema
Expand Down Expand Up @@ -88,10 +88,13 @@ def to_dataframes_source(self,
chunksize = int(entries_per_chunk / num_columns)
logger.info(f"Exporting in chunks of up to {chunksize} rows by {num_columns} columns")

meta = MetaData()
table = Table(self.table_name, meta, schema=self.schema_name)
quoted_table = quote_schema_and_table(None, self.schema_name,
self.table_name, db_engine=db_engine,)
logger.info(f"Reading {quoted_table}...")
chunks: Generator['DataFrame', None, None] = \
pandas.read_sql(text(f"SELECT * FROM {quoted_table}"),
pandas.read_sql(select('*', table), # type: ignore[arg-type] # noqa: F821
con=db_conn,
chunksize=chunksize)
try:
Expand Down
5 changes: 3 additions & 2 deletions records_mover/records/targets/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import sqlalchemy
from sqlalchemy import text
from sqlalchemy.schema import DropTable


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,10 +85,10 @@ def prep_bucket(self) -> None:
db_engine=self.db_engine)
logger.info(f"Dropping external table {schema_and_table}...")
with self.db_engine.connect() as conn:
table = Table(self.table_name, MetaData(), schema=self.schema_name)
# See below note about fix from Spectrify
conn.execution_options(isolation_level='AUTOCOMMIT')
conn.execute(text(f"DROP TABLE IF EXISTS {schema_and_table}"))

conn.execute(DropTable(table, if_exists=True)) # type: ignore[call-arg, arg-type]
logger.info(f"Deleting files in {self.output_loc}...")
self.output_loc.purge_directory()
elif self.existing_table_handling == ExistingTableHandling.APPEND:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def initialize_options(self) -> None:

db_dependencies = [
'sqlalchemy>=1.4',
'sqlalchemy_privileges>=0.2.0',
]

smart_open_dependencies = [
Expand Down
12 changes: 6 additions & 6 deletions tests/component/db/mysql/test_load_options_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_generate_load_data_sql_boring(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'my_filename.txt'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf8'
FIELDS
TERMINATED BY '\t'
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_generate_load_data_sql_different_constants(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'another_filename.txt'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf16'
FIELDS
TERMINATED BY ','
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_generate_load_data_sql_field_enclosed_by_None(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'another_filename.txt'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf16'
FIELDS
TERMINATED BY ','
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_generate_load_data_sql_field_optionally_enclosed_by(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'another_filename.txt'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf16'
FIELDS
TERMINATED BY ','
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_generate_load_data_sql_windows_filename(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'c:\\\\Some Path\\\\OH GOD LET IT END~1.CSV'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf16'
FIELDS
TERMINATED BY ','
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_vertica_dialect_style_terminators(self) -> None:
expected_sql = """\
LOAD DATA
LOCAL INFILE 'another_filename.txt'
INTO TABLE myschema.mytable
INTO TABLE `myschema`.`mytable`
CHARACTER SET 'utf16'
FIELDS
TERMINATED BY '\002'
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/records/purge_old_test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from records_mover.db.quoting import quote_schema_and_table
from records_mover import Session
from datetime import datetime, timedelta
from sqlalchemy import inspect
from sqlalchemy import inspect, Table, MetaData
from sqlalchemy.schema import DropTable
from typing import Optional
import sys

Expand Down Expand Up @@ -39,9 +40,10 @@ def purge_old_tables(schema_name: str, table_name_prefix: str,
"DROP TABLE "
f"{quote_schema_and_table(None, schema_name, table_name, db_engine=db_engine)}")
print(sql)
table = Table(table_name, MetaData(), schema=schema_name)
with db_engine.connect() as connection:
with connection.begin():
connection.exec_driver_sql(sql)
connection.exec_driver_sql(DropTable(table)) # type: ignore[arg-type] # noqa: F821


if __name__ == '__main__':
Expand Down
17 changes: 9 additions & 8 deletions tests/integration/records/records_database_fixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from records_mover.db.quoting import quote_schema_and_table
from records_mover.utils.retry import bigquery_retry
from sqlalchemy import text
from sqlalchemy import Table, MetaData
from sqlalchemy.schema import DropTable
import logging

logger = logging.getLogger(__name__)
Expand All @@ -20,10 +21,10 @@ def __init__(self, db_engine, schema_name, table_name):

@bigquery_retry()
def drop_table_if_exists(self, schema, table):
sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}"
table_to_drop = Table(table, MetaData(), schema=schema)
with self.engine.connect() as connection:
with connection.begin():
connection.execute(text(sql))
connection.execute(DropTable(table_to_drop, if_exists=True))

def tear_down(self):
self.drop_table_if_exists(self.schema_name, f"{self.table_name}_frozen")
Expand All @@ -33,7 +34,7 @@ def tear_down(self):
def bring_up(self):
if self.engine.name == 'redshift':
create_tables = f"""
CREATE TABLE {self.schema_name}.{self.table_name} AS
CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS
SELECT 123 AS num,
'123' AS numstr,
'foo' AS str,
Expand All @@ -48,7 +49,7 @@ def bring_up(self):
""" # noqa
elif self.engine.name == 'vertica':
create_tables = f"""
CREATE TABLE {self.schema_name}.{self.table_name} AS
CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS
SELECT 123 AS num,
'123' AS numstr,
'foo' AS str,
Expand All @@ -63,7 +64,7 @@ def bring_up(self):
""" # noqa
elif self.engine.name == 'bigquery':
create_tables = f"""
CREATE TABLE {self.schema_name}.{self.table_name} AS
CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS
SELECT 123 AS num,
'123' AS numstr,
'foo' AS str,
Expand All @@ -78,7 +79,7 @@ def bring_up(self):
""" # noqa
elif self.engine.name == 'postgresql':
create_tables = f"""
CREATE TABLE {self.schema_name}.{self.table_name} AS
CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS
SELECT 123 AS num,
'123' AS numstr,
'foo' AS str,
Expand All @@ -93,7 +94,7 @@ def bring_up(self):
""" # noqa
elif self.engine.name == 'mysql':
create_tables = f"""
CREATE TABLE {self.schema_name}.{self.table_name} AS
CREATE TABLE {self.quote_schema_and_table(self.schema_name, self.table_name)} AS
SELECT 123 AS num,
'123' AS numstr,
'foo' AS str,
Expand Down
Loading