Skip to content

Commit

Permalink
fix: Add parameterized queries where possible to address the risk of …
Browse files Browse the repository at this point in the history
…SQL injection (#2540)

* fix 3 instances of potential SQL injection

* formatting fix

* fix data API rds formatting
  • Loading branch information
LeonLuttenberger authored Dec 5, 2023
1 parent 67efc06 commit 5157184
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
10 changes: 9 additions & 1 deletion awswrangler/data_api/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,15 @@ def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str,


def _does_table_exist(con: RdsDataApi, table: str, database: str, transaction_id: str) -> bool:
res = con.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table}'")
res = con.execute(
"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = :table",
parameters=[
{
"name": "table",
"value": {"stringValue": table},
},
],
)
return not res.empty


Expand Down
23 changes: 16 additions & 7 deletions awswrangler/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import uuid
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload

import boto3
import pyarrow as pa
Expand All @@ -13,12 +13,21 @@
from awswrangler import _databases as _db_utils
from awswrangler._config import apply_configs

pymysql = _utils.import_optional_dependency("pymysql")
if TYPE_CHECKING:
try:
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import Cursor
except ImportError:
pass
else:
pymysql = _utils.import_optional_dependency("pymysql")


_logger: logging.Logger = logging.getLogger(__name__)


def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None:
def _validate_connection(con: "Connection[Any]") -> None:
if not isinstance(con, pymysql.connections.Connection):
raise exceptions.InvalidConnection(
"Invalid 'conn' argument, please pass a "
Expand All @@ -27,16 +36,16 @@ def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None:
)


def _drop_table(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> None:
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
schema_str = f"`{schema}`." if schema else ""
sql = f"DROP TABLE IF EXISTS {schema_str}`{table}`"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> bool:
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'")
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = %s", args=[table])
return len(cursor.fetchall()) > 0


Expand Down Expand Up @@ -164,7 +173,7 @@ def connect(
password=attrs.password,
port=attrs.port,
host=attrs.host,
ssl=attrs.ssl_context,
ssl=attrs.ssl_context, # type: ignore[arg-type]
read_timeout=read_timeout,
write_timeout=write_timeout,
connect_timeout=connect_timeout,
Expand Down
16 changes: 12 additions & 4 deletions awswrangler/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -26,7 +27,14 @@

__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]

pyodbc = _utils.import_optional_dependency("pyodbc")
if TYPE_CHECKING:
try:
import pyodbc
from pyodbc import Cursor
except ImportError:
pass
else:
pyodbc = _utils.import_optional_dependency("pyodbc")

_logger: logging.Logger = logging.getLogger(__name__)
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
Expand All @@ -47,16 +55,16 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str:
return table_identifier


def _drop_table(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> None:
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
table_identifier = _get_table_identifier(schema, table)
sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> bool:
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'")
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = ?", table)
return len(cursor.fetchall()) > 0


Expand Down

0 comments on commit 5157184

Please sign in to comment.