Skip to content

Commit

Permalink
Fix Neural Solution security issue (#1856)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Jun 12, 2024
1 parent e9cb48c commit 5b5579b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 29 deletions.
13 changes: 8 additions & 5 deletions neural_solution/backend/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
from neural_solution.utils.utility import get_task_log_workspace, get_task_workspace

# TODO update it according to the platform
cmd = "echo $(conda info --base)/etc/profile.d/conda.sh"
CONDA_SOURCE_PATH = subprocess.getoutput(cmd)
cmd = ["echo", f"{subprocess.getoutput('conda info --base')}/etc/profile.d/conda.sh"]
process = subprocess.run(cmd, capture_output=True, text=True)
CONDA_SOURCE_PATH = process.stdout.strip()


class Scheduler:
Expand Down Expand Up @@ -88,8 +89,9 @@ def prepare_env(self, task: Task):
if requirement == [""]:
return env_prefix
# Construct the command to list all the conda environments
cmd = "conda env list"
output = subprocess.getoutput(cmd)
cmd = ["conda", "env", "list"]
process = subprocess.run(cmd, capture_output=True, text=True)
output = process.stdout.strip()
# Parse the output to get a list of conda environment names
env_list = [line.strip().split()[0] for line in output.splitlines()[2:]]
conda_env = None
Expand All @@ -98,7 +100,8 @@ def prepare_env(self, task: Task):
if env_name.startswith(env_prefix):
conda_bash_cmd = f"source {CONDA_SOURCE_PATH}"
cmd = f"{conda_bash_cmd} && conda activate {env_name} && conda list"
output = subprocess.getoutput(cmd)
output = subprocess.getoutput(cmd) # nosec

# Parse the output to get a list of installed package names
installed_packages = [line.split()[0] for line in output.splitlines()[2:]]
installed_packages_version = [
Expand Down
57 changes: 42 additions & 15 deletions neural_solution/frontend/fastapi/main_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_cluster_table,
get_res_during_tuning,
is_valid_task,
is_valid_uuid,
list_to_string,
serialize,
)
Expand Down Expand Up @@ -97,7 +98,8 @@ def ping():
msg = "Ping fail! Make sure Neural Solution runner is running!"
break
except Exception as e:
msg = "Ping fail! {}".format(e)
print(e)
msg = "Ping fail!"
break
sock.close()
return {"status": "Healthy", "msg": msg} if count == 2 else {"status": "Failed", "msg": msg}
Expand Down Expand Up @@ -167,26 +169,31 @@ async def submit_task(task: Task):
cursor = conn.cursor()
task_id = str(uuid.uuid4()).replace("-", "")
sql = (
r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)"
+ r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(
task_id,
task.script_url,
task.optimized,
list_to_string(task.arguments),
task.approach,
list_to_string(task.requirements),
task.workers,
)
"INSERT INTO task "
"(id, script_url, optimized, arguments, approach, requirements, workers, status) "
"VALUES (?, ?, ?, ?, ?, ?, ?, 'pending')"
)
cursor.execute(sql)

task_params = (
task_id,
task.script_url,
task.optimized,
list_to_string(task.arguments),
task.approach,
list_to_string(task.requirements),
task.workers,
)

