Skip to content

Commit

Permalink
Sanitize filenames in MySQLHook
Browse files Browse the repository at this point in the history
  • Loading branch information
PApostol committed Aug 11, 2023
1 parent f36ba1d commit 72343e1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
11 changes: 8 additions & 3 deletions airflow/providers/mysql/hooks/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import json
import logging
import re
from typing import TYPE_CHECKING, Any, Union

from airflow.exceptions import AirflowOptionalProviderFeatureException
Expand Down Expand Up @@ -164,6 +165,10 @@ def _get_conn_config_mysql_connector_python(self, conn: Connection) -> dict:

return conn_config

@staticmethod
def _sanitize_filename(filename: str) -> str:
return re.sub(r"(;.*)$", "", filename)

def get_conn(self) -> MySQLConnectionTypes:
"""
Connection to a MySQL database.
Expand Down Expand Up @@ -208,7 +213,7 @@ def bulk_load(self, table: str, tmp_file: str) -> None:
cur = conn.cursor()
cur.execute(
f"""
LOAD DATA LOCAL INFILE '{tmp_file}'
LOAD DATA LOCAL INFILE '{self._sanitize_filename(tmp_file)}'
INTO TABLE {table}
"""
)
Expand All @@ -221,7 +226,7 @@ def bulk_dump(self, table: str, tmp_file: str) -> None:
cur = conn.cursor()
cur.execute(
f"""
SELECT * INTO OUTFILE '{tmp_file}'
SELECT * INTO OUTFILE '{self._sanitize_filename(tmp_file)}'
FROM {table}
"""
)
Expand Down Expand Up @@ -288,7 +293,7 @@ def bulk_load_custom(

cursor.execute(
f"""
LOAD DATA LOCAL INFILE '{tmp_file}'
LOAD DATA LOCAL INFILE '{self._sanitize_filename(tmp_file)}'
{duplicate_key_handling}
INTO TABLE {table}
{extra_options}
Expand Down
38 changes: 38 additions & 0 deletions tests/providers/mysql/hooks/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ def test_bulk_load(self):
"""
)

def test_bulk_load_with_semicolon_in_filename(self):
self.db_hook.bulk_load("table", "/tmp/file; SELECT * FROM DUAL")
self.cur.execute.assert_called_once_with(
"""
LOAD DATA LOCAL INFILE '/tmp/file'
INTO TABLE table
"""
)

def test_bulk_dump(self):
self.db_hook.bulk_dump("table", "/tmp/file")
self.cur.execute.assert_called_once_with(
Expand All @@ -288,6 +297,15 @@ def test_bulk_dump(self):
"""
)

def test_bulk_dump_with_semicolon_in_filename(self):
self.db_hook.bulk_dump("table", "/tmp/file; SELECT * FROM DUAL")
self.cur.execute.assert_called_once_with(
"""
SELECT * INTO OUTFILE '/tmp/file'
FROM table
"""
)

def test_serialize_cell(self):
assert "foo" == self.db_hook._serialize_cell("foo", None)

Expand All @@ -311,6 +329,26 @@ def test_bulk_load_custom(self):
"""
)

def test_bulk_load_custom_with_semicolon_in_filename(self):
self.db_hook.bulk_load_custom(
"table",
"/tmp/file; SELECT * FROM DUAL",
"IGNORE",
"""FIELDS TERMINATED BY ';'
OPTIONALLY ENCLOSED BY '"'
IGNORE 1 LINES""",
)
self.cur.execute.assert_called_once_with(
"""
LOAD DATA LOCAL INFILE '/tmp/file'
IGNORE
INTO TABLE table
FIELDS TERMINATED BY ';'
OPTIONALLY ENCLOSED BY '"'
IGNORE 1 LINES
"""
)


DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
Expand Down

0 comments on commit 72343e1

Please sign in to comment.