Skip to content

Commit

Permalink
validate column names, allow spaces & hyphens
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Kukushkin <[email protected]>
  • Loading branch information
kukushking committed Dec 5, 2023
1 parent beb9bc9 commit ccc8380
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 6 additions & 3 deletions awswrangler/_sql_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SQL utilities."""
import re

from awswrangler import exceptions


Expand All @@ -10,7 +11,9 @@ def identifier(sql):
if len(sql) == 0:
raise exceptions.InvalidArgumentValue("identifier must be > 0 characters in length")

if re.search(r"\W", sql):
raise exceptions.InvalidArgumentValue("identifier can not contain non-alphanumeric characters")
if re.search(r"[^a-zA-Z0-9-_ ]", sql):
raise exceptions.InvalidArgumentValue(
"identifier must contain only alphanumeric characters, spaces, underscores, or hyphens"
)

return f'`{sql}`'
return f"`{sql}`"
8 changes: 5 additions & 3 deletions awswrangler/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _create_table(
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2mysql,
)
cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2]
cols_str: str = "".join([f"{_sql_utils.identifier(k)} {v},\n" for k, v in mysql_types.items()])[:-2]
sql = f"CREATE TABLE IF NOT EXISTS {_sql_utils.identifier(schema)}.{_sql_utils.identifier(table)} (\n{cols_str})"
_logger.debug("Create table query:\n%s", sql)
cursor.execute(sql)
Expand Down Expand Up @@ -555,9 +555,11 @@ def to_sql(
upsert_str = ""
ignore_str = " IGNORE" if mode == "ignore" else ""
if use_column_names:
insertion_columns = f"(`{'`, `'.join(df.columns)}`)"
insertion_columns = f"({', '.join([_sql_utils.identifier(col) for col in df.columns])})"
if mode == "upsert_duplicate_key":
upsert_columns = ", ".join(df.columns.map(lambda column: f"`{column}`=VALUES(`{column}`)"))
upsert_columns = ", ".join(
df.columns.map(lambda col: f"{_sql_utils.identifier(col)}=VALUES({_sql_utils.identifier(col)})")
)
upsert_str = f" ON DUPLICATE KEY UPDATE {upsert_columns}"
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
Expand Down

0 comments on commit ccc8380

Please sign in to comment.