From 515718456335955e76af0b5f75edf16f15b296fb Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 5 Dec 2023 05:13:35 -0600 Subject: [PATCH] fix: Add parameterized queries where possible to address the risk of SQL injection (#2540) * fix 3 instances of potential SQL injection * formatting fix * fix data API rds formatting --- awswrangler/data_api/rds.py | 10 +++++++++- awswrangler/mysql.py | 23 ++++++++++++++++------- awswrangler/sqlserver.py | 16 ++++++++++++---- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/awswrangler/data_api/rds.py b/awswrangler/data_api/rds.py index e9a5e9d86..698db2621 100644 --- a/awswrangler/data_api/rds.py +++ b/awswrangler/data_api/rds.py @@ -292,7 +292,15 @@ def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, def _does_table_exist(con: RdsDataApi, table: str, database: str, transaction_id: str) -> bool: - res = con.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table}'") + res = con.execute( + "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = :table", + parameters=[ + { + "name": "table", + "value": {"stringValue": table}, + }, + ], + ) return not res.empty diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 689116897..57eb59f69 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload import boto3 import pyarrow as pa @@ -13,12 +13,21 @@ from awswrangler import _databases as _db_utils from awswrangler._config import apply_configs -pymysql = _utils.import_optional_dependency("pymysql") +if TYPE_CHECKING: + try: + import pymysql + from pymysql.connections import Connection + from pymysql.cursors import Cursor + except ImportError: + pass +else: + pymysql = _utils.import_optional_dependency("pymysql") + _logger: logging.Logger = logging.getLogger(__name__) -def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None: +def _validate_connection(con: "Connection[Any]") -> None: if not isinstance(con, pymysql.connections.Connection): raise exceptions.InvalidConnection( "Invalid 'conn' argument, please pass a " @@ -27,16 +36,16 @@ def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None: ) -def _drop_table(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> None: +def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None: schema_str = f"`{schema}`." if schema else "" sql = f"DROP TABLE IF EXISTS {schema_str}`{table}`" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) -def _does_table_exist(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> bool: +def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" - cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") + cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = %s", args=[table]) return len(cursor.fetchall()) > 0 @@ -164,7 +173,7 @@ def connect( password=attrs.password, port=attrs.port, host=attrs.host, - ssl=attrs.ssl_context, + ssl=attrs.ssl_context, # type: ignore[arg-type] read_timeout=read_timeout, write_timeout=write_timeout, connect_timeout=connect_timeout, diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index b51c21d6f..2990ec05e 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -3,6 +3,7 @@ import logging from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -26,7 +27,14 @@ __all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"] -pyodbc = _utils.import_optional_dependency("pyodbc") +if TYPE_CHECKING: + try: + import pyodbc + from pyodbc import Cursor + except ImportError: + pass +else: + pyodbc = _utils.import_optional_dependency("pyodbc") _logger: logging.Logger = logging.getLogger(__name__) FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -47,16 +55,16 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str: return table_identifier -def _drop_table(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> None: +def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None: table_identifier = _get_table_identifier(schema, table) sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) -def _does_table_exist(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> bool: +def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" - cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") + cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = ?", table) return len(cursor.fetchall()) > 0