diff --git a/neural_solution/__init__.py b/neural_solution/__init__.py
index 7017c64526d..300d6010498 100644
--- a/neural_solution/__init__.py
+++ b/neural_solution/__init__.py
@@ -13,4 +13,4 @@
# limitations under the License.
"""Neural Solution."""
-from neural_solution.utils import logger
\ No newline at end of file
+from neural_solution.utils import logger
diff --git a/neural_solution/backend/__init__.py b/neural_solution/backend/__init__.py
index 07d2f346fcd..c93847d4e0b 100644
--- a/neural_solution/backend/__init__.py
+++ b/neural_solution/backend/__init__.py
@@ -13,8 +13,8 @@
# limitations under the License.
"""Neural Solution backend."""
+from neural_solution.backend.cluster import Cluster
from neural_solution.backend.result_monitor import ResultMonitor
from neural_solution.backend.scheduler import Scheduler
-from neural_solution.backend.task_monitor import TaskMonitor
-from neural_solution.backend.cluster import Cluster
from neural_solution.backend.task_db import TaskDB
+from neural_solution.backend.task_monitor import TaskMonitor
diff --git a/neural_solution/backend/cluster.py b/neural_solution/backend/cluster.py
index c3eb2aa5501..cee9a23dc9a 100644
--- a/neural_solution/backend/cluster.py
+++ b/neural_solution/backend/cluster.py
@@ -13,12 +13,14 @@
# limitations under the License.
"""Neural Solution cluster."""
-import threading
import sqlite3
+import threading
+from collections import Counter
from typing import List
-from neural_solution.backend.utils.utility import synchronized, create_dir
+
+from neural_solution.backend.utils.utility import create_dir, synchronized
from neural_solution.utils import logger
-from collections import Counter
+
class Cluster:
"""Cluster resource management based on sockets."""
@@ -35,7 +37,7 @@ def __init__(self, node_lst=[], db_path=None):
self.socket_queue = []
self.db_path = db_path
create_dir(db_path)
- self.conn = sqlite3.connect(f'{db_path}', check_same_thread=False)
+ self.conn = sqlite3.connect(f"{db_path}", check_same_thread=False)
self.initial_cluster_from_node_lst(node_lst)
self.lock = threading.Lock()
@@ -63,7 +65,6 @@ def reserve_resource(self, task):
logger.info(f"[Cluster] Assign {reserved_resource_lst} to task {task.task_id}")
return reserved_resource_lst
-
@synchronized
def free_resource(self, reserved_resource_lst):
"""Free the resource by adding the previous occupied resources to the socket queue."""
@@ -71,9 +72,9 @@ def free_resource(self, reserved_resource_lst):
counts = Counter(int(item.split()[0]) for item in reserved_resource_lst)
free_resources = {}
for node_id, count in counts.items():
- free_resources[node_id] = count
+ free_resources[node_id] = count
for node_id, count in counts.items():
- free_resources[node_id] = count
+ free_resources[node_id] = count
for node_id in free_resources:
sql = """
UPDATE cluster
@@ -105,28 +106,31 @@ def initial_cluster_from_node_lst(self, node_lst):
node_lst (List): the node list.
"""
# sqlite should set this check_same_thread to False
- self.conn = sqlite3.connect(f'{self.db_path}', check_same_thread=False)
+ self.conn = sqlite3.connect(f"{self.db_path}", check_same_thread=False)
self.cursor = self.conn.cursor()
- self.cursor.execute('drop table if exists cluster ')
- self.cursor.execute(r'create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,' +
- 'node_info varchar(500),' +
- 'status varchar(100),' +
- 'free_sockets int,' +
- 'busy_sockets int,' +
- 'total_sockets int)')
+ self.cursor.execute("drop table if exists cluster ")
+ self.cursor.execute(
+ r"create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,"
+ + "node_info varchar(500),"
+ + "status varchar(100),"
+ + "free_sockets int,"
+ + "busy_sockets int,"
+ + "total_sockets int)"
+ )
self.node_lst = node_lst
for index, node in enumerate(self.node_lst):
- self.socket_queue += [str(index+1) + " " + node.name] * node.num_sockets
- self.cursor.execute(r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)" +
- "values ('{}', '{}', {}, {}, {})".format(repr(node).replace("Node", f"Node{index+1}"),
- "alive",
- node.num_sockets,
- 0,
- node.num_sockets))
+ self.socket_queue += [str(index + 1) + " " + node.name] * node.num_sockets
+ self.cursor.execute(
+ r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)"
+ + "values ('{}', '{}', {}, {}, {})".format(
+ repr(node).replace("Node", f"Node{index+1}"), "alive", node.num_sockets, 0, node.num_sockets
+ )
+ )
self.conn.commit()
logger.info(f"socket_queue: {self.socket_queue}")
+
class Node:
"""Node definition."""
@@ -134,14 +138,11 @@ class Node:
ip: str = "unknown_ip"
num_sockets: int = 0
num_cores_per_socket: int = 0
- num_gpus: int = 0 # For future use
-
- def __init__(self,
- name: str,
- ip: str = "unknown_ip",
- num_sockets: int = 0,
- num_cores_per_socket: int = 0,
- num_gpus: int = 0) -> None:
+ num_gpus: int = 0 # For future use
+
+ def __init__(
+ self, name: str, ip: str = "unknown_ip", num_sockets: int = 0, num_cores_per_socket: int = 0, num_gpus: int = 0
+ ) -> None:
"""Init node.
hostfile template:
@@ -167,8 +168,7 @@ def __repr__(self) -> str:
Returns:
str: node info.
"""
- return f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) " \
+ return (
+ f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) "
f"and each socket has {self.num_cores_per_socket} cores."
-
-
-
+ )
diff --git a/neural_solution/backend/result_monitor.py b/neural_solution/backend/result_monitor.py
index 8061cee4a81..0972889f890 100644
--- a/neural_solution/backend/result_monitor.py
+++ b/neural_solution/backend/result_monitor.py
@@ -15,9 +15,11 @@
"""Neural Solution result monitor."""
import socket
-from neural_solution.backend.utils.utility import serialize, deserialize
-from neural_solution.utils import logger
+
from neural_solution.backend.task_db import TaskDB
+from neural_solution.backend.utils.utility import deserialize, serialize
+from neural_solution.utils import logger
+
class ResultMonitor:
"""ResultMonitor is a thread that monitors the coming task results and update the task collection in the TaskDb.
@@ -40,7 +42,7 @@ def __init__(self, port, task_db: TaskDB):
def wait_result(self):
"""Monitor the task results and update them in the task db and send back to studio."""
- self.s.bind(("localhost", self.port)) # open a port as the serving port for results
+ self.s.bind(("localhost", self.port)) # open a port as the serving port for results
self.s.listen(10)
while True:
logger.info("[ResultMonitor] waiting for results...")
@@ -53,7 +55,7 @@ def wait_result(self):
c.close()
continue
logger.info("[ResultMonitor] getting result: {}".format(result))
- logger.info("[ResultMonitor] getting q_model path: {}".format(result['q_model_path']))
+ logger.info("[ResultMonitor] getting q_model path: {}".format(result["q_model_path"]))
self.task_db.update_q_model_path_and_result(result["task_id"], result["q_model_path"], result["result"])
c.close()
# TODO send back the result to the studio
@@ -63,4 +65,3 @@ def query_task_status(self, task_id):
"""Synchronize query on the task status."""
# TODO send back the result to the studio? RPC for query?
logger.info(self.task_db.lookup_task_status(task_id))
-
diff --git a/neural_solution/backend/runner.py b/neural_solution/backend/runner.py
index 173b9ca016a..8162b7b524f 100644
--- a/neural_solution/backend/runner.py
+++ b/neural_solution/backend/runner.py
@@ -13,13 +13,14 @@
# limitations under the License.
"""Main backend runner."""
-import threading
import argparse
+import threading
-from neural_solution.backend import TaskDB, Scheduler, TaskMonitor, ResultMonitor
-from neural_solution.utils import logger
+from neural_solution.backend import ResultMonitor, Scheduler, TaskDB, TaskMonitor
from neural_solution.backend.utils.utility import build_cluster, get_db_path
from neural_solution.config import config
+from neural_solution.utils import logger
+
def parse_args(args=None):
"""Parse the command line options.
@@ -30,24 +31,25 @@ def parse_args(args=None):
Returns:
argparse.Namespace: arguments.
"""
- parser = argparse.ArgumentParser(description="Neural Solution runner automatically schedules multiple inc tasks and\
- executes multi-node distributed tuning.")
-
- parser.add_argument("-H", "--hostfile", type=str, default=None, \
- help="Path to the host file which contains all available nodes.")
- parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \
- help="Port to monitor task.")
- parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \
- help="Port to monitor result.")
- parser.add_argument("-WS", "--workspace", type=str, default="./", \
- help="Work space.")
- parser.add_argument("-CEN", "--conda_env_name", type=str, default="inc", \
- help="Conda environment for task execution")
- parser.add_argument("-UP", "--upload_path", type=str, default="./examples", \
- help="Custom example path.")
+ parser = argparse.ArgumentParser(
+ description="Neural Solution runner automatically schedules multiple inc tasks and\
+ executes multi-node distributed tuning."
+ )
+
+ parser.add_argument(
+ "-H", "--hostfile", type=str, default=None, help="Path to the host file which contains all available nodes."
+ )
+ parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.")
+ parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.")
+ parser.add_argument("-WS", "--workspace", type=str, default="./", help="Work space.")
+ parser.add_argument(
+ "-CEN", "--conda_env_name", type=str, default="inc", help="Conda environment for task execution"
+ )
+ parser.add_argument("-UP", "--upload_path", type=str, default="./examples", help="Custom example path.")
return parser.parse_args(args=args)
+
def main(args=None):
"""Implement the main entry of backend.
@@ -72,9 +74,15 @@ def main(args=None):
t_rm = threading.Thread(target=rm.wait_result)
config.workspace = args.workspace
- ts = Scheduler(cluster, task_db, args.result_monitor_port, \
- conda_env_name=args.conda_env_name, upload_path=args.upload_path, config=config, \
- num_threads_per_process=num_threads_per_process)
+ ts = Scheduler(
+ cluster,
+ task_db,
+ args.result_monitor_port,
+ conda_env_name=args.conda_env_name,
+ upload_path=args.upload_path,
+ config=config,
+ num_threads_per_process=num_threads_per_process,
+ )
t_ts = threading.Thread(target=ts.schedule_tasks)
tm = TaskMonitor(args.task_monitor_port, task_db)
@@ -83,8 +91,9 @@ def main(args=None):
t_rm.start()
t_ts.start()
t_tm.start()
- logger.info("task monitor port {} and result monitor port {}".\
- format(args.task_monitor_port, args.result_monitor_port))
+ logger.info(
+ "task monitor port {} and result monitor port {}".format(args.task_monitor_port, args.result_monitor_port)
+ )
logger.info("server start...")
t_rm.join()
@@ -92,5 +101,5 @@ def main(args=None):
t_tm.join()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/neural_solution/backend/scheduler.py b/neural_solution/backend/scheduler.py
index fc991a47436..1a494122cbf 100644
--- a/neural_solution/backend/scheduler.py
+++ b/neural_solution/backend/scheduler.py
@@ -13,39 +13,49 @@
# limitations under the License.
"""Neural Solution scheduler."""
-import time
-import threading
-import subprocess
+import glob
+import json
import os
-import socket
import re
-import json
-import glob
import shutil
-from neural_solution.backend.task import Task
+import socket
+import subprocess
+import threading
+import time
+
from neural_solution.backend.cluster import Cluster
+from neural_solution.backend.task import Task
from neural_solution.backend.task_db import TaskDB
from neural_solution.backend.utils.utility import (
- serialize,
+ build_workspace,
dump_elapsed_time,
- get_task_log_path,
get_current_time,
- build_workspace,
+ get_q_model_path,
+ get_task_log_path,
is_remote_url,
- get_q_model_path
+ serialize,
)
-from neural_solution.utils.utility import get_task_log_workspace, get_task_workspace
from neural_solution.utils import logger
+from neural_solution.utils.utility import get_task_log_workspace, get_task_workspace
-#TODO update it according to the platform
-cmd="echo $(conda info --base)/etc/profile.d/conda.sh"
+# TODO update it according to the platform
+cmd = "echo $(conda info --base)/etc/profile.d/conda.sh"
CONDA_SOURCE_PATH = subprocess.getoutput(cmd)
+
class Scheduler:
"""Scheduler dispatches the task with the available resources, calls the mpi command and report results."""
- def __init__(self, cluster: Cluster, task_db: TaskDB, result_monitor_port, \
- conda_env_name=None, upload_path="./examples", config=None, num_threads_per_process=5):
+ def __init__(
+ self,
+ cluster: Cluster,
+ task_db: TaskDB,
+ result_monitor_port,
+ conda_env_name=None,
+ upload_path="./examples",
+ config=None,
+ num_threads_per_process=5,
+ ):
"""Scheduler dispatches the task with the available resources, calls the mpi command and report results.
Attributes:
@@ -63,9 +73,9 @@ def __init__(self, cluster: Cluster, task_db: TaskDB, result_monitor_port, \
self.config = config
self.num_threads_per_process = num_threads_per_process
- def prepare_env(self, task:Task):
+ def prepare_env(self, task: Task):
"""Check and create a conda environment.
-
+
If the required packages are not installed in the conda environment,
create a new conda environment and install the required packages.
@@ -76,7 +86,7 @@ def prepare_env(self, task:Task):
env_prefix = self.conda_env_name
requirement = task.requirement.split(" ")
# Skip check when requirement is empty.
- if requirement == ['']:
+ if requirement == [""]:
return env_prefix
# Construct the command to list all the conda environments
cmd = f"conda env list"
@@ -92,8 +102,9 @@ def prepare_env(self, task:Task):
output = subprocess.getoutput(cmd)
# Parse the output to get a list of installed package names
installed_packages = [line.split()[0] for line in output.splitlines()[2:]]
- installed_packages_version = \
- [line.split()[0] + "=" + line.split()[1] for line in output.splitlines()[2:]]
+ installed_packages_version = [
+ line.split()[0] + "=" + line.split()[1] for line in output.splitlines()[2:]
+ ]
missing_packages = set(requirement) - set(installed_packages) - set(installed_packages_version)
if not missing_packages:
conda_env = env_name
@@ -101,19 +112,22 @@ def prepare_env(self, task:Task):
if conda_env is None:
# Construct the command to create a new conda environment and install the required packages
from datetime import datetime
+
now = datetime.now()
suffix = now.strftime("%Y%m%d-%H%M%S")
conda_env = f"{env_prefix}_{suffix}"
# Construct the name of the new conda environment
- cmd = (f"source {CONDA_SOURCE_PATH} && conda create -n {conda_env} --clone {env_prefix}"
- f" && conda activate {conda_env} && pip install {task.requirement.replace('=','==')}")
- p = subprocess.Popen(cmd, shell=True) # nosec
+ cmd = (
+ f"source {CONDA_SOURCE_PATH} && conda create -n {conda_env} --clone {env_prefix}"
+ f" && conda activate {conda_env} && pip install {task.requirement.replace('=','==')}"
+ )
+ p = subprocess.Popen(cmd, shell=True) # nosec
logger.info(f"[Scheduler] Creating new environment {conda_env} start.")
p.wait()
logger.info(f"[Scheduler] Creating new environment {conda_env} end.")
return conda_env
- def prepare_task(self, task:Task):
+ def prepare_task(self, task: Task):
"""Prepare workspace and download run_task.py for task.
Args:
@@ -122,24 +136,25 @@ def prepare_task(self, task:Task):
self.task_path = build_workspace(path=get_task_workspace(self.config.workspace), task_id=task.task_id)
logger.info(f"****TASK PATH: {self.task_path}")
if is_remote_url(task.script_url):
- task_url = task.script_url.replace('github.com', 'raw.githubusercontent.com').replace('blob','')
+ task_url = task.script_url.replace("github.com", "raw.githubusercontent.com").replace("blob", "")
try:
- subprocess.check_call(['wget', '-P', self.task_path, task_url])
+ subprocess.check_call(["wget", "-P", self.task_path, task_url])
except subprocess.CalledProcessError as e:
logger.info("Failed: {}".format(e.cmd))
else:
# Assuming the file is uploaded in directory examples
example_path = os.path.abspath(os.path.join(self.upload_path, task.script_url))
# only one python file
- script_path = glob.glob(os.path.join(example_path, '*.py'))[0]
+ script_path = glob.glob(os.path.join(example_path, "*.py"))[0]
# script_path = glob.glob(os.path.join(example_path, f'*{extension}'))[0]
self.script_name = script_path.split("/")[-1]
shutil.copy(script_path, os.path.abspath(self.task_path))
- task.arguments = task.arguments.replace("=dataset", "=" + os.path.join(example_path,"dataset")).\
- replace("=model", "=" + os.path.join(example_path,"model"))
+ task.arguments = task.arguments.replace("=dataset", "=" + os.path.join(example_path, "dataset")).replace(
+ "=model", "=" + os.path.join(example_path, "model")
+ )
if not task.optimized:
# Generate quantization code with Neural Coder API
- neural_coder_cmd = ['python -m neural_coder --enable --approach']
+ neural_coder_cmd = ["python -m neural_coder --enable --approach"]
# for users to define approach: "static, ""static_ipex", "dynamic", "auto"
approach = task.approach
neural_coder_cmd.append(approach)
@@ -148,7 +163,7 @@ def prepare_task(self, task:Task):
neural_coder_cmd.append(self.script_name)
neural_coder_cmd = " ".join(neural_coder_cmd)
full_cmd = """cd {}\n{}""".format(self.task_path, neural_coder_cmd)
- p = subprocess.Popen(full_cmd, shell=True) # nosec
+ p = subprocess.Popen(full_cmd, shell=True) # nosec
logger.info("[Neural Coder] Generating optimized code start.")
p.wait()
logger.info("[Neural Coder] Generating optimized code end.")
@@ -163,7 +178,7 @@ def check_task_status(self, log_path):
str: status "done" or "failed"
"""
for line in reversed(open(log_path).readlines()):
- res_pattern = r'[INFO] Save deploy yaml to'
+ res_pattern = r"[INFO] Save deploy yaml to"
# res_matches = re.findall(res_pattern, line)
if res_pattern in line:
return "done"
@@ -180,24 +195,25 @@ def _parse_cmd(self, task: Task, resource):
# Activate environment
conda_bash_cmd = f"source {CONDA_SOURCE_PATH}"
conda_env_cmd = f"conda activate {conda_env}"
- mpi_cmd = ["mpirun",
- "-np",
- "{}".format(task.workers),
- "-host",
- "{}".format(host_str),
- "-map-by",
- "socket:pe={}".format(self.num_threads_per_process),
- "-mca",
- "btl_tcp_if_include",
- "192.168.20.0/24", # TODO replace it according to the node
- "-x",
- "OMP_NUM_THREADS={}".format(self.num_threads_per_process),
- '--report-bindings'
- ]
+ mpi_cmd = [
+ "mpirun",
+ "-np",
+ "{}".format(task.workers),
+ "-host",
+ "{}".format(host_str),
+ "-map-by",
+ "socket:pe={}".format(self.num_threads_per_process),
+ "-mca",
+ "btl_tcp_if_include",
+ "192.168.20.0/24", # TODO replace it according to the node
+ "-x",
+ "OMP_NUM_THREADS={}".format(self.num_threads_per_process),
+ "--report-bindings",
+ ]
mpi_cmd = " ".join(mpi_cmd)
# Initial Task command
- task_cmd = ['python']
+ task_cmd = ["python"]
task_cmd.append(self.script_name)
task_cmd.append(self.sanitize_arguments(task.arguments))
task_cmd = " ".join(task_cmd)
@@ -208,10 +224,9 @@ def _parse_cmd(self, task: Task, resource):
# build a bash script to run task.
bash_script_name = "distributed_run.sh" if task.workers > 1 else "run.sh"
- bash_script = """{}\n{}\ncd {}\n{}""".format(conda_bash_cmd, conda_env_cmd, \
- self.task_path, task_cmd)
+ bash_script = """{}\n{}\ncd {}\n{}""".format(conda_bash_cmd, conda_env_cmd, self.task_path, task_cmd)
bash_script_path = os.path.join(self.task_path, bash_script_name)
- with open(bash_script_path, 'w', encoding="utf-8") as f:
+ with open(bash_script_path, "w", encoding="utf-8") as f:
f.write(bash_script)
full_cmd = """cd {}\n{} bash {}""".format(self.task_path, mpi_cmd, bash_script_name)
@@ -221,9 +236,9 @@ def report_result(self, task_id, log_path, task_runtime):
"""Report the result to the result monitor."""
s = socket.socket()
s.connect(("localhost", self.result_monitor_port))
- results = {"optimization time (seconds)" : "{:.2f}".format(task_runtime)}
+ results = {"optimization time (seconds)": "{:.2f}".format(task_runtime)}
for line in reversed(open(log_path).readlines()):
- res_pattern = r'Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?'
+ res_pattern = r"Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?"
res_matches = re.findall(res_pattern, line)
if res_matches:
# results["Tuning count"] = res_matches[0][0]
@@ -237,7 +252,6 @@ def report_result(self, task_id, log_path, task_runtime):
s.send(serialize({"task_id": task_id, "result": results, "q_model_path": self.q_model_path}))
s.close()
-
@dump_elapsed_time("Task execution")
def launch_task(self, task: Task, resource):
"""Generate the mpi command and execute the task.
@@ -247,14 +261,15 @@ def launch_task(self, task: Task, resource):
full_cmd = self._parse_cmd(task, resource)
logger.info(f"[TaskScheduler] Parsed the command from task: {full_cmd}")
log_path = get_task_log_path(log_path=get_task_log_workspace(self.config.workspace), task_id=task.task_id)
- p = subprocess.Popen(full_cmd, stdout=open(log_path, 'w+'), stderr=subprocess.STDOUT, shell=True) # nosec
+ p = subprocess.Popen(full_cmd, stdout=open(log_path, "w+"), stderr=subprocess.STDOUT, shell=True) # nosec
logger.info(f"[TaskScheduler] Start run task {task.task_id}, dump log into {log_path}")
start_time = time.time()
p.wait()
self.cluster.free_resource(resource)
task_runtime = time.time() - start_time
- logger.info(\
- f"[TaskScheduler] Finished task {task.task_id}, and free resource {resource}, dump log into {log_path}")
+ logger.info(
+ f"[TaskScheduler] Finished task {task.task_id}, and free resource {resource}, dump log into {log_path}"
+ )
task_status = self.check_task_status(log_path)
self.task_db.update_task_status(task.task_id, task_status)
self.q_model_path = get_q_model_path(log_path=log_path) if task_status == "done" else None
@@ -271,8 +286,10 @@ def schedule_tasks(self):
time.sleep(5)
logger.info(f"[TaskScheduler {get_current_time()}] try to dispatch a task...")
if self.task_db.get_pending_task_num() > 0:
- logger.info(f"[TaskScheduler {get_current_time()}], " + \
- f"there are {self.task_db.get_pending_task_num()} task pending.")
+ logger.info(
+ f"[TaskScheduler {get_current_time()}], "
+ + f"there are {self.task_db.get_pending_task_num()} task pending."
+ )
task_id = self.task_db.task_queue[0]
task = self.task_db.get_task_by_id(task_id)
resource = self.cluster.reserve_resource(task)
@@ -287,4 +304,4 @@ def schedule_tasks(self):
def sanitize_arguments(self, arguments: str):
"""Replace space encoding with space."""
- return arguments.replace(u'\xa0', u' ')
\ No newline at end of file
+ return arguments.replace("\xa0", " ")
diff --git a/neural_solution/backend/task.py b/neural_solution/backend/task.py
index 722ec76e6db..00a6ec22a9a 100644
--- a/neural_solution/backend/task.py
+++ b/neural_solution/backend/task.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Neural Solution task."""
+
+
class Task:
"""A Task is an abstraction of a user tuning request that is handled in neural solution service.
@@ -24,8 +26,19 @@ class Task:
result: The result of the task, which is only value-assigned when the task is done
"""
- def __init__(self, task_id, arguments, workers, status, script_url, \
- optimized, approach, requirement, result="", q_model_path=""):
+ def __init__(
+ self,
+ task_id,
+ arguments,
+ workers,
+ status,
+ script_url,
+ optimized,
+ approach,
+ requirement,
+ result="",
+ q_model_path="",
+ ):
"""Init task.
Args:
@@ -49,4 +62,4 @@ def __init__(self, task_id, arguments, workers, status, script_url, \
self.approach = approach
self.requirement = requirement
self.result = result
- self.q_model_path = q_model_path
\ No newline at end of file
+ self.q_model_path = q_model_path
diff --git a/neural_solution/backend/task_db.py b/neural_solution/backend/task_db.py
index db8e3e7c0ec..35708d63264 100644
--- a/neural_solution/backend/task_db.py
+++ b/neural_solution/backend/task_db.py
@@ -13,11 +13,13 @@
# limitations under the License.
"""Neural Solution task database."""
-import threading
import sqlite3
+import threading
from collections import deque
-from neural_solution.backend.utils.utility import create_dir
+
from neural_solution.backend.task import Task
+from neural_solution.backend.utils.utility import create_dir
+
class TaskDB:
"""TaskDb manages all the tasks.
@@ -39,12 +41,13 @@ def __init__(self, db_path):
self.task_queue = deque()
create_dir(db_path)
# sqlite should set this check_same_thread to False
- self.conn = sqlite3.connect(f'{db_path}', check_same_thread=False)
+ self.conn = sqlite3.connect(f"{db_path}", check_same_thread=False)
self.cursor = self.conn.cursor()
self.cursor.execute(
- 'create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), ' +
- 'workers int, status varchar(20), script_url varchar(500), optimized integer, ' +
- 'approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))')
+ "create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), "
+ + "workers int, status varchar(20), script_url varchar(500), optimized integer, "
+ + "approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))"
+ )
self.conn.commit()
# self.task_collections = []
self.lock = threading.Lock()
@@ -72,7 +75,7 @@ def update_task_status(self, task_id, status):
"""Update the task status with the task id and the status."""
if status not in ["pending", "running", "done", "failed"]:
raise Exception("status invalid, should be one of pending/running/done")
- self.cursor.execute(r"update task set status='{}' where id=?".format(status), (task_id, ))
+ self.cursor.execute(r"update task set status='{}' where id=?".format(status), (task_id,))
self.conn.commit()
def update_result(self, task_id, result_str):
@@ -82,13 +85,14 @@ def update_result(self, task_id, result_str):
def update_q_model_path_and_result(self, task_id, q_model_path, result_str):
"""Update the task result with the result string."""
- self.cursor.execute(r"update task set q_model_path='{}', result='{}' where id=?"
- .format(q_model_path, result_str), (task_id, ))
+ self.cursor.execute(
+ r"update task set q_model_path='{}', result='{}' where id=?".format(q_model_path, result_str), (task_id,)
+ )
self.conn.commit()
def lookup_task_status(self, task_id):
"""Look up the current task status and result."""
- self.cursor.execute(r"select status, result from task where id=?", (task_id, ))
+ self.cursor.execute(r"select status, result from task where id=?", (task_id,))
status, result = self.cursor.fetchone()
return {"status": status, "result": result}
@@ -98,7 +102,6 @@ def get_task_by_id(self, task_id):
attr_tuple = self.cursor.fetchone()
return Task(*attr_tuple)
- def remove_task(self, task_id): # currently no garbage collection
+ def remove_task(self, task_id): # currently no garbage collection
"""Remove task."""
pass
-
diff --git a/neural_solution/backend/task_monitor.py b/neural_solution/backend/task_monitor.py
index 7b52a1d5375..c6aa3745072 100644
--- a/neural_solution/backend/task_monitor.py
+++ b/neural_solution/backend/task_monitor.py
@@ -14,9 +14,11 @@
"""Neural Solution task monitor."""
import socket
-from neural_solution.backend.utils.utility import serialize, deserialize
+
+from neural_solution.backend.utils.utility import deserialize, serialize
from neural_solution.utils import logger
+
class TaskMonitor:
"""TaskMonitor is a thread that monitors the coming tasks and appends them to the task queue.
@@ -32,7 +34,7 @@ def __init__(self, port, task_db):
self.task_db = task_db
def _start_listening(self, host, port, max_parallelism):
- self.s.bind(("localhost", port)) # open a port as the serving port for tasks
+ self.s.bind(("localhost", port)) # open a port as the serving port for tasks
self.s.listen(max_parallelism)
def _receive_task(self):
@@ -62,4 +64,4 @@ def wait_new_task(self):
task = self._receive_task()
if not task:
continue
- self._append_task(task)
\ No newline at end of file
+ self._append_task(task)
diff --git a/neural_solution/backend/utils/__init__.py b/neural_solution/backend/utils/__init__.py
index 1fdfb51540c..a716ddb1a6e 100644
--- a/neural_solution/backend/utils/__init__.py
+++ b/neural_solution/backend/utils/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Neural Solution backend utils."""
\ No newline at end of file
+"""Neural Solution backend utils."""
diff --git a/neural_solution/backend/utils/utility.py b/neural_solution/backend/utils/utility.py
index 1b4a2b9f3f4..31a65984b07 100644
--- a/neural_solution/backend/utils/utility.py
+++ b/neural_solution/backend/utils/utility.py
@@ -15,17 +15,21 @@
"""Neural Solution backend utils."""
import json
import os
-from neural_solution.utils import logger
from urllib.parse import urlparse
+from neural_solution.utils import logger
+
+
def serialize(request: dict) -> bytes:
"""Serialize a dict object to bytes for inter-process communication."""
return json.dumps(request).encode()
+
def deserialize(request: bytes) -> dict:
"""Deserialize the recived bytes to a dict object."""
return json.loads(request)
+
def dump_elapsed_time(customized_msg=""):
"""Get the elapsed time for decorated functions.
@@ -33,17 +37,23 @@ def dump_elapsed_time(customized_msg=""):
customized_msg (string, optional): The parameter passed to decorator. Defaults to None.
"""
import time
+
def f(func):
def fi(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
end = time.time()
- logger.info('%s elapsed time: %s ms' %(customized_msg if customized_msg else func.__qualname__,
- round((end - start) * 1000, 2)))
+ logger.info(
+ "%s elapsed time: %s ms"
+ % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2))
+ )
return res
+
return fi
+
return f
+
def get_task_log_path(log_path, task_id):
"""Get the path of task log according id.
@@ -56,7 +66,7 @@ def get_task_log_path(log_path, task_id):
"""
if not os.path.exists(log_path):
os.makedirs(log_path)
- log_file_path = "{}/task_{}.txt".format(log_path,task_id)
+ log_file_path = "{}/task_{}.txt".format(log_path, task_id)
return log_file_path
@@ -71,6 +81,7 @@ def get_db_path(workspace="./"):
"""
return os.path.join(workspace, "db", "task.db")
+
def get_task_workspace(workspace="./"):
"""Get the workspace of task.
@@ -82,6 +93,7 @@ def get_task_workspace(workspace="./"):
"""
return os.path.join(workspace, "task_workspace")
+
def get_task_log_workspace(workspace="./"):
"""Get the log workspace for task.
@@ -93,6 +105,7 @@ def get_task_log_workspace(workspace="./"):
"""
return os.path.join(workspace, "task_log")
+
def get_serve_log_workspace(workspace="./"):
"""Get log workspace for service.
@@ -114,16 +127,18 @@ def build_local_cluster(db_path):
Returns:
(Cluster, int): cluster and num threads per process
"""
- from neural_solution.backend.cluster import Node, Cluster
- hostname = 'localhost'
- node1 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5)
- node2 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5)
- node3 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5)
+ from neural_solution.backend.cluster import Cluster, Node
+
+ hostname = "localhost"
+ node1 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5)
+ node2 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5)
+ node3 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5)
node_lst = [node1, node2, node3]
cluster = Cluster(node_lst=node_lst, db_path=db_path)
return cluster, 5
+
def build_cluster(file_path, db_path):
"""Build cluster according to the host file.
@@ -133,7 +148,8 @@ def build_cluster(file_path, db_path):
Returns:
Cluster: return cluster object.
"""
- from neural_solution.backend.cluster import Node, Cluster
+ from neural_solution.backend.cluster import Cluster, Node
+
# If no file is specified, build a local cluster
if file_path == "None" or file_path is None:
return build_local_cluster(db_path)
@@ -143,7 +159,7 @@ def build_cluster(file_path, db_path):
node_lst = []
num_threads_per_process = 5
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
for line in f:
hostname, num_sockets, num_cores_per_socket = line.strip().split(" ")
num_sockets, num_cores_per_socket = int(num_sockets), int(num_cores_per_socket)
@@ -153,6 +169,7 @@ def build_cluster(file_path, db_path):
cluster = Cluster(node_lst=node_lst, db_path=db_path)
return cluster, num_threads_per_process
+
def get_current_time():
"""Get current time.
@@ -160,7 +177,9 @@ def get_current_time():
str: the current time in hours, minutes, and seconds.
"""
from datetime import datetime
- return datetime.now().strftime('%H:%M:%S')
+
+ return datetime.now().strftime("%H:%M:%S")
+
def synchronized(func):
"""Locking for synchronization.
@@ -168,11 +187,14 @@ def synchronized(func):
Args:
func (function): decorative function
"""
+
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
+
return wrapper
+
def build_workspace(path, task_id=""):
"""Build workspace of running tasks.
@@ -185,6 +207,7 @@ def build_workspace(path, task_id=""):
os.makedirs(task_path)
return os.path.abspath(task_path)
+
def is_remote_url(url_or_filename):
"""Check if input is a URL.
@@ -197,11 +220,13 @@ def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
+
def create_dir(path):
"""Create the (nested) path if not exist."""
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
+
def get_q_model_path(log_path):
"""Get the quantized model path from task log.
@@ -212,9 +237,10 @@ def get_q_model_path(log_path):
str: quantized model path
"""
import re
+
for line in reversed(open(log_path).readlines()):
- match = re.search(r'(Save quantized model to|Save config file and weights of quantized model to) (.+?)\.', line)
+ match = re.search(r"(Save quantized model to|Save config file and weights of quantized model to) (.+?)\.", line)
if match:
q_model_path = match.group(2)
return q_model_path
- return "quantized model path not found"
\ No newline at end of file
+ return "quantized model path not found"
diff --git a/neural_solution/bin/__init__.py b/neural_solution/bin/__init__.py
index 4f8f95fdc2f..63ed8c79470 100644
--- a/neural_solution/bin/__init__.py
+++ b/neural_solution/bin/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Neural Solution."""
\ No newline at end of file
+"""Neural Solution."""
diff --git a/neural_solution/bin/neural_solution.py b/neural_solution/bin/neural_solution.py
index d41cfb34013..dccaa9bfdd8 100644
--- a/neural_solution/bin/neural_solution.py
+++ b/neural_solution/bin/neural_solution.py
@@ -18,7 +18,9 @@
def exec():
"""Execute Neural Solution launch."""
from neural_solution.launcher import main
+
main()
-if __name__ == '__main__':
- exec()
\ No newline at end of file
+
+if __name__ == "__main__":
+ exec()
diff --git a/neural_solution/config.py b/neural_solution/config.py
index 52c872557e2..8c30b4c6118 100644
--- a/neural_solution/config.py
+++ b/neural_solution/config.py
@@ -16,6 +16,7 @@
INTERVAL_TIME_BETWEEN_DISPATCH_TASK = 3
+
class Config:
"""Config for services."""
@@ -24,6 +25,7 @@ class Config:
result_monitor_port: int = 3333
service_address: str = "localhost"
grpc_api_port: int = 4444
- #TODO add set and get methods for each attribute
+ # TODO add set and get methods for each attribute
+
-config = Config()
\ No newline at end of file
+config = Config()
diff --git a/neural_solution/examples/custom_models_optimized/tf_example1/test.py b/neural_solution/examples/custom_models_optimized/tf_example1/test.py
index 3f345669a1d..006e241551d 100644
--- a/neural_solution/examples/custom_models_optimized/tf_example1/test.py
+++ b/neural_solution/examples/custom_models_optimized/tf_example1/test.py
@@ -1,41 +1,44 @@
"""Running script."""
import tensorflow as tf
-
-from neural_compressor.data import TensorflowImageRecord
-from neural_compressor.data import BilinearImagenetTransform
-from neural_compressor.data import ComposeTransform
-from neural_compressor.data import DefaultDataLoader
-from neural_compressor.quantization import fit
-from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor import Metric
+from neural_compressor.config import PostTrainingQuantConfig
+from neural_compressor.data import BilinearImagenetTransform, ComposeTransform, DefaultDataLoader, TensorflowImageRecord
+from neural_compressor.quantization import fit
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
-flags.DEFINE_string('dataset_location', None, 'location of calibration dataset and evaluate dataset')
+flags.DEFINE_string("dataset_location", None, "location of calibration dataset and evaluate dataset")
-flags.DEFINE_string('model_path', None, 'location of model')
+flags.DEFINE_string("model_path", None, "location of model")
-calib_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform= \
- ComposeTransform(transform_list= [BilinearImagenetTransform(height=224, width=224)]))
+calib_dataset = TensorflowImageRecord(
+ root=FLAGS.dataset_location,
+ transform=ComposeTransform(transform_list=[BilinearImagenetTransform(height=224, width=224)]),
+)
calib_dataloader = DefaultDataLoader(dataset=calib_dataset, batch_size=10)
-eval_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform=ComposeTransform(transform_list= \
- [BilinearImagenetTransform(height=224, width=224)]))
+eval_dataset = TensorflowImageRecord(
+ root=FLAGS.dataset_location,
+ transform=ComposeTransform(transform_list=[BilinearImagenetTransform(height=224, width=224)]),
+)
eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=1)
+
def main():
"""Implement running function."""
top1 = Metric(name="topk", k=1)
config = PostTrainingQuantConfig(calibration_sampling_size=[20])
model_path = FLAGS.model_path + "/mobilenet_v1_1.0_224_frozen.pb"
q_model = fit(
- model= model_path,
+ model=model_path,
conf=config,
calib_dataloader=calib_dataloader,
eval_dataloader=eval_dataloader,
- eval_metric=top1)
+ eval_metric=top1,
+ )
q_model.save("./q_model_path/q_model")
+
if __name__ == "__main__":
main()
diff --git a/neural_solution/frontend/fastapi/__init__.py b/neural_solution/frontend/fastapi/__init__.py
index 3fb002264a5..b862242e65b 100644
--- a/neural_solution/frontend/fastapi/__init__.py
+++ b/neural_solution/frontend/fastapi/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""FastAPI frontend."""
\ No newline at end of file
+"""FastAPI frontend."""
diff --git a/neural_solution/frontend/fastapi/main_server.py b/neural_solution/frontend/fastapi/main_server.py
index d7a363cd122..1d88e0cade7 100644
--- a/neural_solution/frontend/fastapi/main_server.py
+++ b/neural_solution/frontend/fastapi/main_server.py
@@ -14,36 +14,32 @@
"""Fast api server."""
-from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
-from fastapi.responses import StreamingResponse, HTMLResponse
-from neural_solution.frontend.task_submitter import Task, task_submitter
-from neural_solution.frontend.utility import (
- get_cluster_info,
- get_cluster_table,
- serialize,
- deserialize,
- get_res_during_tuning,
- get_baseline_during_tuning,
- check_log_exists,
- list_to_string)
-
-import sqlite3
-import os
-import uuid
-from watchdog.observers import Observer
-from watchdog.events import FileSystemEventHandler
import asyncio
import json
+import os
import socket
-import uvicorn
+import sqlite3
+import uuid
-from neural_solution.utils.utility import (
- get_task_log_workspace,
- get_db_path
-)
+import uvicorn
+from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
+from fastapi.responses import HTMLResponse, StreamingResponse
+from watchdog.events import FileSystemEventHandler
+from watchdog.observers import Observer
from neural_solution.config import config
-
+from neural_solution.frontend.task_submitter import Task, task_submitter
+from neural_solution.frontend.utility import (
+ check_log_exists,
+ deserialize,
+ get_baseline_during_tuning,
+ get_cluster_info,
+ get_cluster_table,
+ get_res_during_tuning,
+ list_to_string,
+ serialize,
+)
+from neural_solution.utils.utility import get_db_path, get_task_log_workspace
# Get config from Launcher.sh
task_monitor_port = None
@@ -57,19 +53,15 @@
args = None
+
def parse_arguments():
"""Parse the command line options."""
parser = argparse.ArgumentParser(description="Frontend with RESTful API")
- parser.add_argument("-H", "--host", type=str, default="0.0.0.0", \
- help="The address to submit task.")
- parser.add_argument("-FP", "--fastapi_port", type=int, default=8000, \
- help="Port to submit task by user.")
- parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \
- help="Port to monitor task.")
- parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \
- help="Port to monitor result.")
- parser.add_argument("-WS", "--workspace", type=str, default="./", \
- help="Work space.")
+ parser.add_argument("-H", "--host", type=str, default="0.0.0.0", help="The address to submit task.")
+ parser.add_argument("-FP", "--fastapi_port", type=int, default=8000, help="Port to submit task by user.")
+ parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.")
+ parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.")
+ parser.add_argument("-WS", "--workspace", type=str, default="./", help="Work space.")
args = parser.parse_args()
return args
@@ -79,6 +71,7 @@ def read_root():
"""Root route."""
return {"message": "Welcome to Neural Solution!"}
+
@app.get("/ping")
def ping():
"""Test status of services.
@@ -100,14 +93,15 @@ def ping():
sock.close()
continue
except ConnectionRefusedError:
- msg = "Ping fail! Make sure Neural Solution runner is running!"
- break
+ msg = "Ping fail! Make sure Neural Solution runner is running!"
+ break
except Exception as e:
msg = "Ping fail! {}".format(e)
break
sock.close()
return {"status": "Healthy", "msg": msg} if count == 2 else {"status": "Failed", "msg": msg}
+
@app.get("/cluster")
def get_cluster():
"""Get the cluster info.
@@ -118,6 +112,7 @@ def get_cluster():
db_path = get_db_path(config.workspace)
return get_cluster_info(db_path=db_path)
+
@app.get("/clusters")
def get_clusters():
"""Get the cluster info.
@@ -128,6 +123,7 @@ def get_clusters():
db_path = get_db_path(config.workspace)
return HTMLResponse(content=get_cluster_table(db_path=db_path))
+
@app.get("/description")
async def get_description():
"""Get user oriented API descriptions.
@@ -140,6 +136,7 @@ async def get_description():
data = json.load(f)
return data
+
@app.post("/task/submit/")
async def submit_task(task: Task):
"""Submit task.
@@ -163,10 +160,19 @@ async def submit_task(task: Task):
if os.path.isfile(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
- task_id = str(uuid.uuid4()).replace('-','')
- sql = r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" +\
- r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(task_id, task.script_url, task.optimized,
- list_to_string(task.arguments), task.approach, list_to_string(task.requirements), task.workers)
+ task_id = str(uuid.uuid4()).replace("-", "")
+ sql = (
+ r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)"
+ + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(
+ task_id,
+ task.script_url,
+ task.optimized,
+ list_to_string(task.arguments),
+ task.approach,
+ list_to_string(task.requirements),
+ task.workers,
+ )
+ )
cursor.execute(sql)
conn.commit()
try:
@@ -180,9 +186,10 @@ async def submit_task(task: Task):
conn.close()
else:
msg = "Task Submitted fail! db not found!"
- return {"msg": msg} # TODO to align with return message when submit task successfully
+ return {"msg": msg} # TODO to align with return message when submit task successfully
return {"status": status, "task_id": task_id, "msg": msg}
+
@app.get("/task/{task_id}")
def get_task_by_id(task_id: str):
"""Get task status, result, quantized model path according to id.
@@ -202,7 +209,8 @@ def get_task_by_id(task_id: str):
res = cursor.fetchone()
cursor.close()
conn.close()
- return {"status": res[0], 'optimized_result': deserialize(res[1]) if res[1] else res[1], "result_path": res[2]}
+ return {"status": res[0], "optimized_result": deserialize(res[1]) if res[1] else res[1], "result_path": res[2]}
+
@app.get("/task/")
def get_all_tasks():
@@ -222,6 +230,7 @@ def get_all_tasks():
conn.close()
return {"message": res}
+
@app.get("/task/status/{task_id}")
def get_task_status_by_id(task_id: str):
"""Get task status and information according to id.
@@ -241,12 +250,12 @@ def get_task_status_by_id(task_id: str):
if os.path.isfile(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
- cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id, ))
+ cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id,))
res = cursor.fetchone()
cursor.close()
conn.close()
if not res:
- status = "Please check url."
+ status = "Please check url."
elif res[0] == "done":
status = res[0]
optimization_result = deserialize(res[1]) if res[1] else res[1]
@@ -254,15 +263,14 @@ def get_task_status_by_id(task_id: str):
elif res[0] == "pending":
status = "pending"
else:
- baseline = get_baseline_during_tuning(task_id,get_task_log_workspace(config.workspace))
+ baseline = get_baseline_during_tuning(task_id, get_task_log_workspace(config.workspace))
tuning_result = get_res_during_tuning(task_id, get_task_log_workspace(config.workspace))
status = res[0]
- tuning_info = {
- "baseline": baseline,
- "message": tuning_result}
+ tuning_info = {"baseline": baseline, "message": tuning_result}
result = {"status": status, "tuning_info": tuning_info, "optimization_result": optimization_result}
return result
+
@app.get("/task/log/{task_id}")
async def read_logs(task_id: str):
"""Get the log of task according to id.
@@ -279,6 +287,7 @@ async def read_logs(task_id: str):
log_path = "{}/task_{}.txt".format(get_task_log_workspace(config.workspace), task_id)
if not os.path.exists(log_path):
return {"error": "Logfile not found."}
+
def stream_logs():
with open(log_path) as f:
while True:
@@ -286,8 +295,10 @@ def stream_logs():
if not line:
break
yield line.encode()
+
return StreamingResponse(stream_logs(), media_type="text/plain")
+
# Real time output log
class LogEventHandler(FileSystemEventHandler):
"""Responsible for monitoring log changes and sending logs to clients.
@@ -308,11 +319,10 @@ def __init__(self, websocket: WebSocket, task_id, last_position):
self.websocket = websocket
self.task_id = task_id
self.loop = asyncio.get_event_loop()
- self.last_position = last_position # record last line
+ self.last_position = last_position # record last line
self.queue = asyncio.Queue()
self.timer = self.loop.create_task(self.send_messages())
-
async def send_messages(self):
"""Send messages to the client."""
while True:
@@ -340,6 +350,7 @@ def on_modified(self, event):
for line in lines:
self.queue.put_nowait(line.strip())
+
# start log watcher
def start_log_watcher(websocket, task_id, last_position):
"""Start log watcher.
@@ -406,8 +417,8 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
config.task_monitor_port = args.task_monitor_port
config.result_monitor_port = args.result_monitor_port
# initialize the task submitter
- task_submitter.task_monitor_port=config.task_monitor_port
- task_submitter.result_monitor_port=config.result_monitor_port
+ task_submitter.task_monitor_port = config.task_monitor_port
+ task_submitter.result_monitor_port = config.result_monitor_port
config.service_address = task_submitter.service_address
# start the app
uvicorn.run(app, host=args.host, port=args.fastapi_port)
diff --git a/neural_solution/frontend/gRPC/__init__.py b/neural_solution/frontend/gRPC/__init__.py
index de608a1758d..1c1b8ab6053 100644
--- a/neural_solution/frontend/gRPC/__init__.py
+++ b/neural_solution/frontend/gRPC/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""gRPC frontend."""
\ No newline at end of file
+"""gRPC frontend."""
diff --git a/neural_solution/frontend/gRPC/client.py b/neural_solution/frontend/gRPC/client.py
index ef8750954ac..c344e439017 100644
--- a/neural_solution/frontend/gRPC/client.py
+++ b/neural_solution/frontend/gRPC/client.py
@@ -16,16 +16,14 @@
import argparse
import json
-import grpc
import os
-from neural_solution.frontend.gRPC.proto import (
- neural_solution_pb2,
- neural_solution_pb2_grpc
-)
+import grpc
-from neural_solution.utils import logger
from neural_solution.config import config
+from neural_solution.frontend.gRPC.proto import neural_solution_pb2, neural_solution_pb2_grpc
+from neural_solution.utils import logger
+
def _parse_task_from_json(request_path):
file_path = os.path.abspath(request_path)
@@ -33,33 +31,34 @@ def _parse_task_from_json(request_path):
task = json.load(fp)
return task
+
def submit_task(args):
"""Implement main entry point for the client of gRPC frontend."""
task = _parse_task_from_json(args.request)
logger.info("Parsed task:")
logger.info(task)
-
+
# Create a gRPC channel
port = str(config.grpc_api_port)
- channel = grpc.insecure_channel('localhost:' + port)
+ channel = grpc.insecure_channel("localhost:" + port)
# Create a stub (client)
stub = neural_solution_pb2_grpc.TaskServiceStub(channel)
# Ping serve
- request = neural_solution_pb2.EmptyRequest() # pylint: disable=no-member
+ request = neural_solution_pb2.EmptyRequest() # pylint: disable=no-member
response = stub.Ping(request)
logger.info(response.status)
logger.info(response.msg)
# Create a task request with the desired fields
- request = neural_solution_pb2.Task( # pylint: disable=no-member
- script_url=task['script_url'],
- optimized=task['optimized'] == 'True',
- arguments=task['arguments'],
- approach=task['approach'],
- requirements=task['requirements'],
- workers=task['workers']
+ request = neural_solution_pb2.Task( # pylint: disable=no-member
+ script_url=task["script_url"],
+ optimized=task["optimized"] == "True",
+ arguments=task["arguments"],
+ approach=task["approach"],
+ requirements=task["requirements"],
+ workers=task["workers"],
)
# Call the SubmitTask RPC on the server
@@ -80,17 +79,18 @@ def run_query_task_result(args):
task_id = args.task_id
# Create a gRPC channel
port = str(config.grpc_api_port)
- channel = grpc.insecure_channel('localhost:' + port)
+ channel = grpc.insecure_channel("localhost:" + port)
# Create a stub (client)
stub = neural_solution_pb2_grpc.TaskServiceStub(channel)
- request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member
+ request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member
response = stub.QueryTaskResult(request)
logger.info(response.status)
logger.info(response.tuning_information)
logger.info(response.optimization_result)
+
def run_query_task_status(args):
"""Query task status according to id.
@@ -100,26 +100,26 @@ def run_query_task_status(args):
task_id = args.task_id
# Create a gRPC channel
port = str(config.grpc_api_port)
- channel = grpc.insecure_channel('localhost:' + port)
+ channel = grpc.insecure_channel("localhost:" + port)
# Create a stub (client)
stub = neural_solution_pb2_grpc.TaskServiceStub(channel)
- request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member
+ request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member
response = stub.GetTaskById(request)
logger.info(response.status)
logger.info(response.optimized_result)
logger.info(response.result_path)
-if __name__ == '__main__':
+if __name__ == "__main__":
logger.info(f"Try to start gRPC server.")
"""Parse the command line options."""
parser = argparse.ArgumentParser(description="gRPC Client")
subparsers = parser.add_subparsers(help="Action", dest="action")
submit_action_parser = subparsers.add_parser("submit", help="Submit help")
-
+
submit_action_parser.set_defaults(func=submit_task)
submit_action_parser.add_argument("--request", type=str, default=None, help="Request json file path.")
@@ -132,4 +132,4 @@ def run_query_task_status(args):
# for test:
# python client.py query --task_id="d3e10a49326449fb9d0d62f2bfc1cb43"
-# python client.py submit --request="test_task_request.json"
\ No newline at end of file
+# python client.py submit --request="test_task_request.json"
diff --git a/neural_solution/frontend/gRPC/proto/__init__.py b/neural_solution/frontend/gRPC/proto/__init__.py
index facfd9f388a..059f44462f5 100644
--- a/neural_solution/frontend/gRPC/proto/__init__.py
+++ b/neural_solution/frontend/gRPC/proto/__init__.py
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""gRPC proto."""
\ No newline at end of file
+"""gRPC proto."""
diff --git a/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py b/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py
index 8c8f874d6d3..65e8ff7ca34 100644
--- a/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py
+++ b/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py
@@ -3,10 +3,11 @@
# source: neural_solution.proto
# pylint: disable=all
"""Generated protocol buffer code."""
-from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -14,30 +15,30 @@
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15neural_solution.proto\x12\x0fneural_solution\x1a\x1bgoogle/protobuf/empty.proto\"y\n\x04Task\x12\x12\n\nscript_url\x18\x01 \x01(\t\x12\x11\n\toptimized\x18\x02 \x01(\x08\x12\x11\n\targuments\x18\x03 \x03(\t\x12\x10\n\x08\x61pproach\x18\x04 \x01(\t\x12\x14\n\x0crequirements\x18\x05 \x03(\t\x12\x0f\n\x07workers\x18\x06 \x01(\x05\"<\n\x0cTaskResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\"\x19\n\x06TaskId\x12\x0f\n\x07task_id\x18\x01 \x01(\t\"K\n\nTaskStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x18\n\x10optimized_result\x18\x02 \x01(\t\x12\x13\n\x0bresult_path\x18\x03 \x01(\t\"\x0e\n\x0c\x45mptyRequest\"!\n\x0eWelcomeMessage\x12\x0f\n\x07message\x18\x01 \x01(\t\"2\n\x13ResponsePingMessage\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0b\n\x03msg\x18\x02 \x01(\t\"]\n\x12ResponseTaskResult\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x1a\n\x12tuning_information\x18\x02 \x01(\t\x12\x1b\n\x13optimization_result\x18\x03 \x01(\t2\xb5\x02\n\x0bTaskService\x12\x46\n\x04Ping\x12\x16.google.protobuf.Empty\x1a$.neural_solution.ResponsePingMessage\"\x00\x12\x44\n\nSubmitTask\x12\x15.neural_solution.Task\x1a\x1d.neural_solution.TaskResponse\"\x00\x12\x45\n\x0bGetTaskById\x12\x17.neural_solution.TaskId\x1a\x1b.neural_solution.TaskStatus\"\x00\x12Q\n\x0fQueryTaskResult\x12\x17.neural_solution.TaskId\x1a#.neural_solution.ResponseTaskResult\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x15neural_solution.proto\x12\x0fneural_solution\x1a\x1bgoogle/protobuf/empty.proto"y\n\x04Task\x12\x12\n\nscript_url\x18\x01 \x01(\t\x12\x11\n\toptimized\x18\x02 \x01(\x08\x12\x11\n\targuments\x18\x03 \x03(\t\x12\x10\n\x08\x61pproach\x18\x04 \x01(\t\x12\x14\n\x0crequirements\x18\x05 \x03(\t\x12\x0f\n\x07workers\x18\x06 \x01(\x05"<\n\x0cTaskResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t"\x19\n\x06TaskId\x12\x0f\n\x07task_id\x18\x01 \x01(\t"K\n\nTaskStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x18\n\x10optimized_result\x18\x02 \x01(\t\x12\x13\n\x0bresult_path\x18\x03 \x01(\t"\x0e\n\x0c\x45mptyRequest"!\n\x0eWelcomeMessage\x12\x0f\n\x07message\x18\x01 \x01(\t"2\n\x13ResponsePingMessage\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0b\n\x03msg\x18\x02 \x01(\t"]\n\x12ResponseTaskResult\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x1a\n\x12tuning_information\x18\x02 \x01(\t\x12\x1b\n\x13optimization_result\x18\x03 \x01(\t2\xb5\x02\n\x0bTaskService\x12\x46\n\x04Ping\x12\x16.google.protobuf.Empty\x1a$.neural_solution.ResponsePingMessage"\x00\x12\x44\n\nSubmitTask\x12\x15.neural_solution.Task\x1a\x1d.neural_solution.TaskResponse"\x00\x12\x45\n\x0bGetTaskById\x12\x17.neural_solution.TaskId\x1a\x1b.neural_solution.TaskStatus"\x00\x12Q\n\x0fQueryTaskResult\x12\x17.neural_solution.TaskId\x1a#.neural_solution.ResponseTaskResult"\x00\x62\x06proto3'
+)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'neural_solution_pb2', globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "neural_solution_pb2", globals())
if _descriptor._USE_C_DESCRIPTORS == False:
-
- DESCRIPTOR._options = None
- _TASK._serialized_start=71
- _TASK._serialized_end=192
- _TASKRESPONSE._serialized_start=194
- _TASKRESPONSE._serialized_end=254
- _TASKID._serialized_start=256
- _TASKID._serialized_end=281
- _TASKSTATUS._serialized_start=283
- _TASKSTATUS._serialized_end=358
- _EMPTYREQUEST._serialized_start=360
- _EMPTYREQUEST._serialized_end=374
- _WELCOMEMESSAGE._serialized_start=376
- _WELCOMEMESSAGE._serialized_end=409
- _RESPONSEPINGMESSAGE._serialized_start=411
- _RESPONSEPINGMESSAGE._serialized_end=461
- _RESPONSETASKRESULT._serialized_start=463
- _RESPONSETASKRESULT._serialized_end=556
- _TASKSERVICE._serialized_start=559
- _TASKSERVICE._serialized_end=868
+ DESCRIPTOR._options = None
+ _TASK._serialized_start = 71
+ _TASK._serialized_end = 192
+ _TASKRESPONSE._serialized_start = 194
+ _TASKRESPONSE._serialized_end = 254
+ _TASKID._serialized_start = 256
+ _TASKID._serialized_end = 281
+ _TASKSTATUS._serialized_start = 283
+ _TASKSTATUS._serialized_end = 358
+ _EMPTYREQUEST._serialized_start = 360
+ _EMPTYREQUEST._serialized_end = 374
+ _WELCOMEMESSAGE._serialized_start = 376
+ _WELCOMEMESSAGE._serialized_end = 409
+ _RESPONSEPINGMESSAGE._serialized_start = 411
+ _RESPONSEPINGMESSAGE._serialized_end = 461
+ _RESPONSETASKRESULT._serialized_start = 463
+ _RESPONSETASKRESULT._serialized_end = 556
+ _TASKSERVICE._serialized_start = 559
+ _TASKSERVICE._serialized_end = 868
# @@protoc_insertion_point(module_scope)
diff --git a/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py b/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py
index 285511e05f4..fcc1132c5c1 100644
--- a/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py
+++ b/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py
@@ -2,8 +2,8 @@
# pylint: disable=all
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
-
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
+
import neural_solution.frontend.gRPC.proto.neural_solution_pb2 as neural__solution__pb2
@@ -17,25 +17,25 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Ping = channel.unary_unary(
- '/neural_solution.TaskService/Ping',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=neural__solution__pb2.ResponsePingMessage.FromString,
- )
+ "/neural_solution.TaskService/Ping",
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=neural__solution__pb2.ResponsePingMessage.FromString,
+ )
self.SubmitTask = channel.unary_unary(
- '/neural_solution.TaskService/SubmitTask',
- request_serializer=neural__solution__pb2.Task.SerializeToString,
- response_deserializer=neural__solution__pb2.TaskResponse.FromString,
- )
+ "/neural_solution.TaskService/SubmitTask",
+ request_serializer=neural__solution__pb2.Task.SerializeToString,
+ response_deserializer=neural__solution__pb2.TaskResponse.FromString,
+ )
self.GetTaskById = channel.unary_unary(
- '/neural_solution.TaskService/GetTaskById',
- request_serializer=neural__solution__pb2.TaskId.SerializeToString,
- response_deserializer=neural__solution__pb2.TaskStatus.FromString,
- )
+ "/neural_solution.TaskService/GetTaskById",
+ request_serializer=neural__solution__pb2.TaskId.SerializeToString,
+ response_deserializer=neural__solution__pb2.TaskStatus.FromString,
+ )
self.QueryTaskResult = channel.unary_unary(
- '/neural_solution.TaskService/QueryTaskResult',
- request_serializer=neural__solution__pb2.TaskId.SerializeToString,
- response_deserializer=neural__solution__pb2.ResponseTaskResult.FromString,
- )
+ "/neural_solution.TaskService/QueryTaskResult",
+ request_serializer=neural__solution__pb2.TaskId.SerializeToString,
+ response_deserializer=neural__solution__pb2.ResponseTaskResult.FromString,
+ )
class TaskServiceServicer(object):
@@ -44,26 +44,26 @@ class TaskServiceServicer(object):
def Ping(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details('Method not implemented!')
- raise NotImplementedError('Method not implemented!')
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
def SubmitTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details('Method not implemented!')
- raise NotImplementedError('Method not implemented!')
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
def GetTaskById(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details('Method not implemented!')
- raise NotImplementedError('Method not implemented!')
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
def QueryTaskResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details('Method not implemented!')
- raise NotImplementedError('Method not implemented!')
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
def add_TaskServiceServicer_to_server(servicer, server):
@@ -74,47 +74,48 @@ def add_TaskServiceServicer_to_server(servicer, server):
server (grpc._server._Server): server
"""
rpc_method_handlers = {
- 'Ping': grpc.unary_unary_rpc_method_handler(
- servicer.Ping,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=neural__solution__pb2.ResponsePingMessage.SerializeToString,
- ),
- 'SubmitTask': grpc.unary_unary_rpc_method_handler(
- servicer.SubmitTask,
- request_deserializer=neural__solution__pb2.Task.FromString,
- response_serializer=neural__solution__pb2.TaskResponse.SerializeToString,
- ),
- 'GetTaskById': grpc.unary_unary_rpc_method_handler(
- servicer.GetTaskById,
- request_deserializer=neural__solution__pb2.TaskId.FromString,
- response_serializer=neural__solution__pb2.TaskStatus.SerializeToString,
- ),
- 'QueryTaskResult': grpc.unary_unary_rpc_method_handler(
- servicer.QueryTaskResult,
- request_deserializer=neural__solution__pb2.TaskId.FromString,
- response_serializer=neural__solution__pb2.ResponseTaskResult.SerializeToString,
- ),
+ "Ping": grpc.unary_unary_rpc_method_handler(
+ servicer.Ping,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=neural__solution__pb2.ResponsePingMessage.SerializeToString,
+ ),
+ "SubmitTask": grpc.unary_unary_rpc_method_handler(
+ servicer.SubmitTask,
+ request_deserializer=neural__solution__pb2.Task.FromString,
+ response_serializer=neural__solution__pb2.TaskResponse.SerializeToString,
+ ),
+ "GetTaskById": grpc.unary_unary_rpc_method_handler(
+ servicer.GetTaskById,
+ request_deserializer=neural__solution__pb2.TaskId.FromString,
+ response_serializer=neural__solution__pb2.TaskStatus.SerializeToString,
+ ),
+ "QueryTaskResult": grpc.unary_unary_rpc_method_handler(
+ servicer.QueryTaskResult,
+ request_deserializer=neural__solution__pb2.TaskId.FromString,
+ response_serializer=neural__solution__pb2.ResponseTaskResult.SerializeToString,
+ ),
}
- generic_handler = grpc.method_handlers_generic_handler(
- 'neural_solution.TaskService', rpc_method_handlers)
+ generic_handler = grpc.method_handlers_generic_handler("neural_solution.TaskService", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Interface exported by the server."""
@staticmethod
- def Ping(request,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
+ def Ping(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
"""Tets server status.
Args:
@@ -154,23 +155,35 @@ def Ping(request,
Returns:
The response to the RPC.
"""
- return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/Ping',
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/neural_solution.TaskService/Ping",
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
neural__solution__pb2.ResponsePingMessage.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
@staticmethod
- def SubmitTask(request,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
+ def SubmitTask(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
"""Submit task.
Args:
@@ -210,23 +223,35 @@ def SubmitTask(request,
Returns:
The response to the RPC.
"""
- return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/SubmitTask',
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/neural_solution.TaskService/SubmitTask",
neural__solution__pb2.Task.SerializeToString,
neural__solution__pb2.TaskResponse.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
@staticmethod
- def GetTaskById(request,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
+ def GetTaskById(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
"""Get task status according to id.
Args:
@@ -266,23 +291,35 @@ def GetTaskById(request,
Returns:
The response to the RPC.
"""
- return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/GetTaskById',
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/neural_solution.TaskService/GetTaskById",
neural__solution__pb2.TaskId.SerializeToString,
neural__solution__pb2.TaskStatus.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
@staticmethod
- def QueryTaskResult(request,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
+ def QueryTaskResult(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
"""Get task result according to id.
Args:
@@ -322,8 +359,18 @@ def QueryTaskResult(request,
Returns:
The response to the RPC.
"""
- return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/QueryTaskResult',
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/neural_solution.TaskService/QueryTaskResult",
neural__solution__pb2.TaskId.SerializeToString,
neural__solution__pb2.ResponseTaskResult.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/neural_solution/frontend/gRPC/server.py b/neural_solution/frontend/gRPC/server.py
index e451ee3dad0..b838e375ada 100644
--- a/neural_solution/frontend/gRPC/server.py
+++ b/neural_solution/frontend/gRPC/server.py
@@ -14,25 +14,24 @@
"""Server of gRPC frontend."""
-from concurrent import futures
-import logging
-import grpc
import argparse
+import logging
+from concurrent import futures
-from neural_solution.frontend.gRPC.proto import (
- neural_solution_pb2,
- neural_solution_pb2_grpc)
+import grpc
from neural_solution.config import config
-from neural_solution.utils import logger
-from neural_solution.utils.utility import get_db_path, dict_to_str
+from neural_solution.frontend.gRPC.proto import neural_solution_pb2, neural_solution_pb2_grpc
from neural_solution.frontend.task_submitter import task_submitter
-
from neural_solution.frontend.utility import (
- submit_task_to_db,
check_service_status,
+ query_task_result,
query_task_status,
- query_task_result)
+ submit_task_to_db,
+)
+from neural_solution.utils import logger
+from neural_solution.utils.utility import dict_to_str, get_db_path
+
class TaskSubmitterServicer(neural_solution_pb2_grpc.TaskServiceServicer):
"""Deliver services.
@@ -58,7 +57,7 @@ def Ping(self, empty_msg, context):
print(f"Ping grpc serve.")
port_lst = [config.result_monitor_port]
result = check_service_status(port_lst, service_address=config.service_address)
- response = neural_solution_pb2.ResponsePingMessage(**result) # pylint: disable=no-member
+ response = neural_solution_pb2.ResponsePingMessage(**result) # pylint: disable=no-member
return response
def SubmitTask(self, task, context):
@@ -82,7 +81,7 @@ def SubmitTask(self, task, context):
print(db_path)
result = submit_task_to_db(task=task, task_submitter=task_submitter, db_path=get_db_path(config.workspace))
# Return a response
- response = neural_solution_pb2.TaskResponse(**result) # pylint: disable=no-member
+ response = neural_solution_pb2.TaskResponse(**result) # pylint: disable=no-member
return response
def GetTaskById(self, task_id, context):
@@ -97,7 +96,7 @@ def GetTaskById(self, task_id, context):
db_path = get_db_path(config.workspace)
result = query_task_status(task_id.task_id, db_path)
print(f"query result : result")
- response = neural_solution_pb2.TaskStatus(**result) # pylint: disable=no-member
+ response = neural_solution_pb2.TaskStatus(**result) # pylint: disable=no-member
return response
def QueryTaskResult(self, task_id, context):
@@ -111,9 +110,9 @@ def QueryTaskResult(self, task_id, context):
"""
db_path = get_db_path(config.workspace)
result = query_task_result(task_id.task_id, db_path, config.workspace)
- result['tuning_information'] = dict_to_str(result["tuning_information"])
- result['optimization_result'] = dict_to_str(result["optimization_result"])
- response = neural_solution_pb2.ResponseTaskResult(**result) # pylint: disable=no-member
+ result["tuning_information"] = dict_to_str(result["tuning_information"])
+ result["optimization_result"] = dict_to_str(result["optimization_result"])
+ response = neural_solution_pb2.ResponseTaskResult(**result) # pylint: disable=no-member
return response
@@ -121,9 +120,8 @@ def serve():
"""Service entrance."""
port = str(config.grpc_api_port)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- neural_solution_pb2_grpc.add_TaskServiceServicer_to_server(
- TaskSubmitterServicer(), server)
- server.add_insecure_port('[::]:' + port)
+ neural_solution_pb2_grpc.add_TaskServiceServicer_to_server(TaskSubmitterServicer(), server)
+ server.add_insecure_port("[::]:" + port)
server.start()
print("Server started, listening on " + port)
server.wait_for_termination()
@@ -132,21 +130,16 @@ def serve():
def parse_arguments():
"""Parse the command line options."""
parser = argparse.ArgumentParser(description="Frontend with gRPC API")
- parser.add_argument("-H", "--host", type=str, default="0.0.0.0", \
- help="The address to submit task.")
- parser.add_argument("-FP", "--grpc_api_port", type=int, default=8001, \
- help="Port to submit task by user.")
- parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \
- help="Port to monitor task.")
- parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \
- help="Port to monitor result.")
- parser.add_argument("-WS", "--workspace", type=str, default="./ns_workspace", \
- help="Work space.")
+ parser.add_argument("-H", "--host", type=str, default="0.0.0.0", help="The address to submit task.")
+ parser.add_argument("-FP", "--grpc_api_port", type=int, default=8001, help="Port to submit task by user.")
+ parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.")
+ parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.")
+ parser.add_argument("-WS", "--workspace", type=str, default="./ns_workspace", help="Work space.")
args = parser.parse_args()
return args
-if __name__ == '__main__':
+if __name__ == "__main__":
logger.info(f"Try to start gRPC server.")
logging.basicConfig()
args = parse_arguments()
diff --git a/neural_solution/frontend/task_submitter.py b/neural_solution/frontend/task_submitter.py
index 880496fd14e..e4274260c3d 100644
--- a/neural_solution/frontend/task_submitter.py
+++ b/neural_solution/frontend/task_submitter.py
@@ -14,12 +14,14 @@
"""Neural Solution task submitter."""
-import socket
import json
+import socket
+
from pydantic import BaseModel # pylint: disable=no-name-in-module
from neural_solution.config import config
+
class Task(BaseModel):
"""Task definition for submitting requests.
@@ -34,6 +36,7 @@ class Task(BaseModel):
requirements: list
workers: int
+
class TaskSubmitter:
"""Responsible for submitting tasks."""
@@ -65,5 +68,7 @@ def submit_task(self, tid):
s.send(self.serialize(tid))
s.close()
-task_submitter = TaskSubmitter(task_monitor_port=config.task_monitor_port,
- result_monitor_port=config.result_monitor_port)
\ No newline at end of file
+
+task_submitter = TaskSubmitter(
+ task_monitor_port=config.task_monitor_port, result_monitor_port=config.result_monitor_port
+)
diff --git a/neural_solution/frontend/utility.py b/neural_solution/frontend/utility.py
index 454603ad1db..fd8fc2dcc25 100644
--- a/neural_solution/frontend/utility.py
+++ b/neural_solution/frontend/utility.py
@@ -15,13 +15,16 @@
"""Common utilities for all frontend components."""
import json
-import sqlite3
import os
import re
+import socket
+import sqlite3
import uuid
+
import pandas as pd
-import socket
-from neural_solution.utils.utility import get_task_log_workspace, dict_to_str
+
+from neural_solution.utils.utility import dict_to_str, get_task_log_workspace
+
def query_task_status(task_id, db_path):
"""Query task status according to id.
@@ -41,9 +44,12 @@ def query_task_status(task_id, db_path):
res = cursor.fetchone()
cursor.close()
conn.close()
- return {"status": res[0],
- 'optimized_result': dict_to_str(deserialize(res[1]) if res[1] else res[1]),
- "result_path": res[2]}
+ return {
+ "status": res[0],
+ "optimized_result": dict_to_str(deserialize(res[1]) if res[1] else res[1]),
+ "result_path": res[2],
+ }
+
def query_task_result(task_id, db_path, workspace):
"""Query the task result according id.
@@ -64,13 +70,13 @@ def query_task_result(task_id, db_path, workspace):
if os.path.isfile(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
- cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id, ))
+ cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id,))
res = cursor.fetchone()
cursor.close()
conn.close()
print(f"in query")
if not res:
- status = "Please check url."
+ status = "Please check url."
elif res[0] == "done":
status = res[0]
optimization_result = deserialize(res[1]) if res[1] else res[1]
@@ -85,6 +91,7 @@ def query_task_result(task_id, db_path, workspace):
result = {"status": status, "tuning_information": tuning_info, "optimization_result": optimization_result}
return result
+
def check_service_status(port_lst, service_address):
"""Check server status.
@@ -109,8 +116,8 @@ def check_service_status(port_lst, service_address):
sock.close()
continue
except ConnectionRefusedError:
- msg = "Ping fail! Make sure Neural Solution runner is running!"
- break
+ msg = "Ping fail! Make sure Neural Solution runner is running!"
+ break
except Exception as e:
msg = "Ping fail! {}".format(e)
break
@@ -136,16 +143,19 @@ def submit_task_to_db(task, task_submitter, db_path):
if os.path.isfile(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
- task_id = str(uuid.uuid4()).replace('-','')
- sql = r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" +\
- r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(
- task_id,
- task.script_url,
- task.optimized,
- list_to_string(task.arguments),
- task.approach,
- list_to_string(task.requirements),
- task.workers)
+ task_id = str(uuid.uuid4()).replace("-", "")
+ sql = (
+ r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)"
+ + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(
+ task_id,
+ task.script_url,
+ task.optimized,
+ list_to_string(task.arguments),
+ task.approach,
+ list_to_string(task.requirements),
+ task.workers,
+ )
+ )
cursor.execute(sql)
conn.commit()
try:
@@ -161,18 +171,21 @@ def submit_task_to_db(task, task_submitter, db_path):
msg = "Task Submitted fail! db not found!"
result["status"] = status
result["task_id"] = task_id
- result["msg"]=msg
+ result["msg"] = msg
return result
+
def serialize(request: dict) -> bytes:
"""Serialize a dict object to bytes for inter-process communication."""
return json.dumps(request).encode()
+
def deserialize(request: bytes) -> dict:
"""Deserialize the received bytes to a dict object."""
return json.loads(request)
-def get_cluster_info(db_path:str):
+
+def get_cluster_info(db_path: str):
"""Get cluster information from database.
Returns:
@@ -186,7 +199,8 @@ def get_cluster_info(db_path:str):
conn.close()
return {"Cluster info": rows}
-def get_cluster_table(db_path:str):
+
+def get_cluster_table(db_path: str):
"""Get cluster table from database.
Returns:
@@ -197,11 +211,14 @@ def get_cluster_table(db_path:str):
cursor.execute(r"select * from cluster")
conn.commit()
rows = cursor.fetchall()
- df = pd.DataFrame(rows, columns=["Node", "Node info", "status","free workers", "busy workers", "total workers"])
- html_table = df.to_html(index=False, )
+ df = pd.DataFrame(rows, columns=["Node", "Node info", "status", "free workers", "busy workers", "total workers"])
+ html_table = df.to_html(
+ index=False,
+ )
conn.close()
return html_table
+
def get_res_during_tuning(task_id: str, task_log_path):
"""Get result during tuning.
@@ -214,8 +231,8 @@ def get_res_during_tuning(task_id: str, task_log_path):
results = {}
log_path = "{}/task_{}.txt".format(task_log_path, task_id)
for line in reversed(open(log_path).readlines()):
- res_pattern = r'Tune (\d+) result is: '
- res_pattern = r'Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?'
+ res_pattern = r"Tune (\d+) result is: "
+ res_pattern = r"Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?"
res_matches = re.findall(res_pattern, line)
if res_matches:
results["Tuning count"] = res_matches[0][0]
@@ -225,7 +242,8 @@ def get_res_during_tuning(task_id: str, task_log_path):
break
print("Query results: {}".format(results))
- return results if results else "Tune 1 running..."
+ return results if results else "Tune 1 running..."
+
def get_baseline_during_tuning(task_id: str, task_log_path):
"""Get result during tuning.
@@ -248,7 +266,8 @@ def get_baseline_during_tuning(task_id: str, task_log_path):
break
print("FP32 baseline: {}".format(results))
- return results if results else "Getting FP32 baseline..."
+ return results if results else "Getting FP32 baseline..."
+
def check_log_exists(task_id: str, task_log_path):
"""Check whether the log file exists.
@@ -265,6 +284,7 @@ def check_log_exists(task_id: str, task_log_path):
else:
return False
+
def list_to_string(lst: list):
"""Convert the list to a space concatenated string.
@@ -274,4 +294,4 @@ def list_to_string(lst: list):
Returns:
str: string
"""
- return " ".join(str(i) for i in lst)
\ No newline at end of file
+ return " ".join(str(i) for i in lst)
diff --git a/neural_solution/launcher.py b/neural_solution/launcher.py
index c2628af12a4..3711bed982f 100644
--- a/neural_solution/launcher.py
+++ b/neural_solution/launcher.py
@@ -13,16 +13,18 @@
# limitations under the License.
"""The entry of Neural Solution."""
-import sys
-import subprocess
-import os
import argparse
+import os
+import shlex
import socket
-import psutil
+import subprocess
+import sys
import time
-import shlex
from datetime import datetime
+import psutil
+
+
def check_ports(args):
"""Check parameters ending in '_port'.
@@ -30,9 +32,10 @@ def check_ports(args):
args (argparse.Namespace): parameters.
"""
for arg in vars(args):
- if '_port' in arg:
+ if "_port" in arg:
check_port(getattr(args, arg))
+
def check_port(port):
"""Check if the given port is standardized.
@@ -43,6 +46,7 @@ def check_port(port):
print(f"Error: Invalid port number: {port}")
sys.exit(1)
+
def get_local_service_ip(port):
"""Get the local IP address of the machine running the service.
@@ -53,32 +57,34 @@ def get_local_service_ip(port):
str: The IP address of the machine running the service.
"""
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', port))
+ s.connect(("8.8.8.8", port))
return s.getsockname()[0]
+
def stop_service():
"""Stop service."""
# Get all running processes
for proc in psutil.process_iter():
try:
# Get the process details
- pinfo = proc.as_dict(attrs=['pid', 'name', 'cmdline'])
+ pinfo = proc.as_dict(attrs=["pid", "name", "cmdline"])
# Check if the process is the target process
- if "neural_solution.backend.runner" in pinfo['cmdline']:
+ if "neural_solution.backend.runner" in pinfo["cmdline"]:
# Terminate the process using Process.kill() method
- process = psutil.Process(pinfo['pid'])
+ process = psutil.Process(pinfo["pid"])
process.kill()
- elif "neural_solution.frontend.fastapi.main_server" in pinfo['cmdline']:
- process = psutil.Process(pinfo['pid'])
+ elif "neural_solution.frontend.fastapi.main_server" in pinfo["cmdline"]:
+ process = psutil.Process(pinfo["pid"])
process.kill()
- elif "neural_solution.frontend.gRPC.server" in pinfo['cmdline']:
- process = psutil.Process(pinfo['pid'])
+ elif "neural_solution.frontend.gRPC.server" in pinfo["cmdline"]:
+ process = psutil.Process(pinfo["pid"])
process.kill()
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
# Service End
print("Neural Solution Service Stopped!")
+
def check_port_free(port):
"""Check if the port is free.
@@ -89,10 +95,11 @@ def check_port_free(port):
bool : the free state of the port.
"""
for conn in psutil.net_connections():
- if conn.status == 'LISTEN' and conn.laddr.port == port:
+ if conn.status == "LISTEN" and conn.laddr.port == port:
return False
return True
+
def start_service(args):
"""Start service.
@@ -118,8 +125,10 @@ def start_service(args):
print("No environment specified or conda environment activated !!!")
sys.exit(1)
else:
- print(f"No environment specified, use environment activated:" + \
- f" ({conda_env}) as the task runtime environment.")
+ print(
+ f"No environment specified, use environment activated:"
+ + f" ({conda_env}) as the task runtime environment."
+ )
conda_env_name = conda_env
else:
conda_env_name = args.conda_env
@@ -132,34 +141,67 @@ def start_service(args):
date_suffix = "_" + date_time.strftime("%Y%m%d-%H%M%S")
date_suffix = ""
with open(f"{serve_log_dir}/backend{date_suffix}.log", "w") as f:
- subprocess.Popen([
- "python", "-m", "neural_solution.backend.runner",
- "--hostfile", shlex.quote(str(args.hostfile)),
- "--task_monitor_port", shlex.quote(str(args.task_monitor_port)),
- "--result_monitor_port", shlex.quote(str(args.result_monitor_port)),
- "--workspace", shlex.quote(str(args.workspace)),
- "--conda_env_name", shlex.quote(str(conda_env_name)),
- "--upload_path", shlex.quote(str(args.upload_path))
- ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT)
+ subprocess.Popen(
+ [
+ "python",
+ "-m",
+ "neural_solution.backend.runner",
+ "--hostfile",
+ shlex.quote(str(args.hostfile)),
+ "--task_monitor_port",
+ shlex.quote(str(args.task_monitor_port)),
+ "--result_monitor_port",
+ shlex.quote(str(args.result_monitor_port)),
+ "--workspace",
+ shlex.quote(str(args.workspace)),
+ "--conda_env_name",
+ shlex.quote(str(conda_env_name)),
+ "--upload_path",
+ shlex.quote(str(args.upload_path)),
+ ],
+ stdout=os.dup(f.fileno()),
+ stderr=subprocess.STDOUT,
+ )
if args.api_type in ["all", "restful"]:
with open(f"{serve_log_dir}/frontend{date_suffix}.log", "w") as f:
- subprocess.Popen([
- "python", "-m", "neural_solution.frontend.fastapi.main_server",
- "--host", "0.0.0.0",
- "--fastapi_port", shlex.quote(str(args.restful_api_port)),
- "--task_monitor_port", shlex.quote(str(args.task_monitor_port)),
- "--result_monitor_port", shlex.quote(str(args.result_monitor_port)),
- "--workspace", shlex.quote(str(args.workspace))
- ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT)
+ subprocess.Popen(
+ [
+ "python",
+ "-m",
+ "neural_solution.frontend.fastapi.main_server",
+ "--host",
+ "0.0.0.0",
+ "--fastapi_port",
+ shlex.quote(str(args.restful_api_port)),
+ "--task_monitor_port",
+ shlex.quote(str(args.task_monitor_port)),
+ "--result_monitor_port",
+ shlex.quote(str(args.result_monitor_port)),
+ "--workspace",
+ shlex.quote(str(args.workspace)),
+ ],
+ stdout=os.dup(f.fileno()),
+ stderr=subprocess.STDOUT,
+ )
if args.api_type in ["all", "grpc"]:
with open(f"{serve_log_dir}/frontend_grpc.log", "w") as f:
- subprocess.Popen([
- "python", "-m", "neural_solution.frontend.gRPC.server",
- "--grpc_api_port", shlex.quote(str(args.grpc_api_port)),
- "--task_monitor_port", shlex.quote(str(args.task_monitor_port)),
- "--result_monitor_port", shlex.quote(str(args.result_monitor_port)),
- "--workspace", shlex.quote(str(args.workspace))
- ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT)
+ subprocess.Popen(
+ [
+ "python",
+ "-m",
+ "neural_solution.frontend.gRPC.server",
+ "--grpc_api_port",
+ shlex.quote(str(args.grpc_api_port)),
+ "--task_monitor_port",
+ shlex.quote(str(args.task_monitor_port)),
+ "--result_monitor_port",
+ shlex.quote(str(args.result_monitor_port)),
+ "--workspace",
+ shlex.quote(str(args.workspace)),
+ ],
+ stdout=os.dup(f.fileno()),
+ stderr=subprocess.STDOUT,
+ )
ip_address = get_local_service_ip(80)
# Check if the service is started
@@ -169,8 +211,11 @@ def start_service(args):
start_time = time.time()
while True:
# Check if the ports are in use
- if check_port_free(args.task_monitor_port) or check_port_free(args.result_monitor_port) \
- or check_port_free(args.restful_api_port):
+ if (
+ check_port_free(args.task_monitor_port)
+ or check_port_free(args.result_monitor_port)
+ or check_port_free(args.restful_api_port)
+ ):
# If the ports are not in use, wait for a second and check again
time.sleep(0.5)
# Check if timed out
@@ -205,28 +250,40 @@ def start_service(args):
# Check completed
print("Neural Solution Service Started!")
- print(f"Service log saving path is in \"{os.path.abspath(serve_log_dir)}\"")
+ print(f'Service log saving path is in "{os.path.abspath(serve_log_dir)}"')
print(f"To submit task at: {ip_address}:{args.restful_api_port}/task/submit/")
print("[For information] neural_solution help")
+
def main():
"""Implement the main function."""
parser = argparse.ArgumentParser(description="Neural Solution")
- parser.add_argument('action', choices=['start', 'stop'], help='start/stop service')
- parser.add_argument("--hostfile", default=None,
- help="start backend serve host file which contains all available nodes")
- parser.add_argument("--restful_api_port", type=int, default=8000,
- help="start restful serve with {restful_api_port}, default 8000")
- parser.add_argument("--grpc_api_port", type=int, default=8001,
- help="start gRPC with {restful_api_port}, default 8001")
- parser.add_argument("--result_monitor_port", type=int, default=3333,
- help="start serve for result monitor at {result_monitor_port}, default 3333")
- parser.add_argument("--task_monitor_port", type=int, default=2222,
- help="start serve for task monitor at {task_monitor_port}, default 2222")
- parser.add_argument("--api_type", default="all",
- help="start web serve with all/grpc/restful, default all")
- parser.add_argument("--workspace", default="./ns_workspace",
- help="neural solution workspace, default \"./ns_workspace\"")
+ parser.add_argument("action", choices=["start", "stop"], help="start/stop service")
+ parser.add_argument(
+ "--hostfile", default=None, help="start backend serve host file which contains all available nodes"
+ )
+ parser.add_argument(
+ "--restful_api_port", type=int, default=8000, help="start restful serve with {restful_api_port}, default 8000"
+ )
+ parser.add_argument(
+ "--grpc_api_port", type=int, default=8001, help="start gRPC with {restful_api_port}, default 8001"
+ )
+ parser.add_argument(
+ "--result_monitor_port",
+ type=int,
+ default=3333,
+ help="start serve for result monitor at {result_monitor_port}, default 3333",
+ )
+ parser.add_argument(
+ "--task_monitor_port",
+ type=int,
+ default=2222,
+ help="start serve for task monitor at {task_monitor_port}, default 2222",
+ )
+ parser.add_argument("--api_type", default="all", help="start web serve with all/grpc/restful, default all")
+ parser.add_argument(
+ "--workspace", default="./ns_workspace", help='neural solution workspace, default "./ns_workspace"'
+ )
parser.add_argument("--conda_env", default=None, help="specify the running environment for the task")
parser.add_argument("--upload_path", default="examples", help="specify the file path for the tasks")
args = parser.parse_args()
@@ -234,10 +291,11 @@ def main():
# Check parameters ending in '_port'
check_ports(args)
- if args.action == 'start':
+ if args.action == "start":
start_service(args)
- elif args.action == 'stop':
+ elif args.action == "stop":
stop_service()
-
-if __name__ == '__main__':
+
+
+if __name__ == "__main__":
main()
diff --git a/neural_solution/scripts/prepare_deps.py b/neural_solution/scripts/prepare_deps.py
index 33c711b9273..c7df3524ac2 100644
--- a/neural_solution/scripts/prepare_deps.py
+++ b/neural_solution/scripts/prepare_deps.py
@@ -20,4 +20,3 @@
- CONDA
- other packages, such as, mpi4py
"""
-
diff --git a/neural_solution/test/backend/test_cluster.py b/neural_solution/test/backend/test_cluster.py
index fbb001d14c7..8ea687e81b6 100644
--- a/neural_solution/test/backend/test_cluster.py
+++ b/neural_solution/test/backend/test_cluster.py
@@ -1,16 +1,12 @@
"""Tests for cluster"""
import importlib
-import shutil
import os
+import shutil
import unittest
-
-import unittest
from neural_solution.backend.cluster import Cluster, Node
from neural_solution.backend.task import Task
-
-
-from neural_solution.utils.utility import get_task_workspace, get_task_log_workspace, get_db_path
+from neural_solution.utils.utility import get_db_path, get_task_log_workspace, get_task_workspace
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE)
@@ -23,16 +19,16 @@ def setUp(self):
self.cluster = Cluster(node_lst, db_path=db_path)
self.task = Task(
- task_id = "1",
- arguments = ["arg1", "arg2"],
- workers = 2,
- status = "pending",
- script_url = "https://example.com",
- optimized = True,
- approach = "static",
- requirement = ["req1", "req2"],
- result = "",
- q_model_path = "q_model_path"
+ task_id="1",
+ arguments=["arg1", "arg2"],
+ workers=2,
+ status="pending",
+ script_url="https://example.com",
+ optimized=True,
+ approach="static",
+ requirement=["req1", "req2"],
+ result="",
+ q_model_path="q_model_path",
)
@classmethod
@@ -49,7 +45,7 @@ def test_free_resource(self):
task = self.task
reserved_resource_lst = self.cluster.reserve_resource(task)
self.cluster.free_resource(reserved_resource_lst)
- self.assertEqual(self.cluster.socket_queue, ['2 node2', '2 node2', '1 node1', '1 node1'])
+ self.assertEqual(self.cluster.socket_queue, ["2 node2", "2 node2", "1 node1", "1 node1"])
def test_get_free_socket(self):
free_socket_lst = self.cluster.get_free_socket(4)
@@ -61,5 +57,6 @@ def test_get_free_socket(self):
free_socket_lst = self.cluster.get_free_socket(10)
self.assertEqual(free_socket_lst, 0)
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_result_monitor.py b/neural_solution/test/backend/test_result_monitor.py
index 18f83a78196..1691826729f 100644
--- a/neural_solution/test/backend/test_result_monitor.py
+++ b/neural_solution/test/backend/test_result_monitor.py
@@ -2,28 +2,29 @@
import threading
import unittest
from unittest.mock import MagicMock, patch
+
from neural_solution.backend.result_monitor import ResultMonitor
-class TestResultMonitor(unittest.TestCase):
- @patch('socket.socket')
+class TestResultMonitor(unittest.TestCase):
+ @patch("socket.socket")
def test_wait_result(self, mock_socket):
# Mock data for testing
task_db = MagicMock()
task_db.lookup_task_status.return_value = "COMPLETED"
- result = {'task_id': 1, 'q_model_path': "path/to/q_model", 'result': 0.8}
+ result = {"task_id": 1, "q_model_path": "path/to/q_model", "result": 0.8}
serialized_result = json.dumps(result)
-
+
mock_c = MagicMock()
mock_c.recv.return_value = serialized_result
-
+
mock_socket.return_value.accept.return_value = (mock_c, MagicMock())
mock_socket.return_value.recv.return_value = serialized_result
mock_socket.return_value.__enter__.return_value = mock_socket.return_value
# Create a ResultMonitor object and call the wait_result method
result_monitor = ResultMonitor(8080, task_db)
- with patch('neural_solution.backend.result_monitor.deserialize', return_value={"ping": "test"}):
+ with patch("neural_solution.backend.result_monitor.deserialize", return_value={"ping": "test"}):
adding_abort = threading.Thread(
target=result_monitor.wait_result,
args=(),
@@ -43,7 +44,7 @@ def test_query_task_status(self):
# Assert that the task_db.lookup_task_status method was called with the correct argument
task_db.lookup_task_status.assert_called_once_with(1)
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_runner.py b/neural_solution/test/backend/test_runner.py
index 1c0e447079b..c400f00c427 100644
--- a/neural_solution/test/backend/test_runner.py
+++ b/neural_solution/test/backend/test_runner.py
@@ -1,26 +1,33 @@
-import unittest
-from unittest.mock import patch
-import threading
-import os
import argparse
+import os
import shutil
-from neural_solution.backend.runner import parse_args, main
+import threading
+import unittest
+from unittest.mock import patch
-class TestMain(unittest.TestCase):
+from neural_solution.backend.runner import main, parse_args
+
+class TestMain(unittest.TestCase):
@classmethod
def tearDownClass(cls) -> None:
- os.remove('test.txt')
+ os.remove("test.txt")
shutil.rmtree("ns_workspace", ignore_errors=True)
def test_parse_args(self):
- args = ['-H', 'path/to/hostfile', '-TMP', '2222', '-RMP', '3333', '-CEN', 'inc']
- with patch('argparse.ArgumentParser.parse_args', \
- return_value=argparse.Namespace(hostfile='path/to/hostfile', \
- task_monitor_port=2222, result_monitor_port=3333, conda_env_name='inc')):
- self.assertEqual(parse_args(args), \
- argparse.Namespace(hostfile='path/to/hostfile', \
- task_monitor_port=2222, result_monitor_port=3333, conda_env_name='inc'))
+ args = ["-H", "path/to/hostfile", "-TMP", "2222", "-RMP", "3333", "-CEN", "inc"]
+ with patch(
+ "argparse.ArgumentParser.parse_args",
+ return_value=argparse.Namespace(
+ hostfile="path/to/hostfile", task_monitor_port=2222, result_monitor_port=3333, conda_env_name="inc"
+ ),
+ ):
+ self.assertEqual(
+ parse_args(args),
+ argparse.Namespace(
+ hostfile="path/to/hostfile", task_monitor_port=2222, result_monitor_port=3333, conda_env_name="inc"
+ ),
+ )
def test_main(self):
"""Test blocking flag in abort_job method."""
@@ -29,11 +36,12 @@ def test_main(self):
f.write("hostname1 2 20\nhostname2 2 20")
adding_abort = threading.Thread(
target=main,
- kwargs={'args': ['-H', 'test.txt', '-TMP', '2222', '-RMP', '3333', '-CEN', 'inc_conda_env']},
+ kwargs={"args": ["-H", "test.txt", "-TMP", "2222", "-RMP", "3333", "-CEN", "inc_conda_env"]},
daemon=True,
)
adding_abort.start()
adding_abort.join(timeout=2)
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_scheduler.py b/neural_solution/test/backend/test_scheduler.py
index 8139a5afb7b..a84689d4658 100644
--- a/neural_solution/test/backend/test_scheduler.py
+++ b/neural_solution/test/backend/test_scheduler.py
@@ -1,30 +1,31 @@
-import unittest
import os
-import threading
import shutil
+import threading
+import unittest
from subprocess import CalledProcessError
-from unittest.mock import MagicMock, patch, mock_open, Mock
+from unittest.mock import MagicMock, Mock, mock_open, patch
+
+from neural_solution.backend.cluster import Cluster
from neural_solution.backend.scheduler import Scheduler
from neural_solution.backend.task import Task
-from neural_solution.backend.cluster import Cluster
from neural_solution.backend.task_db import TaskDB
from neural_solution.backend.utils.utility import dump_elapsed_time, get_task_log_path
-
-import os
-
-from neural_solution.utils.utility import get_db_path
from neural_solution.config import config
+from neural_solution.utils.utility import get_db_path
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE)
config.workspace = NEURAL_SOLUTION_WORKSPACE
+
class TestScheduler(unittest.TestCase):
def setUp(self):
self.cluster = Cluster(db_path=db_path)
self.task_db = TaskDB(db_path=db_path)
self.result_monitor_port = 1234
- self.scheduler = Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="for_ns_test", config=config)
+ self.scheduler = Scheduler(
+ self.cluster, self.task_db, self.result_monitor_port, conda_env_name="for_ns_test", config=config
+ )
def tearDown(self) -> None:
shutil.rmtree("ns_workspace", ignore_errors=True)
@@ -34,28 +35,69 @@ def tearDownClass(cls) -> None:
shutil.rmtree("examples")
def test_prepare_env(self):
- task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "pip", "test_result", "test_q_model_path")
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "pip",
+ "test_result",
+ "test_q_model_path",
+ )
result = self.scheduler.prepare_env(task)
self.assertTrue(result.startswith(self.scheduler.conda_env_name))
# Test requirement in {conda_env} case
- task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "pip", "test_result", "test_q_model_path")
- scheduler_test = Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="base", config=config)
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "pip",
+ "test_result",
+ "test_q_model_path",
+ )
+ scheduler_test = Scheduler(
+ self.cluster, self.task_db, self.result_monitor_port, conda_env_name="base", config=config
+ )
result = scheduler_test.prepare_env(task)
- self.assertTrue(result.startswith('base'))
+ self.assertTrue(result.startswith("base"))
# Test requirement is '' case
- task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "", "test_result", "test_q_model_path")
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "",
+ "test_result",
+ "test_q_model_path",
+ )
result = self.scheduler.prepare_env(task)
self.assertEqual(result, self.scheduler.conda_env_name)
def test_prepare_task(self):
- task = Task("test_task", "test_arguments", "test_workers", "test_status",\
- "test_example", \
- "test_optimized", "static", "test_requirement", "test_result", "test_q_model_path")
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_example",
+ "test_optimized",
+ "static",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
test_path = "examples/test_example"
if not os.path.exists(test_path):
os.makedirs(test_path)
@@ -65,33 +107,50 @@ def test_prepare_task(self):
self.scheduler.prepare_task(task)
# url case
- with patch('neural_solution.backend.scheduler.is_remote_url', return_value=True):
+ with patch("neural_solution.backend.scheduler.is_remote_url", return_value=True):
self.scheduler.prepare_task(task)
# optimized is False case
- task = Task("test_task", "test_arguments", "test_workers", "test_status",\
- "test_example", \
- False, "static", "test_requirement", "test_result", "test_q_model_path")
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_example",
+ False,
+ "static",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
self.scheduler.prepare_task(task)
- with patch('neural_solution.backend.scheduler.is_remote_url', return_value=True):
- task = Task("test_task", "test_arguments", "test_workers", "test_status",\
- "test_example/test.py", \
- False, "static", "test_requirement", "test_result", "test_q_model_path")
+ with patch("neural_solution.backend.scheduler.is_remote_url", return_value=True):
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_example/test.py",
+ False,
+ "static",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
self.scheduler.prepare_task(task)
-
def test_check_task_status(self):
log_path = "test_log_path"
# done case
- with patch('builtins.open', mock_open(read_data='[INFO] Save deploy yaml to\n')) as mock_file:
+ with patch("builtins.open", mock_open(read_data="[INFO] Save deploy yaml to\n")) as mock_file:
result = self.scheduler.check_task_status(log_path)
- self.assertEqual(result, 'done')
+ self.assertEqual(result, "done")
# failed case
- with patch('builtins.open', mock_open(read_data='[INFO] Deploying...\n')) as mock_file:
+ with patch("builtins.open", mock_open(read_data="[INFO] Deploying...\n")) as mock_file:
result = self.scheduler.check_task_status(log_path)
- self.assertEqual(result, 'failed')
+ self.assertEqual(result, "failed")
def test_sanitize_arguments(self):
arguments = "test_arguments\xa0"
@@ -99,15 +158,25 @@ def test_sanitize_arguments(self):
self.assertEqual(sanitized_arguments, "test_arguments ")
def test_dispatch_task(self):
- task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path")
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
resource = [("node1", "8"), ("node2", "8")]
- with patch('neural_solution.backend.scheduler.Scheduler.launch_task') as mock_launch_task:
+ with patch("neural_solution.backend.scheduler.Scheduler.launch_task") as mock_launch_task:
self.scheduler.dispatch_task(task, resource)
self.assertTrue(mock_launch_task.called)
- @patch('socket.socket')
- @patch('builtins.open')
+ @patch("socket.socket")
+ @patch("builtins.open")
def test_report_result(self, mock_open, mock_socket):
task_id = "test_task"
log_path = "test_log_path"
@@ -115,29 +184,42 @@ def test_report_result(self, mock_open, mock_socket):
self.scheduler.q_model_path = None
mock_socket.return_value.connect.return_value = None
mock_open.return_value.readlines.return_value = ["Tune 1 result is: (int8|fp32): 0.8 (int8|fp32): 0.9"]
- expected_result = {
- "optimization time (seconds)": "10.00",
- "Accuracy": "0.8",
- "Duration (seconds)": "0.9"
- }
+ expected_result = {"optimization time (seconds)": "10.00", "Accuracy": "0.8", "Duration (seconds)": "0.9"}
self.scheduler.report_result(task_id, log_path, task_runtime)
mock_open.assert_called_once_with(log_path)
mock_socket.assert_called_once_with()
- mock_socket.return_value.connect.assert_called_once_with(('localhost', 1234))
+ mock_socket.return_value.connect.assert_called_once_with(("localhost", 1234))
mock_socket.return_value.send.assert_called_once()
-
- @patch('neural_solution.backend.scheduler.Scheduler.prepare_task')
- @patch('neural_solution.backend.scheduler.Scheduler.prepare_env')
- @patch('neural_solution.backend.scheduler.Scheduler._parse_cmd')
- @patch('subprocess.Popen')
- @patch('neural_solution.backend.scheduler.Scheduler.check_task_status')
- @patch('neural_solution.backend.scheduler.Scheduler.report_result')
- def test_launch_task(self, mock_report_result, mock_check_task_status, mock_popen, mock_parse_cmd, mock_prepare_env, mock_prepare_task):
- task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path")
+ @patch("neural_solution.backend.scheduler.Scheduler.prepare_task")
+ @patch("neural_solution.backend.scheduler.Scheduler.prepare_env")
+ @patch("neural_solution.backend.scheduler.Scheduler._parse_cmd")
+ @patch("subprocess.Popen")
+ @patch("neural_solution.backend.scheduler.Scheduler.check_task_status")
+ @patch("neural_solution.backend.scheduler.Scheduler.report_result")
+ def test_launch_task(
+ self,
+ mock_report_result,
+ mock_check_task_status,
+ mock_popen,
+ mock_parse_cmd,
+ mock_prepare_env,
+ mock_prepare_task,
+ ):
+ task = Task(
+ "test_task",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
resource = ["1 node1", "2 node2"]
mock_parse_cmd.return_value = "test_cmd"
mock_check_task_status.return_value = "done"
@@ -148,16 +230,26 @@ def test_launch_task(self, mock_report_result, mock_check_task_status, mock_pope
self.scheduler.launch_task(task, resource)
-
-
- @patch('neural_solution.backend.scheduler.Scheduler.launch_task')
- @patch('neural_solution.backend.cluster.Cluster.reserve_resource')
+ @patch("neural_solution.backend.scheduler.Scheduler.launch_task")
+ @patch("neural_solution.backend.cluster.Cluster.reserve_resource")
def test_schedule_tasks(self, mock_reserve_resource, mock_launch_task):
- task1 = Task("1", "test_arguments", "test_workers", "test_status", "test_script_url", \
- "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path")
- self.task_db.cursor.execute("insert into task values ('1', 'test_arguments', 'test_workers', \
+ task1 = Task(
+ "1",
+ "test_arguments",
+ "test_workers",
+ "test_status",
+ "test_script_url",
+ "test_optimized",
+ "test_approach",
+ "test_requirement",
+ "test_result",
+ "test_q_model_path",
+ )
+ self.task_db.cursor.execute(
+ "insert into task values ('1', 'test_arguments', 'test_workers', \
'test_status', 'test_script_url', \
- 'test_optimized', 'test_approach', 'test_requirement', 'test_result', 'test_q_model_path')")
+ 'test_optimized', 'test_approach', 'test_requirement', 'test_result', 'test_q_model_path')"
+ )
# no pending task case
adding_abort = threading.Thread(
@@ -194,20 +286,20 @@ def test_schedule_tasks(self, mock_reserve_resource, mock_launch_task):
class TestParseCmd(unittest.TestCase):
-
def setUp(self):
self.cluster = Cluster(db_path=db_path)
self.task_db = TaskDB(db_path=db_path)
self.result_monitor_port = 1234
- self.task_scheduler = \
- Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="for_ns_test", config=config)
+ self.task_scheduler = Scheduler(
+ self.cluster, self.task_db, self.result_monitor_port, conda_env_name="for_ns_test", config=config
+ )
self.task = MagicMock()
- self.resource = ['1 node1', '2 node2', '3 node3']
+ self.resource = ["1 node1", "2 node2", "3 node3"]
self.task.workers = 3
- self.task_name = 'test_task'
- self.script_name = 'test_script.py'
- self.task_path = '/path/to/task'
- self.arguments = 'arg1 arg2'
+ self.task_name = "test_task"
+ self.script_name = "test_script.py"
+ self.task_path = "/path/to/task"
+ self.arguments = "arg1 arg2"
self.task.arguments = self.arguments
self.task.name = self.task_name
self.task.optimized = True
@@ -217,33 +309,37 @@ def setUp(self):
self.task_scheduler.task_path = self.task_path
def test__parse_cmd(self):
- expected_cmd = 'cd /path/to/task\nmpirun -np 3 -host node1,node2,node3 -map-by socket:pe=5' + \
- ' -mca btl_tcp_if_include 192.168.20.0/24 -x OMP_NUM_THREADS=5 --report-bindings bash distributed_run.sh'
- with patch('neural_solution.backend.scheduler.Scheduler.prepare_task') as mock_prepare_task, \
- patch('neural_solution.backend.scheduler.Scheduler.prepare_env') as mock_prepare_env, \
- patch('neural_solution.backend.scheduler.logger.info') as mock_logger_info, \
- patch('builtins.open', create=True) as mock_open, \
- patch('neural_solution.backend.scheduler.os.path.join') as mock_os_path_join:
-
+ expected_cmd = (
+ "cd /path/to/task\nmpirun -np 3 -host node1,node2,node3 -map-by socket:pe=5"
+ + " -mca btl_tcp_if_include 192.168.20.0/24 -x OMP_NUM_THREADS=5 --report-bindings bash distributed_run.sh"
+ )
+ with patch("neural_solution.backend.scheduler.Scheduler.prepare_task") as mock_prepare_task, patch(
+ "neural_solution.backend.scheduler.Scheduler.prepare_env"
+ ) as mock_prepare_env, patch("neural_solution.backend.scheduler.logger.info") as mock_logger_info, patch(
+ "builtins.open", create=True
+ ) as mock_open, patch(
+ "neural_solution.backend.scheduler.os.path.join"
+ ) as mock_os_path_join:
mock_prepare_task.return_value = None
- mock_prepare_env.return_value = 'test_env'
+ mock_prepare_env.return_value = "test_env"
mock_logger_info.return_value = None
mock_open.return_value.__enter__ = lambda x: x
mock_open.return_value.__exit__ = MagicMock()
- mock_os_path_join.return_value = '/path/to/task/distributed_run.sh'
+ mock_os_path_join.return_value = "/path/to/task/distributed_run.sh"
result = self.task_scheduler._parse_cmd(self.task, self.resource)
self.assertEqual(result, expected_cmd)
mock_prepare_task.assert_called_once_with(self.task)
mock_prepare_env.assert_called_once_with(self.task)
- mock_logger_info.assert_called_once_with('[TaskScheduler] host resource: node1,node2,node3')
- mock_open.assert_called_once_with('/path/to/task/distributed_run.sh', 'w', encoding='utf-8')
- mock_os_path_join.assert_called_once_with('/path/to/task', 'distributed_run.sh')
+ mock_logger_info.assert_called_once_with("[TaskScheduler] host resource: node1,node2,node3")
+ mock_open.assert_called_once_with("/path/to/task/distributed_run.sh", "w", encoding="utf-8")
+ mock_os_path_join.assert_called_once_with("/path/to/task", "distributed_run.sh")
self.task.optimized = False
result = self.task_scheduler._parse_cmd(self.task, self.resource)
self.assertEqual(result, expected_cmd)
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_task.py b/neural_solution/test/backend/test_task.py
index 55073c92946..43d3c649ec5 100644
--- a/neural_solution/test/backend/test_task.py
+++ b/neural_solution/test/backend/test_task.py
@@ -1,15 +1,13 @@
import unittest
+
from neural_solution.backend.task import Task
+
class TestTask(unittest.TestCase):
def setUp(self):
- self.task = Task("123",
- "python script.py",
- 4, "pending",
- "http://example.com/script.py",
- True,
- "approach",
- "requirement")
+ self.task = Task(
+ "123", "python script.py", 4, "pending", "http://example.com/script.py", True, "approach", "requirement"
+ )
def test_task_attributes(self):
self.assertEqual(self.task.task_id, "123")
@@ -23,5 +21,6 @@ def test_task_attributes(self):
self.assertEqual(self.task.result, "")
self.assertEqual(self.task.q_model_path, "")
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_task_db.py b/neural_solution/test/backend/test_task_db.py
index 2b1c5b6f28b..38e9690a8e7 100644
--- a/neural_solution/test/backend/test_task_db.py
+++ b/neural_solution/test/backend/test_task_db.py
@@ -1,29 +1,21 @@
-
-import unittest
-from unittest.mock import patch, MagicMock
-from neural_solution.backend.task_db import TaskDB, Task
-import shutil
import os
+import shutil
+import unittest
+from unittest.mock import MagicMock, patch
+from neural_solution.backend.task_db import Task, TaskDB
from neural_solution.utils.utility import get_db_path
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE)
-class TestTaskDB(unittest.TestCase):
+class TestTaskDB(unittest.TestCase):
def setUp(self):
self.taskdb = TaskDB(db_path=db_path)
self.task = Task(
- '1',
- 'arguments',
- 1,
- 'pending',
- 'script_url',
- 0, 'approach',
- 'requirement',
- 'result',
- 'q_model_path')
+ "1", "arguments", 1, "pending", "script_url", 0, "approach", "requirement", "result", "q_model_path"
+ )
@classmethod
def tearDownClass(cls) -> None:
@@ -32,71 +24,85 @@ def tearDownClass(cls) -> None:
def test_append_task(self):
self.taskdb.append_task(self.task)
self.assertEqual(len(self.taskdb.task_queue), 1)
- self.assertEqual(self.taskdb.task_queue[0], '1')
+ self.assertEqual(self.taskdb.task_queue[0], "1")
def test_get_pending_task_num(self):
self.taskdb.append_task(self.task)
self.assertEqual(self.taskdb.get_pending_task_num(), 1)
def test_get_all_pending_tasks(self):
- self.taskdb.cursor.execute("insert into task values ('2', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('2', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
pending_tasks = self.taskdb.get_all_pending_tasks()
self.assertEqual(len(pending_tasks), 1)
- self.assertEqual(pending_tasks[0].task_id, '2')
+ self.assertEqual(pending_tasks[0].task_id, "2")
def test_update_task_status(self):
- self.taskdb.cursor.execute("insert into task values ('3', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('3', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
- self.taskdb.update_task_status('3', 'running')
+ self.taskdb.update_task_status("3", "running")
self.taskdb.cursor.execute("select status from task where id='3'")
status = self.taskdb.cursor.fetchone()[0]
- self.assertEqual(status, 'running')
+ self.assertEqual(status, "running")
with self.assertRaises(Exception):
- self.taskdb.update_task_status('3', 'invalid_status')
+ self.taskdb.update_task_status("3", "invalid_status")
def test_update_result(self):
- self.taskdb.cursor.execute("insert into task values ('4', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('4', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
- self.taskdb.update_result('4', 'new_result')
+ self.taskdb.update_result("4", "new_result")
self.taskdb.cursor.execute("select result from task where id='4'")
result = self.taskdb.cursor.fetchone()[0]
- self.assertEqual(result, 'new_result')
+ self.assertEqual(result, "new_result")
def test_update_q_model_path_and_result(self):
- self.taskdb.cursor.execute("insert into task values ('5', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('5', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
- self.taskdb.update_q_model_path_and_result('5', 'new_q_model_path', 'new_result')
+ self.taskdb.update_q_model_path_and_result("5", "new_q_model_path", "new_result")
self.taskdb.cursor.execute("select q_model_path, result from task where id='5'")
q_model_path, result = self.taskdb.cursor.fetchone()
- self.assertEqual(q_model_path, 'new_q_model_path')
- self.assertEqual(result, 'new_result')
+ self.assertEqual(q_model_path, "new_q_model_path")
+ self.assertEqual(result, "new_result")
def test_lookup_task_status(self):
- self.taskdb.cursor.execute("insert into task values ('6', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('6', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
- status_dict = self.taskdb.lookup_task_status('6')
- self.assertEqual(status_dict['status'], 'pending')
- self.assertEqual(status_dict['result'], 'result')
+ status_dict = self.taskdb.lookup_task_status("6")
+ self.assertEqual(status_dict["status"], "pending")
+ self.assertEqual(status_dict["result"], "result")
def test_get_task_by_id(self):
- self.taskdb.cursor.execute("insert into task values ('7', 'arguments', 1, \
- 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')")
+ self.taskdb.cursor.execute(
+ "insert into task values ('7', 'arguments', 1, \
+ 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')"
+ )
self.taskdb.conn.commit()
- task = self.taskdb.get_task_by_id('7')
- self.assertEqual(task.task_id, '7')
- self.assertEqual(task.arguments, 'arguments')
+ task = self.taskdb.get_task_by_id("7")
+ self.assertEqual(task.task_id, "7")
+ self.assertEqual(task.arguments, "arguments")
self.assertEqual(task.workers, 1)
- self.assertEqual(task.status, 'pending')
- self.assertEqual(task.result, 'result')
+ self.assertEqual(task.status, "pending")
+ self.assertEqual(task.result, "result")
def test_remove_task(self):
- self.taskdb.remove_task('1')
+ self.taskdb.remove_task("1")
# currently no garbage collection, so this function does nothing
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/test_task_monitor.py b/neural_solution/test/backend/test_task_monitor.py
index 68ba7ee40bb..1e4bb069726 100644
--- a/neural_solution/test/backend/test_task_monitor.py
+++ b/neural_solution/test/backend/test_task_monitor.py
@@ -1,8 +1,10 @@
+import threading
import unittest
-from unittest.mock import Mock, patch, MagicMock
-from neural_solution.backend.task_monitor import TaskMonitor
+from unittest.mock import MagicMock, Mock, patch
+
from neural_solution.backend.task import Task
-import threading
+from neural_solution.backend.task_monitor import TaskMonitor
+
class TestTaskMonitor(unittest.TestCase):
def setUp(self):
@@ -18,37 +20,66 @@ def test__start_listening(self):
mock_socket.return_value.bind = mock_bind
mock_socket.return_value.listen = mock_listen
self.task_monitor._start_listening("localhost", 8888, 10)
-
+
def test_receive_task(self):
- self.mock_socket.accept.return_value = (Mock(), b'{"task_id": 123, "arguments": {}, "workers": 1, \
+ self.mock_socket.accept.return_value = (
+ Mock(),
+ b'{"task_id": 123, "arguments": {}, "workers": 1, \
"status": "pending", "script_url": "http://example.com", "optimized": True, \
- "approach": "static", "requirement": "neural_solution", "result": "", "q_model_path": ""}')
+ "approach": "static", "requirement": "neural_solution", "result": "", "q_model_path": ""}',
+ )
self.mock_task_db.get_task_by_id.return_value = Task(
- task_id=123, arguments={}, workers=1,
- status="pending", script_url="http://example.com", optimized=True,
- approach="static", requirement="neural_solution", result="", q_model_path="")
-
+ task_id=123,
+ arguments={},
+ workers=1,
+ status="pending",
+ script_url="http://example.com",
+ optimized=True,
+ approach="static",
+ requirement="neural_solution",
+ result="",
+ q_model_path="",
+ )
+
# Test normal task case
- with patch('neural_solution.backend.task_monitor.deserialize', return_value={
- "task_id": 123, "arguments": {}, "workers": 1,
- "status": "pending", "script_url": "http://example.com",
- "optimized": True, "approach": "static",
- "requirement": "neural_solution", "result": "", "q_model_path": ""}):
+ with patch(
+ "neural_solution.backend.task_monitor.deserialize",
+ return_value={
+ "task_id": 123,
+ "arguments": {},
+ "workers": 1,
+ "status": "pending",
+ "script_url": "http://example.com",
+ "optimized": True,
+ "approach": "static",
+ "requirement": "neural_solution",
+ "result": "",
+ "q_model_path": "",
+ },
+ ):
task = self.task_monitor._receive_task()
self.assertEqual(task.task_id, 123)
self.mock_task_db.get_task_by_id.assert_called_once_with(123)
-
+
# Test ping case
- with patch('neural_solution.backend.task_monitor.deserialize', return_value={"ping": "test"}):
+ with patch("neural_solution.backend.task_monitor.deserialize", return_value={"ping": "test"}):
response = self.task_monitor._receive_task()
self.assertEqual(response, False)
self.mock_task_db.get_task_by_id.assert_called_once_with(123)
def test_append_task(self):
task = Task(
- task_id=123, arguments={}, workers=1,
- status="pending", script_url="http://example.com", optimized=True,
- approach="static", requirement="neural_solution", result="", q_model_path="")
+ task_id=123,
+ arguments={},
+ workers=1,
+ status="pending",
+ script_url="http://example.com",
+ optimized=True,
+ approach="static",
+ requirement="neural_solution",
+ result="",
+ q_model_path="",
+ )
self.task_monitor._append_task(task)
self.mock_task_db.append_task.assert_called_once_with(task)
@@ -61,7 +92,7 @@ def test_wait_new_task(self):
self.task_monitor._receive_task = mock_receive_task
self.task_monitor._append_task = mock_append_task
self.task_monitor._start_listening = MagicMock()
-
+
# Call the function to be tested
adding_abort = threading.Thread(
target=self.task_monitor.wait_new_task,
@@ -70,13 +101,14 @@ def test_wait_new_task(self):
)
adding_abort.start()
adding_abort.join(timeout=1)
-
+
# Test task is False
mock_receive_task = MagicMock(return_value=False)
mock_append_task = MagicMock()
self.task_monitor._receive_task = mock_receive_task
-
+
adding_abort.join(timeout=1)
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/backend/utils/test_utility.py b/neural_solution/test/backend/utils/test_utility.py
index 8816179a2a6..296e2dc3e44 100644
--- a/neural_solution/test/backend/utils/test_utility.py
+++ b/neural_solution/test/backend/utils/test_utility.py
@@ -1,28 +1,36 @@
-import unittest
import os
import shutil
-from unittest.mock import patch, MagicMock, mock_open
-from neural_solution.backend.utils.utility import (
- serialize, deserialize,
- dump_elapsed_time, get_task_log_path, build_local_cluster,
- build_cluster, get_current_time,
- synchronized, build_workspace, is_remote_url, create_dir,
- get_q_model_path)
+import unittest
+from unittest.mock import MagicMock, mock_open, patch
-from neural_solution.utils.utility import get_task_workspace, get_task_log_workspace, get_db_path
+from neural_solution.backend.utils.utility import (
+ build_cluster,
+ build_local_cluster,
+ build_workspace,
+ create_dir,
+ deserialize,
+ dump_elapsed_time,
+ get_current_time,
+ get_q_model_path,
+ get_task_log_path,
+ is_remote_url,
+ serialize,
+ synchronized,
+)
from neural_solution.config import config
+from neural_solution.utils.utility import get_db_path, get_task_log_workspace, get_task_workspace
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db"
-TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
+TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log"
SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log"
config.workspace = NEURAL_SOLUTION_WORKSPACE
db_path = get_db_path(config.workspace)
-class TestUtils(unittest.TestCase):
+class TestUtils(unittest.TestCase):
@classmethod
def tearDown(self) -> None:
if os.path.exists("ns_workspace"):
@@ -42,24 +50,28 @@ def test_dump_elapsed_time(self):
@dump_elapsed_time("test function")
def test_function():
return True
- with patch('neural_solution.utils.logger') as mock_logger:
+
+ with patch("neural_solution.utils.logger") as mock_logger:
test_function()
def test_get_task_log_path(self):
task_id = 123
expected_output = f"{TASK_LOG_path}/task_{task_id}.txt"
- self.assertEqual(get_task_log_path(log_path=get_task_log_workspace(config.workspace), task_id=task_id), expected_output)
+ self.assertEqual(
+ get_task_log_path(log_path=get_task_log_workspace(config.workspace), task_id=task_id), expected_output
+ )
def test_build_local_cluster(self):
- with patch('neural_solution.backend.cluster.Node') as mock_node, \
- patch('neural_solution.backend.cluster.Cluster') as mock_cluster:
+ with patch("neural_solution.backend.cluster.Node") as mock_node, patch(
+ "neural_solution.backend.cluster.Cluster"
+ ) as mock_cluster:
mock_node_obj = MagicMock()
mock_node.return_value = mock_node_obj
- mock_node_obj.name = 'localhost'
+ mock_node_obj.name = "localhost"
mock_node_obj.num_sockets = 2
mock_node_obj.num_cores_per_socket = 5
build_local_cluster(db_path=db_path)
- mock_node.assert_called_with(name='localhost', num_sockets=2, num_cores_per_socket=5)
+ mock_node.assert_called_with(name="localhost", num_sockets=2, num_cores_per_socket=5)
mock_cluster.assert_called_once()
def test_build_cluster(self):
@@ -73,11 +85,9 @@ def test_build_cluster(self):
os.remove("test.txt")
file_path = "test_host_file"
- with patch('neural_solution.backend.cluster.Node') as mock_node, \
- patch('neural_solution.backend.cluster.Cluster') as mock_cluster, \
- patch('builtins.open') as mock_open, \
- patch('os.path.exists') as mock_exists:
-
+ with patch("neural_solution.backend.cluster.Node") as mock_node, patch(
+ "neural_solution.backend.cluster.Cluster"
+ ) as mock_cluster, patch("builtins.open") as mock_open, patch("os.path.exists") as mock_exists:
# Test None
cluster, _ = build_cluster(file_path=None, db_path=db_path)
mock_cluster.assert_called()
@@ -87,11 +97,9 @@ def test_build_cluster(self):
# test_build_cluster_file_not_exist
file_path = "test_file"
- with patch('neural_solution.backend.cluster.Node'), \
- patch('neural_solution.backend.cluster.Cluster'), \
- patch('builtins.open'), \
- patch('os.path.exists') as mock_exists, \
- patch('neural_solution.utils.logger') as mock_logger:
+ with patch("neural_solution.backend.cluster.Node"), patch("neural_solution.backend.cluster.Cluster"), patch(
+ "builtins.open"
+ ), patch("os.path.exists") as mock_exists, patch("neural_solution.utils.logger") as mock_logger:
mock_exists.return_value = False
self.assertRaises(Exception, build_cluster, file_path)
mock_logger.reset_mock()
@@ -103,17 +111,19 @@ def test_synchronized(self):
class TestClass:
def __init__(self):
self.lock = MagicMock()
+
@synchronized
def test_function(self):
return True
+
test_class = TestClass()
- with patch.object(test_class, 'lock'):
+ with patch.object(test_class, "lock"):
test_class.test_function()
def test_build_workspace(self):
task_id = 123
expected_output = os.path.abspath(f"{TASK_WORKSPACE}/{task_id}")
- self.assertEqual(build_workspace(path=get_task_workspace(config.workspace) ,task_id=task_id), expected_output)
+ self.assertEqual(build_workspace(path=get_task_workspace(config.workspace), task_id=task_id), expected_output)
def test_is_remote_url(self):
self.assertTrue(is_remote_url("http://test.com"))
@@ -125,17 +135,18 @@ def test_create_dir(self):
create_dir(path)
self.assertTrue(os.path.exists(os.path.dirname(path)))
- @patch('builtins.open', mock_open(read_data='Save quantized model to /path/to/model.'))
+ @patch("builtins.open", mock_open(read_data="Save quantized model to /path/to/model."))
def test_get_q_model_path_success(self):
- log_path = 'fake_log_path'
+ log_path = "fake_log_path"
q_model_path = get_q_model_path(log_path)
- self.assertEqual(q_model_path, '/path/to/model')
+ self.assertEqual(q_model_path, "/path/to/model")
- @patch('builtins.open', mock_open(read_data='No quantized model saved.'))
+ @patch("builtins.open", mock_open(read_data="No quantized model saved."))
def test_get_q_model_path_failure(self):
- log_path = 'fake_log_path'
+ log_path = "fake_log_path"
q_model_path = get_q_model_path(log_path)
- self.assertEqual(q_model_path, 'quantized model path not found')
+ self.assertEqual(q_model_path, "quantized model path not found")
+
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/neural_solution/test/frontend/fastapi/test_main_server.py b/neural_solution/test/frontend/fastapi/test_main_server.py
index 0f3473019b2..c101e616216 100644
--- a/neural_solution/test/frontend/fastapi/test_main_server.py
+++ b/neural_solution/test/frontend/fastapi/test_main_server.py
@@ -1,18 +1,19 @@
-import unittest
import asyncio
-from unittest.mock import patch, Mock, MagicMock
-from fastapi.testclient import TestClient
-from fastapi import WebSocket
-from neural_solution.frontend.fastapi.main_server import app, LogEventHandler, start_log_watcher, Observer
-import sqlite3
import os
import shutil
+import sqlite3
+import unittest
+from unittest.mock import MagicMock, Mock, patch
+
+from fastapi import WebSocket
+from fastapi.testclient import TestClient
from neural_solution.config import config
+from neural_solution.frontend.fastapi.main_server import LogEventHandler, Observer, app, start_log_watcher
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db"
-TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
+TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log"
SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log"
@@ -22,38 +23,47 @@
def build_db():
if not os.path.exists(DB_PATH):
os.makedirs(DB_PATH)
- conn = sqlite3.connect(f'{DB_PATH}/task.db', check_same_thread=False) # sqlite should set this check_same_thread to False
+ conn = sqlite3.connect(
+ f"{DB_PATH}/task.db", check_same_thread=False
+ ) # sqlite should set this check_same_thread to False
cursor = conn.cursor()
cursor.execute(
- 'create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), ' +
- 'workers int, status varchar(20), script_url varchar(500), optimized integer, ' +
- 'approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))')
- cursor.execute('drop table if exists cluster ')
- cursor.execute(r'create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,' +
- 'node_info varchar(500),' +
- 'status varchar(100),' +
- 'free_sockets int,' +
- 'busy_sockets int,' +
- 'total_sockets int)')
+ "create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), "
+ + "workers int, status varchar(20), script_url varchar(500), optimized integer, "
+ + "approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))"
+ )
+ cursor.execute("drop table if exists cluster ")
+ cursor.execute(
+ r"create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,"
+ + "node_info varchar(500),"
+ + "status varchar(100),"
+ + "free_sockets int,"
+ + "busy_sockets int,"
+ + "total_sockets int)"
+ )
conn.commit()
conn.close
+
def delete_db():
if os.path.exists(DB_PATH):
shutil.rmtree(DB_PATH)
+
def use_db():
def f(func):
def fi(*args, **kwargs):
build_db()
res = func(*args, **kwargs)
delete_db()
+
return fi
+
return f
-class TestMain(unittest.TestCase):
+class TestMain(unittest.TestCase):
@classmethod
def setUpClass(self):
if not os.path.exists(TASK_LOG_path):
@@ -69,7 +79,7 @@ def test_read_root(self):
assert response.status_code == 200
self.assertEqual(response.json(), {"message": "Welcome to Neural Solution!"})
- @patch('neural_solution.frontend.fastapi.main_server.socket')
+ @patch("neural_solution.frontend.fastapi.main_server.socket")
def test_ping(self, mock_socket):
response = client.get("/ping")
self.assertEqual(response.status_code, 200)
@@ -92,11 +102,12 @@ def test_get_description(self):
data = {
"description": "",
}
- path = "../../doc"
+ path = "../../doc"
if not os.path.exists(path):
os.makedirs(path)
with open(os.path.join(path, "user_facing_api.json"), "w") as f:
import json
+
json.dump(data, f)
response = client.get("/description")
assert response.status_code == 200
@@ -111,7 +122,7 @@ def test_submit_task(self, mock_submit_task):
"arguments": ["arg1", "arg2"],
"approach": "approach1",
"requirements": ["req1", "req2"],
- "workers": 3
+ "workers": 3,
}
# test no db case
@@ -132,7 +143,6 @@ def test_submit_task(self, mock_submit_task):
self.assertIn("successfully", response.json()["status"])
mock_submit_task.assert_called_once()
-
# test ConnectionRefusedError case
mock_submit_task.side_effect = ConnectionRefusedError
response = client.post("/task/submit/", json=task)
@@ -166,7 +176,7 @@ def test_get_task_by_id(self, mock_submit_task):
"arguments": ["arg1", "arg2"],
"approach": "approach1",
"requirements": ["req1", "req2"],
- "workers": 3
+ "workers": 3,
}
response = client.post("/task/submit/", json=task)
task_id = response.json()["task_id"]
@@ -192,7 +202,7 @@ def test_get_task_status_by_id(self, mock_submit_task):
"arguments": ["arg1", "arg2"],
"approach": "approach1",
"requirements": ["req1", "req2"],
- "workers": 3
+ "workers": 3,
}
response = client.post("/task/submit/", json=task)
task_id = response.json()["task_id"]
@@ -214,8 +224,8 @@ def test_read_logs(self):
self.assertIn(task_id, response.text)
os.remove(log_path)
-class TestLogEventHandler(unittest.TestCase):
+class TestLogEventHandler(unittest.TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
@@ -231,6 +241,7 @@ def test_init(self):
def test_on_modified(self):
from neural_solution.config import config
+
config.workspace = NEURAL_SOLUTION_WORKSPACE
mock_websocket = MagicMock()
mock_websocket.send_text = MagicMock()
@@ -249,7 +260,7 @@ def test_on_modified(self):
task_id = "1234"
log_path = f"{TASK_LOG_path}/task_{task_id}.txt"
event.src_path = log_path
- with patch('builtins.open', MagicMock()) as mock_file:
+ with patch("builtins.open", MagicMock()) as mock_file:
mock_file.return_value.__enter__.return_value.seek.return_value = None
mock_file.return_value.__enter__.return_value.readlines.return_value = ["test line"]
handler.on_modified(event)
@@ -259,20 +270,20 @@ def test_on_modified(self):
# handler.queue.put_nowait.assert_called_once_with("test line")
os.remove(log_path)
-class TestStartLogWatcher(unittest.TestCase):
+class TestStartLogWatcher(unittest.TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
def test_start_log_watcher(self):
mock_observer = MagicMock()
mock_observer.schedule = MagicMock()
- with patch('neural_solution.frontend.fastapi.main_server.Observer', MagicMock(return_value=mock_observer)):
+ with patch("neural_solution.frontend.fastapi.main_server.Observer", MagicMock(return_value=mock_observer)):
observer = start_log_watcher("test_websocket", "1234", 0)
self.assertIsInstance(observer, type(mock_observer))
-class TestWebsocketEndpoint(unittest.TestCase):
+class TestWebsocketEndpoint(unittest.TestCase):
def setUp(self):
self.loop = asyncio.get_event_loop()
self.client = TestClient(app)
@@ -282,5 +293,6 @@ def test_websocket_endpoint(self):
# with self.assertRaises(HTTPException):
# asyncio.run(websocket_endpoint(WebSocket, "nonexistent_task"))
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/frontend/fastapi/test_task_submitter.py b/neural_solution/test/frontend/fastapi/test_task_submitter.py
index 48a0f312662..c08c2cd605e 100644
--- a/neural_solution/test/frontend/fastapi/test_task_submitter.py
+++ b/neural_solution/test/frontend/fastapi/test_task_submitter.py
@@ -1,7 +1,9 @@
+import socket
import unittest
from unittest.mock import patch
-import socket
-from neural_solution.frontend.task_submitter import TaskSubmitter, Task
+
+from neural_solution.frontend.task_submitter import Task, TaskSubmitter
+
class TestTask(unittest.TestCase):
def test_task_creation(self):
@@ -18,7 +20,7 @@ def test_task_creation(self):
arguments=arguments,
approach=approach,
requirements=requirements,
- workers=workers
+ workers=workers,
)
self.assertEqual(task.script_url, script_url)
@@ -28,15 +30,17 @@ def test_task_creation(self):
self.assertEqual(task.requirements, requirements)
self.assertEqual(task.workers, workers)
+
class TestTaskSubmitter(unittest.TestCase):
- @patch('socket.socket')
+ @patch("socket.socket")
def test_submit_task(self, mock_socket):
task_submitter = TaskSubmitter()
- task_id = '1234'
+ task_id = "1234"
task_submitter.submit_task(task_id)
- mock_socket.return_value.connect.assert_called_once_with(('localhost', 2222))
+ mock_socket.return_value.connect.assert_called_once_with(("localhost", 2222))
mock_socket.return_value.send.assert_called_once_with(b'{"task_id": "1234"}')
mock_socket.return_value.close.assert_called_once()
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/frontend/fastapi/test_utils.py b/neural_solution/test/frontend/fastapi/test_utils.py
index d60364f0d26..abb57fff92b 100644
--- a/neural_solution/test/frontend/fastapi/test_utils.py
+++ b/neural_solution/test/frontend/fastapi/test_utils.py
@@ -1,21 +1,27 @@
-import unittest
import os
import shutil
-from unittest.mock import patch, mock_open
+import unittest
+from unittest.mock import mock_open, patch
from neural_solution.frontend.utility import (
- serialize, deserialize,
- get_cluster_info,get_cluster_table,
- get_res_during_tuning, get_baseline_during_tuning,
- check_log_exists, list_to_string)
+ check_log_exists,
+ deserialize,
+ get_baseline_during_tuning,
+ get_cluster_info,
+ get_cluster_table,
+ get_res_during_tuning,
+ list_to_string,
+ serialize,
+)
+
NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace")
DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db/task.db"
-TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
+TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace"
TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log"
SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log"
-class TestMyModule(unittest.TestCase):
+class TestMyModule(unittest.TestCase):
@classmethod
def setUpClass(self):
if not os.path.exists(TASK_LOG_path):
@@ -35,39 +41,41 @@ def test_deserialize(self):
expected_result = {"key": "value"}
self.assertEqual(deserialize(request), expected_result)
- @patch('sqlite3.connect')
+ @patch("sqlite3.connect")
def test_get_cluster_info(self, mock_connect):
mock_cursor = mock_connect().cursor.return_value
mock_cursor.fetchall.return_value = [(1, "node info", "status", 1, 2, 3)]
expected_result = {"Cluster info": [(1, "node info", "status", 1, 2, 3)]}
self.assertEqual(get_cluster_info(TASK_LOG_path), expected_result)
- @patch('sqlite3.connect')
+ @patch("sqlite3.connect")
def test_get_cluster_table(self, mock_connect):
mock_cursor = mock_connect().cursor.return_value
mock_cursor.fetchall.return_value = [(1, "node info", "status", 1, 2, 3)]
- expected_result = ('
\n'
- ' \n'
- ' \n'
- ' Node | \n'
- ' Node info | \n'
- ' status | \n'
- ' free workers | \n'
- ' busy workers | \n'
- ' total workers | \n'
- '
\n'
- ' \n'
- ' \n'
- ' \n'
- ' 1 | \n'
- ' node info | \n'
- ' status | \n'
- ' 1 | \n'
- ' 2 | \n'
- ' 3 | \n'
- '
\n'
- ' \n'
- '
')
+ expected_result = (
+ '\n'
+ " \n"
+ ' \n'
+ " Node | \n"
+ " Node info | \n"
+ " status | \n"
+ " free workers | \n"
+ " busy workers | \n"
+ " total workers | \n"
+ "
\n"
+ " \n"
+ " \n"
+ " \n"
+ " 1 | \n"
+ " node info | \n"
+ " status | \n"
+ " 1 | \n"
+ " 2 | \n"
+ " 3 | \n"
+ "
\n"
+ " \n"
+ "
"
+ )
self.assertEqual(get_cluster_table(TASK_LOG_path), expected_result)
def test_get_res_during_tuning(self):
@@ -102,5 +110,6 @@ def test_list_to_string(self):
expected_result = "Hello Neural Solution"
self.assertEqual(list_to_string(lst), expected_result)
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/neural_solution/test/test_logger.py b/neural_solution/test/test_logger.py
index db1df1d4c78..cd184a12e48 100644
--- a/neural_solution/test/test_logger.py
+++ b/neural_solution/test/test_logger.py
@@ -1,7 +1,9 @@
"""Tests for logging utilities."""
-from neural_solution.utils import logger
import unittest
+from neural_solution.utils import logger
+
+
class TestLogger(unittest.TestCase):
def test_logger(self):
logger.log(0, "call logger log function.")
@@ -20,15 +22,13 @@ def test_logger(self):
logger.warning({"msg": "call logger warning function"})
logger.warning(["call logger warning function", "done"])
logger.warning(("call logger warning function", "done"))
- logger.warning({"msg": {('bert', "embedding"): {'weight': {'dtype': ['unint8', 'int8']}}}})
- logger.warning({"msg": {('bert', "embedding"): {'op': ('a', 'b')}}})
+ logger.warning({"msg": {("bert", "embedding"): {"weight": {"dtype": ["unint8", "int8"]}}}})
+ logger.warning({"msg": {("bert", "embedding"): {"op": ("a", "b")}}})
# the following log will not be prettified
logger.warning([{"msg": "call logger warning function"}, {"msg2": "done"}])
logger.warning(({"msg": "call logger warning function"}, {"msg2": "done"}))
- logger.warning(({"msg": [{"sub_msg":"call logger"},
- {"sub_msg2":"call warning function"}]},
- {"msg2": "done"}))
+ logger.warning(({"msg": [{"sub_msg": "call logger"}, {"sub_msg2": "call warning function"}]}, {"msg2": "done"}))
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/neural_solution/utils/__init__.py b/neural_solution/utils/__init__.py
index 3f778a4dcd7..0415332cb3f 100644
--- a/neural_solution/utils/__init__.py
+++ b/neural_solution/utils/__init__.py
@@ -15,4 +15,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""All common functions for both backend and frontend."""
\ No newline at end of file
+"""All common functions for both backend and frontend."""
diff --git a/neural_solution/utils/logger.py b/neural_solution/utils/logger.py
index 2bfbbba7dc7..03c8b86be56 100644
--- a/neural_solution/utils/logger.py
+++ b/neural_solution/utils/logger.py
@@ -17,8 +17,8 @@
"""Logger: handles logging functionalities."""
-import os
import logging
+import os
class Logger(object):
@@ -35,13 +35,11 @@ def __new__(cls):
def _log(self):
"""Set up the logger format and handler."""
- LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
+ LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
self._logger = logging.getLogger("neural_compressor")
self._logger.handlers.clear()
self._logger.setLevel(LOGLEVEL)
- formatter = logging.Formatter(
- '%(asctime)s [%(levelname)s] %(message)s',
- "%Y-%m-%d %H:%M:%S")
+ formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S")
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
self._logger.addHandler(streamHandler)
@@ -54,25 +52,16 @@ def get_logger(self):
def _pretty_dict(value, indent=0):
"""Make the logger dict pretty."""
- prefix = '\n' + ' ' * (indent + 4)
+ prefix = "\n" + " " * (indent + 4)
if isinstance(value, dict):
- items = [
- prefix + repr(key) + ': ' + _pretty_dict(value[key], indent + 4)
- for key in value
- ]
- return '{%s}' % (','.join(items) + '\n' + ' ' * indent)
+ items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value]
+ return "{%s}" % (",".join(items) + "\n" + " " * indent)
elif isinstance(value, list):
- items = [
- prefix + _pretty_dict(item, indent + 4)
- for item in value
- ]
- return '[%s]' % (','.join(items) + '\n' + ' ' * indent)
+ items = [prefix + _pretty_dict(item, indent + 4) for item in value]
+ return "[%s]" % (",".join(items) + "\n" + " " * indent)
elif isinstance(value, tuple):
- items = [
- prefix + _pretty_dict(item, indent + 4)
- for item in value
- ]
- return '(%s)' % (','.join(items) + '\n' + ' ' * indent)
+ items = [prefix + _pretty_dict(item, indent + 4) for item in value]
+ return "(%s)" % (",".join(items) + "\n" + " " * indent)
else:
return repr(value)
@@ -84,7 +73,7 @@ def _pretty_dict(value, indent=0):
def log(level, msg, *args, **kwargs):
"""Output log with the level as a parameter."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().log(level, line, *args, **kwargs)
else:
Logger().get_logger().log(level, msg, *args, **kwargs)
@@ -93,7 +82,7 @@ def log(level, msg, *args, **kwargs):
def debug(msg, *args, **kwargs):
"""Output log with the debug level."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().debug(line, *args, **kwargs)
else:
Logger().get_logger().debug(msg, *args, **kwargs)
@@ -102,7 +91,7 @@ def debug(msg, *args, **kwargs):
def error(msg, *args, **kwargs):
"""Output log with the error level."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().error(line, *args, **kwargs)
else:
Logger().get_logger().error(msg, *args, **kwargs)
@@ -111,7 +100,7 @@ def error(msg, *args, **kwargs):
def fatal(msg, *args, **kwargs):
"""Output log with the fatal level."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().fatal(line, *args, **kwargs)
else:
Logger().get_logger().fatal(msg, *args, **kwargs)
@@ -120,7 +109,7 @@ def fatal(msg, *args, **kwargs):
def info(msg, *args, **kwargs):
"""Output log with the info level."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().info(line, *args, **kwargs)
else:
Logger().get_logger().info(msg, *args, **kwargs)
@@ -129,7 +118,7 @@ def info(msg, *args, **kwargs):
def warn(msg, *args, **kwargs):
"""Output log with the warning level."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().warning(line, *args, **kwargs)
else:
Logger().get_logger().warning(msg, *args, **kwargs)
@@ -138,7 +127,7 @@ def warn(msg, *args, **kwargs):
def warning(msg, *args, **kwargs):
"""Output log with the warining level (Alias of the method warn)."""
if isinstance(msg, dict):
- for _, line in enumerate(_pretty_dict(msg).split('\n')):
+ for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().warning(line, *args, **kwargs)
else:
- Logger().get_logger().warning(msg, *args, **kwargs)
\ No newline at end of file
+ Logger().get_logger().warning(msg, *args, **kwargs)
diff --git a/neural_solution/utils/utility.py b/neural_solution/utils/utility.py
index 8c3999e343d..13ef5da1558 100644
--- a/neural_solution/utils/utility.py
+++ b/neural_solution/utils/utility.py
@@ -14,8 +14,9 @@
"""Neural Solution utility."""
-import os
import json
+import os
+
def get_db_path(workspace="./"):
"""Get the database path.
@@ -29,6 +30,7 @@ def get_db_path(workspace="./"):
db_path = os.path.join(workspace, "db", "task.db")
return os.path.abspath(db_path)
+
def get_task_workspace(workspace="./"):
"""Get the workspace of task.
@@ -40,6 +42,7 @@ def get_task_workspace(workspace="./"):
"""
return os.path.join(workspace, "task_workspace")
+
def get_task_log_workspace(workspace="./"):
"""Get the log workspace for task.
@@ -51,6 +54,7 @@ def get_task_log_workspace(workspace="./"):
"""
return os.path.join(workspace, "task_log")
+
def get_serve_log_workspace(workspace="./"):
"""Get log workspace for service.
@@ -62,6 +66,7 @@ def get_serve_log_workspace(workspace="./"):
"""
return os.path.join(workspace, "serve_log")
+
def dict_to_str(d):
"""Covert a dict object to a string object.
@@ -72,4 +77,4 @@ def dict_to_str(d):
str: string
"""
result = json.dumps(d)
- return result
\ No newline at end of file
+ return result