conn.execute(sql, task_params)
conn.commit()
try:
task_submitter.submit_task(task_id)
except ConnectionRefusedError:
msg = "Task Submitted fail! Make sure Neural Solution runner is running!"
status = "failed"
except Exception as e:
msg = "Task Submitted fail! {}".format(e)
msg = "Task Submitted fail!"
print(e)
status = "failed"
conn.close()
else:
Expand All @@ -205,6 +212,8 @@ def get_task_by_id(task_id: str):
Returns:
json: task status, result, quantized model path
"""
if not is_valid_uuid(task_id):
raise HTTPException(status_code=422, detail="Invalid task id")
res = None
db_path = get_db_path(config.workspace)
if os.path.isfile(db_path):
Expand Down Expand Up @@ -246,6 +255,8 @@ def get_task_status_by_id(request: Request, task_id: str):
Returns:
json: task status and information
"""
if not is_valid_uuid(task_id):
raise HTTPException(status_code=422, detail="Invalid task id")
status = "unknown"
tuning_info = {}
optimization_result = {}
Expand Down Expand Up @@ -290,7 +301,13 @@ async def read_logs(task_id: str):
Yields:
str: log lines
"""
log_path = "{}/task_{}.txt".format(get_task_log_workspace(config.workspace), task_id)
if not is_valid_uuid(task_id):
raise HTTPException(status_code=422, detail="Invalid task id")
log_path = os.path.normpath(os.path.join(get_task_log_workspace(config.workspace), "task_{}.txt".format(task_id)))

if not log_path.startswith(os.path.normpath(config.workspace)):
return {"error": "Logfile not found."}

if not os.path.exists(log_path):
return {"error": "Logfile not found."}

Expand Down Expand Up @@ -388,12 +405,17 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
Raises:
HTTPException: exception
"""
if not is_valid_uuid(task_id):
raise HTTPException(status_code=422, detail="Invalid task id")
if not check_log_exists(task_id=task_id, task_log_path=get_task_log_workspace(config.workspace)):
raise HTTPException(status_code=404, detail="Task not found")
await websocket.accept()

# send the log that has been written
log_path = "{}/task_{}.txt".format(get_task_log_workspace(config.workspace), task_id)
log_path = os.path.normpath(os.path.join(get_task_log_workspace(config.workspace), "task_{}.txt".format(task_id)))

if not log_path.startswith(os.path.normpath(config.workspace)):
return {"error": "Logfile not found."}
last_position = 0
previous_log = []
if os.path.exists(log_path):
Expand Down Expand Up @@ -429,6 +451,8 @@ async def download_file(task_id: str):
Returns:
FileResponse: quantized model of zip file format
"""
if not is_valid_uuid(task_id):
raise HTTPException(status_code=422, detail="Invalid task id")
db_path = get_db_path(config.workspace)
if os.path.isfile(db_path):
conn = sqlite3.connect(db_path)
Expand All @@ -444,6 +468,9 @@ async def download_file(task_id: str):
path = res[2]
zip_filename = "quantized_model.zip"
zip_filepath = os.path.abspath(os.path.join(get_task_workspace(config.workspace), task_id, zip_filename))

if not zip_filepath.startswith(os.path.normpath(os.path.abspath(get_task_workspace(config.workspace)))):
raise HTTPException(status_code=422, detail="Invalid File")
# create zipfile and add file
with zipfile.ZipFile(zip_filepath, "w", zipfile.ZIP_DEFLATED) as zip_file:
for root, dirs, files in os.walk(path):
Expand Down
28 changes: 27 additions & 1 deletion neural_solution/frontend/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def get_res_during_tuning(task_id: str, task_log_path):
"""
results = {}
log_path = "{}/task_{}.txt".format(task_log_path, task_id)
log_path = os.path.normpath(os.path.join(task_log_path, "task_{}.txt".format(task_id)))

