Skip to content

Commit

Permalink
Merge pull request #817 from gradio-app/csv-sec
Browse files Browse the repository at this point in the history
Sanitize flagging inputs before writing to csv
  • Loading branch information
abidlabs authored Mar 14, 2022
2 parents 5f907e4 + 2c413b6 commit 80fea89
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
24 changes: 13 additions & 11 deletions gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, List, Optional

import gradio as gr
from gradio import encryptor
from gradio import encryptor, utils


class FlaggingCallback(ABC):
Expand Down Expand Up @@ -99,7 +99,7 @@ def flag(

with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(csv_data))

with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
Expand Down Expand Up @@ -186,7 +186,7 @@ def replace_flag_at_index(file_content):
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(content)
writer.writerows(utils.santize_for_csv(content))
return output.getvalue()

if interface.encrypt:
Expand All @@ -200,33 +200,35 @@ def replace_flag_at_index(file_content):
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(file_content)
output.write(utils.santize_for_csv(file_content))
writer = csv.writer(output)
if flag_index is None:
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
utils.santize_for_csv(
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
)
)
)
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(headers))
writer.writerow(utils.santize_for_csv(csv_data))
else:
with open(log_fp) as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
with open(
log_fp, "w", newline=""
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
csvfile.write(utils.santize_for_csv(file_content))
with open(log_fp, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
Expand Down Expand Up @@ -368,7 +370,7 @@ def flag(
"_type": "Value",
}

writer.writerow(headers)
writer.writerow(utils.santize_for_csv(headers))

# Generate the row corresponding to the flagged sample
csv_data = []
Expand Down Expand Up @@ -403,7 +405,7 @@ def flag(
if flag_option is not None:
csv_data.append(flag_option)

writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(csv_data))

if is_new:
json.dump(infos, open(self.infos_file, "w"))
Expand Down
37 changes: 36 additions & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import copy
import csv
import inspect
import json
Expand All @@ -10,7 +11,7 @@
import random
import warnings
from distutils.version import StrictVersion
from typing import TYPE_CHECKING, Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, List

import aiohttp
import analytics
Expand Down Expand Up @@ -286,3 +287,37 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values()
]


def santize_for_csv(data: str | List[str] | List[List[str]]):
"""Sanitizes data so that it can be safely written to a CSV file."""

def sanitize(item):
return "'" + item

unsafe_prefixes = ("+", "=", "-", "@")
warning_message = "Sanitizing flagged data by escaping cell contents that begin "
"with one of the following characters: '+', '=', '-', '@'."

if isinstance(data, str):
if data.startswith(unsafe_prefixes):
warnings.warn(warning_message)
return sanitize(data)
return data
elif isinstance(data, list) and isinstance(data[0], str):
sanitized_data = copy.deepcopy(data)
for index, item in enumerate(data):
if item.startswith(unsafe_prefixes):
warnings.warn(warning_message)
sanitized_data[index] = sanitize(item)
return sanitized_data
elif isinstance(data[0], list) and isinstance(data[0][0], str):
sanitized_data = copy.deepcopy(data)
for outer_index, sublist in enumerate(data):
for inner_index, item in enumerate(sublist):
if item.startswith(unsafe_prefixes):
warnings.warn(warning_message)
sanitized_data[outer_index][inner_index] = sanitize(item)
return sanitized_data
else:
raise ValueError("Unsupported data type: " + str(type(data)))
19 changes: 19 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
json,
launch_analytics,
readme_to_html,
santize_for_csv,
version_check,
)

Expand Down Expand Up @@ -116,5 +117,23 @@ def test_get_ip_without_internet(self, mock_get):
self.assertEqual(ip, "No internet connection")


class TestSanitizeForCSV(unittest.TestCase):
def test_safe(self):
safe_data = santize_for_csv("abc")
self.assertEquals(safe_data, "abc")
safe_data = santize_for_csv(["def"])
self.assertEquals(safe_data, ["def"])
safe_data = santize_for_csv([["abc"]])
self.assertEquals(safe_data, [["abc"]])

def test_unsafe(self):
safe_data = santize_for_csv("=abc")
self.assertEquals(safe_data, "'=abc")
safe_data = santize_for_csv(["abc", "+abc"])
self.assertEquals(safe_data, ["abc", "'+abc"])
safe_data = santize_for_csv([["abc", "=abc"]])
self.assertEquals(safe_data, [["abc", "'=abc"]])


if __name__ == "__main__":
unittest.main()

0 comments on commit 80fea89

Please sign in to comment.