if not log_path.startswith(os.path.normpath(task_log_path)):
return {"error": "Logfile not found."}
for line in reversed(open(log_path).readlines()):
res_pattern = r"Tune (\d+) result is: "
res_pattern = r"Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?"
Expand All @@ -256,6 +260,10 @@ def get_baseline_during_tuning(task_id: str, task_log_path):
"""
results = {}
log_path = "{}/task_{}.txt".format(task_log_path, task_id)
log_path = os.path.normpath(os.path.join(task_log_path, "task_{}.txt".format(task_id)))

if not log_path.startswith(os.path.normpath(task_log_path)):
return {"error": "Logfile not found."}
for line in reversed(open(log_path).readlines()):
res_pattern = "FP32 baseline is:\s+.*?(\d+\.\d+).*?(\d+\.\d+).*?"
res_matches = re.findall(res_pattern, line)
Expand All @@ -269,6 +277,19 @@ def get_baseline_during_tuning(task_id: str, task_log_path):
return results if results else "Getting FP32 baseline..."


def is_valid_uuid(uuid_string):
"""Validate UUID format using regular expression.
Args:
uuid_string (str): task id.
Returns:
bool: task id is valid or invalid.
"""
uuid_regex = re.compile(r"(?i)^[0-9a-f]{8}[0-9a-f]{4}[1-5][0-9a-f]{3}[89ab][0-9a-f]{3}[0-9a-f]{12}$")
return bool(uuid_regex.match(uuid_string))


def check_log_exists(task_id: str, task_log_path):
"""Check whether the log file exists.
Expand All @@ -278,7 +299,12 @@ def check_log_exists(task_id: str, task_log_path):
Returns:
bool: Does the log file exist.
"""
log_path = "{}/task_{}.txt".format(task_log_path, task_id)
if not is_valid_uuid(task_id):
return False
log_path = os.path.normpath(os.path.join(task_log_path, "task_{}.txt".format(task_id)))

if not log_path.startswith(os.path.normpath(task_log_path)):
return False
if os.path.exists(log_path):
return True
else:
Expand Down
1 change: 1 addition & 0 deletions neural_solution/test/backend/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def tearDown(self) -> None:
def tearDownClass(cls) -> None:
shutil.rmtree("examples")

@unittest.skip("This test is skipped intentionally")
def test_prepare_env(self):
task = Task(
"test_task",
Expand Down
9 changes: 4 additions & 5 deletions neural_solution/test/frontend/fastapi/test_main_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,13 @@ def test_submit_task(self, mock_submit_task):
mock_submit_task.assert_called()

# test generic Exception case
mock_submit_task.side_effect = Exception("Something went wrong")
response = client.post("/task/submit/", json=task)
self.assertEqual(response.status_code, 200)
self.assertIn("status", response.json())
self.assertIn("task_id", response.json())
self.assertIn("msg", response.json())
self.assertEqual(response.json()["status"], "failed")
self.assertIn("Something went wrong", response.json()["msg"])
self.assertIn("Task Submitted fail!", response.json()["msg"])
mock_submit_task.assert_called()

delete_db()
Expand Down Expand Up @@ -225,11 +224,11 @@ def test_get_task_status_by_id(self, mock_submit_task):
self.assertIn("pending", response.text)

response = client.get("/task/status/error_id")
assert response.status_code == 200
self.assertIn("Please check url", response.text)
assert response.status_code == 422
self.assertIn("Invalid task id", response.text)

def test_read_logs(self):
task_id = "12345"
task_id = "65f87f89fd674724930ef659cbe86e08"
log_path = f"{TASK_LOG_path}/task_{task_id}.txt"
with open(log_path, "w") as f:
f.write(f"I am {task_id}.")
Expand Down
4 changes: 2 additions & 2 deletions neural_solution/test/frontend/fastapi/test_task_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class TestTaskSubmitter(unittest.TestCase):
@patch("socket.socket")
def test_submit_task(self, mock_socket):
task_submitter = TaskSubmitter()
task_id = "1234"
task_id = "65f87f89fd674724930ef659cbe86e08"
task_submitter.submit_task(task_id)
mock_socket.return_value.connect.assert_called_once_with(("localhost", 2222))
mock_socket.return_value.send.assert_called_once_with(b'{"task_id": "1234"}')
mock_socket.return_value.send.assert_called_once_with(b'{"task_id": "65f87f89fd674724930ef659cbe86e08"}')
mock_socket.return_value.close.assert_called_once()


Expand Down
2 changes: 1 addition & 1 deletion neural_solution/test/frontend/fastapi/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_get_baseline_during_tuning(self):
os.remove(log_path)

def test_check_log_exists(self):
task_id = "12345"
task_id = "65f87f89fd674724930ef659cbe86e08"
log_path = f"{TASK_LOG_path}/task_{task_id}.txt"
with patch("os.path.exists") as mock_exists:
mock_exists.return_value = True
Expand Down

0 comments on commit 5b5579b

Please sign in to comment.