diff --git a/neural_solution/__init__.py b/neural_solution/__init__.py index 7017c64526d..300d6010498 100644 --- a/neural_solution/__init__.py +++ b/neural_solution/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """Neural Solution.""" -from neural_solution.utils import logger \ No newline at end of file +from neural_solution.utils import logger diff --git a/neural_solution/backend/__init__.py b/neural_solution/backend/__init__.py index 07d2f346fcd..c93847d4e0b 100644 --- a/neural_solution/backend/__init__.py +++ b/neural_solution/backend/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. """Neural Solution backend.""" +from neural_solution.backend.cluster import Cluster from neural_solution.backend.result_monitor import ResultMonitor from neural_solution.backend.scheduler import Scheduler -from neural_solution.backend.task_monitor import TaskMonitor -from neural_solution.backend.cluster import Cluster from neural_solution.backend.task_db import TaskDB +from neural_solution.backend.task_monitor import TaskMonitor diff --git a/neural_solution/backend/cluster.py b/neural_solution/backend/cluster.py index c3eb2aa5501..cee9a23dc9a 100644 --- a/neural_solution/backend/cluster.py +++ b/neural_solution/backend/cluster.py @@ -13,12 +13,14 @@ # limitations under the License. """Neural Solution cluster.""" -import threading import sqlite3 +import threading +from collections import Counter from typing import List -from neural_solution.backend.utils.utility import synchronized, create_dir + +from neural_solution.backend.utils.utility import create_dir, synchronized from neural_solution.utils import logger -from collections import Counter + class Cluster: """Cluster resource management based on sockets.""" @@ -35,7 +37,7 @@ def __init__(self, node_lst=[], db_path=None): self.socket_queue = [] self.db_path = db_path create_dir(db_path) - self.conn = sqlite3.connect(f'{db_path}', check_same_thread=False) + self.conn = sqlite3.connect(f"{db_path}", check_same_thread=False) self.initial_cluster_from_node_lst(node_lst) self.lock = threading.Lock() @@ -63,7 +65,6 @@ def reserve_resource(self, task): logger.info(f"[Cluster] Assign {reserved_resource_lst} to task {task.task_id}") return reserved_resource_lst - @synchronized def free_resource(self, reserved_resource_lst): """Free the resource by adding the previous occupied resources to the socket queue.""" @@ -71,9 +72,9 @@ def free_resource(self, reserved_resource_lst): counts = Counter(int(item.split()[0]) for item in reserved_resource_lst) free_resources = {} for node_id, count in counts.items(): - free_resources[node_id] = count + free_resources[node_id] = count for node_id, count in counts.items(): - free_resources[node_id] = count + free_resources[node_id] = count for node_id in free_resources: sql = """ UPDATE cluster @@ -105,28 +106,31 @@ def initial_cluster_from_node_lst(self, node_lst): node_lst (List): the node list. """ # sqlite should set this check_same_thread to False - self.conn = sqlite3.connect(f'{self.db_path}', check_same_thread=False) + self.conn = sqlite3.connect(f"{self.db_path}", check_same_thread=False) self.cursor = self.conn.cursor() - self.cursor.execute('drop table if exists cluster ') - self.cursor.execute(r'create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,' + - 'node_info varchar(500),' + - 'status varchar(100),' + - 'free_sockets int,' + - 'busy_sockets int,' + - 'total_sockets int)') + self.cursor.execute("drop table if exists cluster ") + self.cursor.execute( + r"create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT," + + "node_info varchar(500)," + + "status varchar(100)," + + "free_sockets int," + + "busy_sockets int," + + "total_sockets int)" + ) self.node_lst = node_lst for index, node in enumerate(self.node_lst): - self.socket_queue += [str(index+1) + " " + node.name] * node.num_sockets - self.cursor.execute(r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)" + - "values ('{}', '{}', {}, {}, {})".format(repr(node).replace("Node", f"Node{index+1}"), - "alive", - node.num_sockets, - 0, - node.num_sockets)) + self.socket_queue += [str(index + 1) + " " + node.name] * node.num_sockets + self.cursor.execute( + r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)" + + "values ('{}', '{}', {}, {}, {})".format( + repr(node).replace("Node", f"Node{index+1}"), "alive", node.num_sockets, 0, node.num_sockets + ) + ) self.conn.commit() logger.info(f"socket_queue: {self.socket_queue}") + class Node: """Node definition.""" @@ -134,14 +138,11 @@ class Node: ip: str = "unknown_ip" num_sockets: int = 0 num_cores_per_socket: int = 0 - num_gpus: int = 0 # For future use - - def __init__(self, - name: str, - ip: str = "unknown_ip", - num_sockets: int = 0, - num_cores_per_socket: int = 0, - num_gpus: int = 0) -> None: + num_gpus: int = 0 # For future use + + def __init__( + self, name: str, ip: str = "unknown_ip", num_sockets: int = 0, num_cores_per_socket: int = 0, num_gpus: int = 0 + ) -> None: """Init node. hostfile template: @@ -167,8 +168,7 @@ def __repr__(self) -> str: Returns: str: node info. """ - return f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) " \ + return ( + f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) " f"and each socket has {self.num_cores_per_socket} cores." - - - + ) diff --git a/neural_solution/backend/result_monitor.py b/neural_solution/backend/result_monitor.py index 8061cee4a81..0972889f890 100644 --- a/neural_solution/backend/result_monitor.py +++ b/neural_solution/backend/result_monitor.py @@ -15,9 +15,11 @@ """Neural Solution result monitor.""" import socket -from neural_solution.backend.utils.utility import serialize, deserialize -from neural_solution.utils import logger + from neural_solution.backend.task_db import TaskDB +from neural_solution.backend.utils.utility import deserialize, serialize +from neural_solution.utils import logger + class ResultMonitor: """ResultMonitor is a thread that monitors the coming task results and update the task collection in the TaskDb. @@ -40,7 +42,7 @@ def __init__(self, port, task_db: TaskDB): def wait_result(self): """Monitor the task results and update them in the task db and send back to studio.""" - self.s.bind(("localhost", self.port)) # open a port as the serving port for results + self.s.bind(("localhost", self.port)) # open a port as the serving port for results self.s.listen(10) while True: logger.info("[ResultMonitor] waiting for results...") @@ -53,7 +55,7 @@ def wait_result(self): c.close() continue logger.info("[ResultMonitor] getting result: {}".format(result)) - logger.info("[ResultMonitor] getting q_model path: {}".format(result['q_model_path'])) + logger.info("[ResultMonitor] getting q_model path: {}".format(result["q_model_path"])) self.task_db.update_q_model_path_and_result(result["task_id"], result["q_model_path"], result["result"]) c.close() # TODO send back the result to the studio @@ -63,4 +65,3 @@ def query_task_status(self, task_id): """Synchronize query on the task status.""" # TODO send back the result to the studio? RPC for query? logger.info(self.task_db.lookup_task_status(task_id)) - diff --git a/neural_solution/backend/runner.py b/neural_solution/backend/runner.py index 173b9ca016a..8162b7b524f 100644 --- a/neural_solution/backend/runner.py +++ b/neural_solution/backend/runner.py @@ -13,13 +13,14 @@ # limitations under the License. """Main backend runner.""" -import threading import argparse +import threading -from neural_solution.backend import TaskDB, Scheduler, TaskMonitor, ResultMonitor -from neural_solution.utils import logger +from neural_solution.backend import ResultMonitor, Scheduler, TaskDB, TaskMonitor from neural_solution.backend.utils.utility import build_cluster, get_db_path from neural_solution.config import config +from neural_solution.utils import logger + def parse_args(args=None): """Parse the command line options. @@ -30,24 +31,25 @@ def parse_args(args=None): Returns: argparse.Namespace: arguments. """ - parser = argparse.ArgumentParser(description="Neural Solution runner automatically schedules multiple inc tasks and\ - executes multi-node distributed tuning.") - - parser.add_argument("-H", "--hostfile", type=str, default=None, \ - help="Path to the host file which contains all available nodes.") - parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \ - help="Port to monitor task.") - parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \ - help="Port to monitor result.") - parser.add_argument("-WS", "--workspace", type=str, default="./", \ - help="Work space.") - parser.add_argument("-CEN", "--conda_env_name", type=str, default="inc", \ - help="Conda environment for task execution") - parser.add_argument("-UP", "--upload_path", type=str, default="./examples", \ - help="Custom example path.") + parser = argparse.ArgumentParser( + description="Neural Solution runner automatically schedules multiple inc tasks and\ + executes multi-node distributed tuning." + ) + + parser.add_argument( + "-H", "--hostfile", type=str, default=None, help="Path to the host file which contains all available nodes." + ) + parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.") + parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.") + parser.add_argument("-WS", "--workspace", type=str, default="./", help="Work space.") + parser.add_argument( + "-CEN", "--conda_env_name", type=str, default="inc", help="Conda environment for task execution" + ) + parser.add_argument("-UP", "--upload_path", type=str, default="./examples", help="Custom example path.") return parser.parse_args(args=args) + def main(args=None): """Implement the main entry of backend. @@ -72,9 +74,15 @@ def main(args=None): t_rm = threading.Thread(target=rm.wait_result) config.workspace = args.workspace - ts = Scheduler(cluster, task_db, args.result_monitor_port, \ - conda_env_name=args.conda_env_name, upload_path=args.upload_path, config=config, \ - num_threads_per_process=num_threads_per_process) + ts = Scheduler( + cluster, + task_db, + args.result_monitor_port, + conda_env_name=args.conda_env_name, + upload_path=args.upload_path, + config=config, + num_threads_per_process=num_threads_per_process, + ) t_ts = threading.Thread(target=ts.schedule_tasks) tm = TaskMonitor(args.task_monitor_port, task_db) @@ -83,8 +91,9 @@ def main(args=None): t_rm.start() t_ts.start() t_tm.start() - logger.info("task monitor port {} and result monitor port {}".\ - format(args.task_monitor_port, args.result_monitor_port)) + logger.info( + "task monitor port {} and result monitor port {}".format(args.task_monitor_port, args.result_monitor_port) + ) logger.info("server start...") t_rm.join() @@ -92,5 +101,5 @@ def main(args=None): t_tm.join() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/neural_solution/backend/scheduler.py b/neural_solution/backend/scheduler.py index fc991a47436..1a494122cbf 100644 --- a/neural_solution/backend/scheduler.py +++ b/neural_solution/backend/scheduler.py @@ -13,39 +13,49 @@ # limitations under the License. """Neural Solution scheduler.""" -import time -import threading -import subprocess +import glob +import json import os -import socket import re -import json -import glob import shutil -from neural_solution.backend.task import Task +import socket +import subprocess +import threading +import time + from neural_solution.backend.cluster import Cluster +from neural_solution.backend.task import Task from neural_solution.backend.task_db import TaskDB from neural_solution.backend.utils.utility import ( - serialize, + build_workspace, dump_elapsed_time, - get_task_log_path, get_current_time, - build_workspace, + get_q_model_path, + get_task_log_path, is_remote_url, - get_q_model_path + serialize, ) -from neural_solution.utils.utility import get_task_log_workspace, get_task_workspace from neural_solution.utils import logger +from neural_solution.utils.utility import get_task_log_workspace, get_task_workspace -#TODO update it according to the platform -cmd="echo $(conda info --base)/etc/profile.d/conda.sh" +# TODO update it according to the platform +cmd = "echo $(conda info --base)/etc/profile.d/conda.sh" CONDA_SOURCE_PATH = subprocess.getoutput(cmd) + class Scheduler: """Scheduler dispatches the task with the available resources, calls the mpi command and report results.""" - def __init__(self, cluster: Cluster, task_db: TaskDB, result_monitor_port, \ - conda_env_name=None, upload_path="./examples", config=None, num_threads_per_process=5): + def __init__( + self, + cluster: Cluster, + task_db: TaskDB, + result_monitor_port, + conda_env_name=None, + upload_path="./examples", + config=None, + num_threads_per_process=5, + ): """Scheduler dispatches the task with the available resources, calls the mpi command and report results. Attributes: @@ -63,9 +73,9 @@ def __init__(self, cluster: Cluster, task_db: TaskDB, result_monitor_port, \ self.config = config self.num_threads_per_process = num_threads_per_process - def prepare_env(self, task:Task): + def prepare_env(self, task: Task): """Check and create a conda environment. - + If the required packages are not installed in the conda environment, create a new conda environment and install the required packages. @@ -76,7 +86,7 @@ def prepare_env(self, task:Task): env_prefix = self.conda_env_name requirement = task.requirement.split(" ") # Skip check when requirement is empty. - if requirement == ['']: + if requirement == [""]: return env_prefix # Construct the command to list all the conda environments cmd = f"conda env list" @@ -92,8 +102,9 @@ def prepare_env(self, task:Task): output = subprocess.getoutput(cmd) # Parse the output to get a list of installed package names installed_packages = [line.split()[0] for line in output.splitlines()[2:]] - installed_packages_version = \ - [line.split()[0] + "=" + line.split()[1] for line in output.splitlines()[2:]] + installed_packages_version = [ + line.split()[0] + "=" + line.split()[1] for line in output.splitlines()[2:] + ] missing_packages = set(requirement) - set(installed_packages) - set(installed_packages_version) if not missing_packages: conda_env = env_name @@ -101,19 +112,22 @@ def prepare_env(self, task:Task): if conda_env is None: # Construct the command to create a new conda environment and install the required packages from datetime import datetime + now = datetime.now() suffix = now.strftime("%Y%m%d-%H%M%S") conda_env = f"{env_prefix}_{suffix}" # Construct the name of the new conda environment - cmd = (f"source {CONDA_SOURCE_PATH} && conda create -n {conda_env} --clone {env_prefix}" - f" && conda activate {conda_env} && pip install {task.requirement.replace('=','==')}") - p = subprocess.Popen(cmd, shell=True) # nosec + cmd = ( + f"source {CONDA_SOURCE_PATH} && conda create -n {conda_env} --clone {env_prefix}" + f" && conda activate {conda_env} && pip install {task.requirement.replace('=','==')}" + ) + p = subprocess.Popen(cmd, shell=True) # nosec logger.info(f"[Scheduler] Creating new environment {conda_env} start.") p.wait() logger.info(f"[Scheduler] Creating new environment {conda_env} end.") return conda_env - def prepare_task(self, task:Task): + def prepare_task(self, task: Task): """Prepare workspace and download run_task.py for task. Args: @@ -122,24 +136,25 @@ def prepare_task(self, task:Task): self.task_path = build_workspace(path=get_task_workspace(self.config.workspace), task_id=task.task_id) logger.info(f"****TASK PATH: {self.task_path}") if is_remote_url(task.script_url): - task_url = task.script_url.replace('github.com', 'raw.githubusercontent.com').replace('blob','') + task_url = task.script_url.replace("github.com", "raw.githubusercontent.com").replace("blob", "") try: - subprocess.check_call(['wget', '-P', self.task_path, task_url]) + subprocess.check_call(["wget", "-P", self.task_path, task_url]) except subprocess.CalledProcessError as e: logger.info("Failed: {}".format(e.cmd)) else: # Assuming the file is uploaded in directory examples example_path = os.path.abspath(os.path.join(self.upload_path, task.script_url)) # only one python file - script_path = glob.glob(os.path.join(example_path, '*.py'))[0] + script_path = glob.glob(os.path.join(example_path, "*.py"))[0] # script_path = glob.glob(os.path.join(example_path, f'*{extension}'))[0] self.script_name = script_path.split("/")[-1] shutil.copy(script_path, os.path.abspath(self.task_path)) - task.arguments = task.arguments.replace("=dataset", "=" + os.path.join(example_path,"dataset")).\ - replace("=model", "=" + os.path.join(example_path,"model")) + task.arguments = task.arguments.replace("=dataset", "=" + os.path.join(example_path, "dataset")).replace( + "=model", "=" + os.path.join(example_path, "model") + ) if not task.optimized: # Generate quantization code with Neural Coder API - neural_coder_cmd = ['python -m neural_coder --enable --approach'] + neural_coder_cmd = ["python -m neural_coder --enable --approach"] # for users to define approach: "static, ""static_ipex", "dynamic", "auto" approach = task.approach neural_coder_cmd.append(approach) @@ -148,7 +163,7 @@ def prepare_task(self, task:Task): neural_coder_cmd.append(self.script_name) neural_coder_cmd = " ".join(neural_coder_cmd) full_cmd = """cd {}\n{}""".format(self.task_path, neural_coder_cmd) - p = subprocess.Popen(full_cmd, shell=True) # nosec + p = subprocess.Popen(full_cmd, shell=True) # nosec logger.info("[Neural Coder] Generating optimized code start.") p.wait() logger.info("[Neural Coder] Generating optimized code end.") @@ -163,7 +178,7 @@ def check_task_status(self, log_path): str: status "done" or "failed" """ for line in reversed(open(log_path).readlines()): - res_pattern = r'[INFO] Save deploy yaml to' + res_pattern = r"[INFO] Save deploy yaml to" # res_matches = re.findall(res_pattern, line) if res_pattern in line: return "done" @@ -180,24 +195,25 @@ def _parse_cmd(self, task: Task, resource): # Activate environment conda_bash_cmd = f"source {CONDA_SOURCE_PATH}" conda_env_cmd = f"conda activate {conda_env}" - mpi_cmd = ["mpirun", - "-np", - "{}".format(task.workers), - "-host", - "{}".format(host_str), - "-map-by", - "socket:pe={}".format(self.num_threads_per_process), - "-mca", - "btl_tcp_if_include", - "192.168.20.0/24", # TODO replace it according to the node - "-x", - "OMP_NUM_THREADS={}".format(self.num_threads_per_process), - '--report-bindings' - ] + mpi_cmd = [ + "mpirun", + "-np", + "{}".format(task.workers), + "-host", + "{}".format(host_str), + "-map-by", + "socket:pe={}".format(self.num_threads_per_process), + "-mca", + "btl_tcp_if_include", + "192.168.20.0/24", # TODO replace it according to the node + "-x", + "OMP_NUM_THREADS={}".format(self.num_threads_per_process), + "--report-bindings", + ] mpi_cmd = " ".join(mpi_cmd) # Initial Task command - task_cmd = ['python'] + task_cmd = ["python"] task_cmd.append(self.script_name) task_cmd.append(self.sanitize_arguments(task.arguments)) task_cmd = " ".join(task_cmd) @@ -208,10 +224,9 @@ def _parse_cmd(self, task: Task, resource): # build a bash script to run task. bash_script_name = "distributed_run.sh" if task.workers > 1 else "run.sh" - bash_script = """{}\n{}\ncd {}\n{}""".format(conda_bash_cmd, conda_env_cmd, \ - self.task_path, task_cmd) + bash_script = """{}\n{}\ncd {}\n{}""".format(conda_bash_cmd, conda_env_cmd, self.task_path, task_cmd) bash_script_path = os.path.join(self.task_path, bash_script_name) - with open(bash_script_path, 'w', encoding="utf-8") as f: + with open(bash_script_path, "w", encoding="utf-8") as f: f.write(bash_script) full_cmd = """cd {}\n{} bash {}""".format(self.task_path, mpi_cmd, bash_script_name) @@ -221,9 +236,9 @@ def report_result(self, task_id, log_path, task_runtime): """Report the result to the result monitor.""" s = socket.socket() s.connect(("localhost", self.result_monitor_port)) - results = {"optimization time (seconds)" : "{:.2f}".format(task_runtime)} + results = {"optimization time (seconds)": "{:.2f}".format(task_runtime)} for line in reversed(open(log_path).readlines()): - res_pattern = r'Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?' + res_pattern = r"Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?" res_matches = re.findall(res_pattern, line) if res_matches: # results["Tuning count"] = res_matches[0][0] @@ -237,7 +252,6 @@ def report_result(self, task_id, log_path, task_runtime): s.send(serialize({"task_id": task_id, "result": results, "q_model_path": self.q_model_path})) s.close() - @dump_elapsed_time("Task execution") def launch_task(self, task: Task, resource): """Generate the mpi command and execute the task. @@ -247,14 +261,15 @@ def launch_task(self, task: Task, resource): full_cmd = self._parse_cmd(task, resource) logger.info(f"[TaskScheduler] Parsed the command from task: {full_cmd}") log_path = get_task_log_path(log_path=get_task_log_workspace(self.config.workspace), task_id=task.task_id) - p = subprocess.Popen(full_cmd, stdout=open(log_path, 'w+'), stderr=subprocess.STDOUT, shell=True) # nosec + p = subprocess.Popen(full_cmd, stdout=open(log_path, "w+"), stderr=subprocess.STDOUT, shell=True) # nosec logger.info(f"[TaskScheduler] Start run task {task.task_id}, dump log into {log_path}") start_time = time.time() p.wait() self.cluster.free_resource(resource) task_runtime = time.time() - start_time - logger.info(\ - f"[TaskScheduler] Finished task {task.task_id}, and free resource {resource}, dump log into {log_path}") + logger.info( + f"[TaskScheduler] Finished task {task.task_id}, and free resource {resource}, dump log into {log_path}" + ) task_status = self.check_task_status(log_path) self.task_db.update_task_status(task.task_id, task_status) self.q_model_path = get_q_model_path(log_path=log_path) if task_status == "done" else None @@ -271,8 +286,10 @@ def schedule_tasks(self): time.sleep(5) logger.info(f"[TaskScheduler {get_current_time()}] try to dispatch a task...") if self.task_db.get_pending_task_num() > 0: - logger.info(f"[TaskScheduler {get_current_time()}], " + \ - f"there are {self.task_db.get_pending_task_num()} task pending.") + logger.info( + f"[TaskScheduler {get_current_time()}], " + + f"there are {self.task_db.get_pending_task_num()} task pending." + ) task_id = self.task_db.task_queue[0] task = self.task_db.get_task_by_id(task_id) resource = self.cluster.reserve_resource(task) @@ -287,4 +304,4 @@ def schedule_tasks(self): def sanitize_arguments(self, arguments: str): """Replace space encoding with space.""" - return arguments.replace(u'\xa0', u' ') \ No newline at end of file + return arguments.replace("\xa0", " ") diff --git a/neural_solution/backend/task.py b/neural_solution/backend/task.py index 722ec76e6db..00a6ec22a9a 100644 --- a/neural_solution/backend/task.py +++ b/neural_solution/backend/task.py @@ -13,6 +13,8 @@ # limitations under the License. """Neural Solution task.""" + + class Task: """A Task is an abstraction of a user tuning request that is handled in neural solution service. @@ -24,8 +26,19 @@ class Task: result: The result of the task, which is only value-assigned when the task is done """ - def __init__(self, task_id, arguments, workers, status, script_url, \ - optimized, approach, requirement, result="", q_model_path=""): + def __init__( + self, + task_id, + arguments, + workers, + status, + script_url, + optimized, + approach, + requirement, + result="", + q_model_path="", + ): """Init task. Args: @@ -49,4 +62,4 @@ def __init__(self, task_id, arguments, workers, status, script_url, \ self.approach = approach self.requirement = requirement self.result = result - self.q_model_path = q_model_path \ No newline at end of file + self.q_model_path = q_model_path diff --git a/neural_solution/backend/task_db.py b/neural_solution/backend/task_db.py index db8e3e7c0ec..35708d63264 100644 --- a/neural_solution/backend/task_db.py +++ b/neural_solution/backend/task_db.py @@ -13,11 +13,13 @@ # limitations under the License. """Neural Solution task database.""" -import threading import sqlite3 +import threading from collections import deque -from neural_solution.backend.utils.utility import create_dir + from neural_solution.backend.task import Task +from neural_solution.backend.utils.utility import create_dir + class TaskDB: """TaskDb manages all the tasks. @@ -39,12 +41,13 @@ def __init__(self, db_path): self.task_queue = deque() create_dir(db_path) # sqlite should set this check_same_thread to False - self.conn = sqlite3.connect(f'{db_path}', check_same_thread=False) + self.conn = sqlite3.connect(f"{db_path}", check_same_thread=False) self.cursor = self.conn.cursor() self.cursor.execute( - 'create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), ' + - 'workers int, status varchar(20), script_url varchar(500), optimized integer, ' + - 'approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))') + "create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), " + + "workers int, status varchar(20), script_url varchar(500), optimized integer, " + + "approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))" + ) self.conn.commit() # self.task_collections = [] self.lock = threading.Lock() @@ -72,7 +75,7 @@ def update_task_status(self, task_id, status): """Update the task status with the task id and the status.""" if status not in ["pending", "running", "done", "failed"]: raise Exception("status invalid, should be one of pending/running/done") - self.cursor.execute(r"update task set status='{}' where id=?".format(status), (task_id, )) + self.cursor.execute(r"update task set status='{}' where id=?".format(status), (task_id,)) self.conn.commit() def update_result(self, task_id, result_str): @@ -82,13 +85,14 @@ def update_result(self, task_id, result_str): def update_q_model_path_and_result(self, task_id, q_model_path, result_str): """Update the task result with the result string.""" - self.cursor.execute(r"update task set q_model_path='{}', result='{}' where id=?" - .format(q_model_path, result_str), (task_id, )) + self.cursor.execute( + r"update task set q_model_path='{}', result='{}' where id=?".format(q_model_path, result_str), (task_id,) + ) self.conn.commit() def lookup_task_status(self, task_id): """Look up the current task status and result.""" - self.cursor.execute(r"select status, result from task where id=?", (task_id, )) + self.cursor.execute(r"select status, result from task where id=?", (task_id,)) status, result = self.cursor.fetchone() return {"status": status, "result": result} @@ -98,7 +102,6 @@ def get_task_by_id(self, task_id): attr_tuple = self.cursor.fetchone() return Task(*attr_tuple) - def remove_task(self, task_id): # currently no garbage collection + def remove_task(self, task_id): # currently no garbage collection """Remove task.""" pass - diff --git a/neural_solution/backend/task_monitor.py b/neural_solution/backend/task_monitor.py index 7b52a1d5375..c6aa3745072 100644 --- a/neural_solution/backend/task_monitor.py +++ b/neural_solution/backend/task_monitor.py @@ -14,9 +14,11 @@ """Neural Solution task monitor.""" import socket -from neural_solution.backend.utils.utility import serialize, deserialize + +from neural_solution.backend.utils.utility import deserialize, serialize from neural_solution.utils import logger + class TaskMonitor: """TaskMonitor is a thread that monitors the coming tasks and appends them to the task queue. @@ -32,7 +34,7 @@ def __init__(self, port, task_db): self.task_db = task_db def _start_listening(self, host, port, max_parallelism): - self.s.bind(("localhost", port)) # open a port as the serving port for tasks + self.s.bind(("localhost", port)) # open a port as the serving port for tasks self.s.listen(max_parallelism) def _receive_task(self): @@ -62,4 +64,4 @@ def wait_new_task(self): task = self._receive_task() if not task: continue - self._append_task(task) \ No newline at end of file + self._append_task(task) diff --git a/neural_solution/backend/utils/__init__.py b/neural_solution/backend/utils/__init__.py index 1fdfb51540c..a716ddb1a6e 100644 --- a/neural_solution/backend/utils/__init__.py +++ b/neural_solution/backend/utils/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Neural Solution backend utils.""" \ No newline at end of file +"""Neural Solution backend utils.""" diff --git a/neural_solution/backend/utils/utility.py b/neural_solution/backend/utils/utility.py index 1b4a2b9f3f4..31a65984b07 100644 --- a/neural_solution/backend/utils/utility.py +++ b/neural_solution/backend/utils/utility.py @@ -15,17 +15,21 @@ """Neural Solution backend utils.""" import json import os -from neural_solution.utils import logger from urllib.parse import urlparse +from neural_solution.utils import logger + + def serialize(request: dict) -> bytes: """Serialize a dict object to bytes for inter-process communication.""" return json.dumps(request).encode() + def deserialize(request: bytes) -> dict: """Deserialize the recived bytes to a dict object.""" return json.loads(request) + def dump_elapsed_time(customized_msg=""): """Get the elapsed time for decorated functions. @@ -33,17 +37,23 @@ def dump_elapsed_time(customized_msg=""): customized_msg (string, optional): The parameter passed to decorator. Defaults to None. """ import time + def f(func): def fi(*args, **kwargs): start = time.time() res = func(*args, **kwargs) end = time.time() - logger.info('%s elapsed time: %s ms' %(customized_msg if customized_msg else func.__qualname__, - round((end - start) * 1000, 2))) + logger.info( + "%s elapsed time: %s ms" + % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) + ) return res + return fi + return f + def get_task_log_path(log_path, task_id): """Get the path of task log according id. @@ -56,7 +66,7 @@ def get_task_log_path(log_path, task_id): """ if not os.path.exists(log_path): os.makedirs(log_path) - log_file_path = "{}/task_{}.txt".format(log_path,task_id) + log_file_path = "{}/task_{}.txt".format(log_path, task_id) return log_file_path @@ -71,6 +81,7 @@ def get_db_path(workspace="./"): """ return os.path.join(workspace, "db", "task.db") + def get_task_workspace(workspace="./"): """Get the workspace of task. @@ -82,6 +93,7 @@ def get_task_workspace(workspace="./"): """ return os.path.join(workspace, "task_workspace") + def get_task_log_workspace(workspace="./"): """Get the log workspace for task. @@ -93,6 +105,7 @@ def get_task_log_workspace(workspace="./"): """ return os.path.join(workspace, "task_log") + def get_serve_log_workspace(workspace="./"): """Get log workspace for service. @@ -114,16 +127,18 @@ def build_local_cluster(db_path): Returns: (Cluster, int): cluster and num threads per process """ - from neural_solution.backend.cluster import Node, Cluster - hostname = 'localhost' - node1 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5) - node2 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5) - node3 = Node(name=hostname,num_sockets=2, num_cores_per_socket=5) + from neural_solution.backend.cluster import Cluster, Node + + hostname = "localhost" + node1 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5) + node2 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5) + node3 = Node(name=hostname, num_sockets=2, num_cores_per_socket=5) node_lst = [node1, node2, node3] cluster = Cluster(node_lst=node_lst, db_path=db_path) return cluster, 5 + def build_cluster(file_path, db_path): """Build cluster according to the host file. @@ -133,7 +148,8 @@ def build_cluster(file_path, db_path): Returns: Cluster: return cluster object. """ - from neural_solution.backend.cluster import Node, Cluster + from neural_solution.backend.cluster import Cluster, Node + # If no file is specified, build a local cluster if file_path == "None" or file_path is None: return build_local_cluster(db_path) @@ -143,7 +159,7 @@ def build_cluster(file_path, db_path): node_lst = [] num_threads_per_process = 5 - with open(file_path, 'r') as f: + with open(file_path, "r") as f: for line in f: hostname, num_sockets, num_cores_per_socket = line.strip().split(" ") num_sockets, num_cores_per_socket = int(num_sockets), int(num_cores_per_socket) @@ -153,6 +169,7 @@ def build_cluster(file_path, db_path): cluster = Cluster(node_lst=node_lst, db_path=db_path) return cluster, num_threads_per_process + def get_current_time(): """Get current time. @@ -160,7 +177,9 @@ def get_current_time(): str: the current time in hours, minutes, and seconds. """ from datetime import datetime - return datetime.now().strftime('%H:%M:%S') + + return datetime.now().strftime("%H:%M:%S") + def synchronized(func): """Locking for synchronization. @@ -168,11 +187,14 @@ def synchronized(func): Args: func (function): decorative function """ + def wrapper(self, *args, **kwargs): with self.lock: return func(self, *args, **kwargs) + return wrapper + def build_workspace(path, task_id=""): """Build workspace of running tasks. @@ -185,6 +207,7 @@ def build_workspace(path, task_id=""): os.makedirs(task_path) return os.path.abspath(task_path) + def is_remote_url(url_or_filename): """Check if input is a URL. @@ -197,11 +220,13 @@ def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") + def create_dir(path): """Create the (nested) path if not exist.""" if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) + def get_q_model_path(log_path): """Get the quantized model path from task log. @@ -212,9 +237,10 @@ def get_q_model_path(log_path): str: quantized model path """ import re + for line in reversed(open(log_path).readlines()): - match = re.search(r'(Save quantized model to|Save config file and weights of quantized model to) (.+?)\.', line) + match = re.search(r"(Save quantized model to|Save config file and weights of quantized model to) (.+?)\.", line) if match: q_model_path = match.group(2) return q_model_path - return "quantized model path not found" \ No newline at end of file + return "quantized model path not found" diff --git a/neural_solution/bin/__init__.py b/neural_solution/bin/__init__.py index 4f8f95fdc2f..63ed8c79470 100644 --- a/neural_solution/bin/__init__.py +++ b/neural_solution/bin/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Neural Solution.""" \ No newline at end of file +"""Neural Solution.""" diff --git a/neural_solution/bin/neural_solution.py b/neural_solution/bin/neural_solution.py index d41cfb34013..dccaa9bfdd8 100644 --- a/neural_solution/bin/neural_solution.py +++ b/neural_solution/bin/neural_solution.py @@ -18,7 +18,9 @@ def exec(): """Execute Neural Solution launch.""" from neural_solution.launcher import main + main() -if __name__ == '__main__': - exec() \ No newline at end of file + +if __name__ == "__main__": + exec() diff --git a/neural_solution/config.py b/neural_solution/config.py index 52c872557e2..8c30b4c6118 100644 --- a/neural_solution/config.py +++ b/neural_solution/config.py @@ -16,6 +16,7 @@ INTERVAL_TIME_BETWEEN_DISPATCH_TASK = 3 + class Config: """Config for services.""" @@ -24,6 +25,7 @@ class Config: result_monitor_port: int = 3333 service_address: str = "localhost" grpc_api_port: int = 4444 - #TODO add set and get methods for each attribute + # TODO add set and get methods for each attribute + -config = Config() \ No newline at end of file +config = Config() diff --git a/neural_solution/examples/custom_models_optimized/tf_example1/test.py b/neural_solution/examples/custom_models_optimized/tf_example1/test.py index 3f345669a1d..006e241551d 100644 --- a/neural_solution/examples/custom_models_optimized/tf_example1/test.py +++ b/neural_solution/examples/custom_models_optimized/tf_example1/test.py @@ -1,41 +1,44 @@ """Running script.""" import tensorflow as tf - -from neural_compressor.data import TensorflowImageRecord -from neural_compressor.data import BilinearImagenetTransform -from neural_compressor.data import ComposeTransform -from neural_compressor.data import DefaultDataLoader -from neural_compressor.quantization import fit -from neural_compressor.config import PostTrainingQuantConfig from neural_compressor import Metric +from neural_compressor.config import PostTrainingQuantConfig +from neural_compressor.data import BilinearImagenetTransform, ComposeTransform, DefaultDataLoader, TensorflowImageRecord +from neural_compressor.quantization import fit flags = tf.compat.v1.flags FLAGS = flags.FLAGS -flags.DEFINE_string('dataset_location', None, 'location of calibration dataset and evaluate dataset') +flags.DEFINE_string("dataset_location", None, "location of calibration dataset and evaluate dataset") -flags.DEFINE_string('model_path', None, 'location of model') +flags.DEFINE_string("model_path", None, "location of model") -calib_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform= \ - ComposeTransform(transform_list= [BilinearImagenetTransform(height=224, width=224)])) +calib_dataset = TensorflowImageRecord( + root=FLAGS.dataset_location, + transform=ComposeTransform(transform_list=[BilinearImagenetTransform(height=224, width=224)]), +) calib_dataloader = DefaultDataLoader(dataset=calib_dataset, batch_size=10) -eval_dataset = TensorflowImageRecord(root=FLAGS.dataset_location, transform=ComposeTransform(transform_list= \ - [BilinearImagenetTransform(height=224, width=224)])) +eval_dataset = TensorflowImageRecord( + root=FLAGS.dataset_location, + transform=ComposeTransform(transform_list=[BilinearImagenetTransform(height=224, width=224)]), +) eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=1) + def main(): """Implement running function.""" top1 = Metric(name="topk", k=1) config = PostTrainingQuantConfig(calibration_sampling_size=[20]) model_path = FLAGS.model_path + "/mobilenet_v1_1.0_224_frozen.pb" q_model = fit( - model= model_path, + model=model_path, conf=config, calib_dataloader=calib_dataloader, eval_dataloader=eval_dataloader, - eval_metric=top1) + eval_metric=top1, + ) q_model.save("./q_model_path/q_model") + if __name__ == "__main__": main() diff --git a/neural_solution/frontend/fastapi/__init__.py b/neural_solution/frontend/fastapi/__init__.py index 3fb002264a5..b862242e65b 100644 --- a/neural_solution/frontend/fastapi/__init__.py +++ b/neural_solution/frontend/fastapi/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""FastAPI frontend.""" \ No newline at end of file +"""FastAPI frontend.""" diff --git a/neural_solution/frontend/fastapi/main_server.py b/neural_solution/frontend/fastapi/main_server.py index d7a363cd122..1d88e0cade7 100644 --- a/neural_solution/frontend/fastapi/main_server.py +++ b/neural_solution/frontend/fastapi/main_server.py @@ -14,36 +14,32 @@ """Fast api server.""" -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException -from fastapi.responses import StreamingResponse, HTMLResponse -from neural_solution.frontend.task_submitter import Task, task_submitter -from neural_solution.frontend.utility import ( - get_cluster_info, - get_cluster_table, - serialize, - deserialize, - get_res_during_tuning, - get_baseline_during_tuning, - check_log_exists, - list_to_string) - -import sqlite3 -import os -import uuid -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler import asyncio import json +import os import socket -import uvicorn +import sqlite3 +import uuid -from neural_solution.utils.utility import ( - get_task_log_workspace, - get_db_path -) +import uvicorn +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse, StreamingResponse +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer from neural_solution.config import config - +from neural_solution.frontend.task_submitter import Task, task_submitter +from neural_solution.frontend.utility import ( + check_log_exists, + deserialize, + get_baseline_during_tuning, + get_cluster_info, + get_cluster_table, + get_res_during_tuning, + list_to_string, + serialize, +) +from neural_solution.utils.utility import get_db_path, get_task_log_workspace # Get config from Launcher.sh task_monitor_port = None @@ -57,19 +53,15 @@ args = None + def parse_arguments(): """Parse the command line options.""" parser = argparse.ArgumentParser(description="Frontend with RESTful API") - parser.add_argument("-H", "--host", type=str, default="0.0.0.0", \ - help="The address to submit task.") - parser.add_argument("-FP", "--fastapi_port", type=int, default=8000, \ - help="Port to submit task by user.") - parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \ - help="Port to monitor task.") - parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \ - help="Port to monitor result.") - parser.add_argument("-WS", "--workspace", type=str, default="./", \ - help="Work space.") + parser.add_argument("-H", "--host", type=str, default="0.0.0.0", help="The address to submit task.") + parser.add_argument("-FP", "--fastapi_port", type=int, default=8000, help="Port to submit task by user.") + parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.") + parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.") + parser.add_argument("-WS", "--workspace", type=str, default="./", help="Work space.") args = parser.parse_args() return args @@ -79,6 +71,7 @@ def read_root(): """Root route.""" return {"message": "Welcome to Neural Solution!"} + @app.get("/ping") def ping(): """Test status of services. @@ -100,14 +93,15 @@ def ping(): sock.close() continue except ConnectionRefusedError: - msg = "Ping fail! Make sure Neural Solution runner is running!" - break + msg = "Ping fail! Make sure Neural Solution runner is running!" + break except Exception as e: msg = "Ping fail! {}".format(e) break sock.close() return {"status": "Healthy", "msg": msg} if count == 2 else {"status": "Failed", "msg": msg} + @app.get("/cluster") def get_cluster(): """Get the cluster info. @@ -118,6 +112,7 @@ def get_cluster(): db_path = get_db_path(config.workspace) return get_cluster_info(db_path=db_path) + @app.get("/clusters") def get_clusters(): """Get the cluster info. @@ -128,6 +123,7 @@ def get_clusters(): db_path = get_db_path(config.workspace) return HTMLResponse(content=get_cluster_table(db_path=db_path)) + @app.get("/description") async def get_description(): """Get user oriented API descriptions. @@ -140,6 +136,7 @@ async def get_description(): data = json.load(f) return data + @app.post("/task/submit/") async def submit_task(task: Task): """Submit task. @@ -163,10 +160,19 @@ async def submit_task(task: Task): if os.path.isfile(db_path): conn = sqlite3.connect(db_path) cursor = conn.cursor() - task_id = str(uuid.uuid4()).replace('-','') - sql = r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" +\ - r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format(task_id, task.script_url, task.optimized, - list_to_string(task.arguments), task.approach, list_to_string(task.requirements), task.workers) + task_id = str(uuid.uuid4()).replace("-", "") + sql = ( + r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" + + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format( + task_id, + task.script_url, + task.optimized, + list_to_string(task.arguments), + task.approach, + list_to_string(task.requirements), + task.workers, + ) + ) cursor.execute(sql) conn.commit() try: @@ -180,9 +186,10 @@ async def submit_task(task: Task): conn.close() else: msg = "Task Submitted fail! db not found!" - return {"msg": msg} # TODO to align with return message when submit task successfully + return {"msg": msg} # TODO to align with return message when submit task successfully return {"status": status, "task_id": task_id, "msg": msg} + @app.get("/task/{task_id}") def get_task_by_id(task_id: str): """Get task status, result, quantized model path according to id. @@ -202,7 +209,8 @@ def get_task_by_id(task_id: str): res = cursor.fetchone() cursor.close() conn.close() - return {"status": res[0], 'optimized_result': deserialize(res[1]) if res[1] else res[1], "result_path": res[2]} + return {"status": res[0], "optimized_result": deserialize(res[1]) if res[1] else res[1], "result_path": res[2]} + @app.get("/task/") def get_all_tasks(): @@ -222,6 +230,7 @@ def get_all_tasks(): conn.close() return {"message": res} + @app.get("/task/status/{task_id}") def get_task_status_by_id(task_id: str): """Get task status and information according to id. @@ -241,12 +250,12 @@ def get_task_status_by_id(task_id: str): if os.path.isfile(db_path): conn = sqlite3.connect(db_path) cursor = conn.cursor() - cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id, )) + cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id,)) res = cursor.fetchone() cursor.close() conn.close() if not res: - status = "Please check url." + status = "Please check url." elif res[0] == "done": status = res[0] optimization_result = deserialize(res[1]) if res[1] else res[1] @@ -254,15 +263,14 @@ def get_task_status_by_id(task_id: str): elif res[0] == "pending": status = "pending" else: - baseline = get_baseline_during_tuning(task_id,get_task_log_workspace(config.workspace)) + baseline = get_baseline_during_tuning(task_id, get_task_log_workspace(config.workspace)) tuning_result = get_res_during_tuning(task_id, get_task_log_workspace(config.workspace)) status = res[0] - tuning_info = { - "baseline": baseline, - "message": tuning_result} + tuning_info = {"baseline": baseline, "message": tuning_result} result = {"status": status, "tuning_info": tuning_info, "optimization_result": optimization_result} return result + @app.get("/task/log/{task_id}") async def read_logs(task_id: str): """Get the log of task according to id. @@ -279,6 +287,7 @@ async def read_logs(task_id: str): log_path = "{}/task_{}.txt".format(get_task_log_workspace(config.workspace), task_id) if not os.path.exists(log_path): return {"error": "Logfile not found."} + def stream_logs(): with open(log_path) as f: while True: @@ -286,8 +295,10 @@ def stream_logs(): if not line: break yield line.encode() + return StreamingResponse(stream_logs(), media_type="text/plain") + # Real time output log class LogEventHandler(FileSystemEventHandler): """Responsible for monitoring log changes and sending logs to clients. @@ -308,11 +319,10 @@ def __init__(self, websocket: WebSocket, task_id, last_position): self.websocket = websocket self.task_id = task_id self.loop = asyncio.get_event_loop() - self.last_position = last_position # record last line + self.last_position = last_position # record last line self.queue = asyncio.Queue() self.timer = self.loop.create_task(self.send_messages()) - async def send_messages(self): """Send messages to the client.""" while True: @@ -340,6 +350,7 @@ def on_modified(self, event): for line in lines: self.queue.put_nowait(line.strip()) + # start log watcher def start_log_watcher(websocket, task_id, last_position): """Start log watcher. @@ -406,8 +417,8 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str): config.task_monitor_port = args.task_monitor_port config.result_monitor_port = args.result_monitor_port # initialize the task submitter - task_submitter.task_monitor_port=config.task_monitor_port - task_submitter.result_monitor_port=config.result_monitor_port + task_submitter.task_monitor_port = config.task_monitor_port + task_submitter.result_monitor_port = config.result_monitor_port config.service_address = task_submitter.service_address # start the app uvicorn.run(app, host=args.host, port=args.fastapi_port) diff --git a/neural_solution/frontend/gRPC/__init__.py b/neural_solution/frontend/gRPC/__init__.py index de608a1758d..1c1b8ab6053 100644 --- a/neural_solution/frontend/gRPC/__init__.py +++ b/neural_solution/frontend/gRPC/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""gRPC frontend.""" \ No newline at end of file +"""gRPC frontend.""" diff --git a/neural_solution/frontend/gRPC/client.py b/neural_solution/frontend/gRPC/client.py index ef8750954ac..c344e439017 100644 --- a/neural_solution/frontend/gRPC/client.py +++ b/neural_solution/frontend/gRPC/client.py @@ -16,16 +16,14 @@ import argparse import json -import grpc import os -from neural_solution.frontend.gRPC.proto import ( - neural_solution_pb2, - neural_solution_pb2_grpc -) +import grpc -from neural_solution.utils import logger from neural_solution.config import config +from neural_solution.frontend.gRPC.proto import neural_solution_pb2, neural_solution_pb2_grpc +from neural_solution.utils import logger + def _parse_task_from_json(request_path): file_path = os.path.abspath(request_path) @@ -33,33 +31,34 @@ def _parse_task_from_json(request_path): task = json.load(fp) return task + def submit_task(args): """Implement main entry point for the client of gRPC frontend.""" task = _parse_task_from_json(args.request) logger.info("Parsed task:") logger.info(task) - + # Create a gRPC channel port = str(config.grpc_api_port) - channel = grpc.insecure_channel('localhost:' + port) + channel = grpc.insecure_channel("localhost:" + port) # Create a stub (client) stub = neural_solution_pb2_grpc.TaskServiceStub(channel) # Ping serve - request = neural_solution_pb2.EmptyRequest() # pylint: disable=no-member + request = neural_solution_pb2.EmptyRequest() # pylint: disable=no-member response = stub.Ping(request) logger.info(response.status) logger.info(response.msg) # Create a task request with the desired fields - request = neural_solution_pb2.Task( # pylint: disable=no-member - script_url=task['script_url'], - optimized=task['optimized'] == 'True', - arguments=task['arguments'], - approach=task['approach'], - requirements=task['requirements'], - workers=task['workers'] + request = neural_solution_pb2.Task( # pylint: disable=no-member + script_url=task["script_url"], + optimized=task["optimized"] == "True", + arguments=task["arguments"], + approach=task["approach"], + requirements=task["requirements"], + workers=task["workers"], ) # Call the SubmitTask RPC on the server @@ -80,17 +79,18 @@ def run_query_task_result(args): task_id = args.task_id # Create a gRPC channel port = str(config.grpc_api_port) - channel = grpc.insecure_channel('localhost:' + port) + channel = grpc.insecure_channel("localhost:" + port) # Create a stub (client) stub = neural_solution_pb2_grpc.TaskServiceStub(channel) - request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member + request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member response = stub.QueryTaskResult(request) logger.info(response.status) logger.info(response.tuning_information) logger.info(response.optimization_result) + def run_query_task_status(args): """Query task status according to id. @@ -100,26 +100,26 @@ def run_query_task_status(args): task_id = args.task_id # Create a gRPC channel port = str(config.grpc_api_port) - channel = grpc.insecure_channel('localhost:' + port) + channel = grpc.insecure_channel("localhost:" + port) # Create a stub (client) stub = neural_solution_pb2_grpc.TaskServiceStub(channel) - request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member + request = neural_solution_pb2.TaskId(task_id=task_id) # pylint: disable=no-member response = stub.GetTaskById(request) logger.info(response.status) logger.info(response.optimized_result) logger.info(response.result_path) -if __name__ == '__main__': +if __name__ == "__main__": logger.info(f"Try to start gRPC server.") """Parse the command line options.""" parser = argparse.ArgumentParser(description="gRPC Client") subparsers = parser.add_subparsers(help="Action", dest="action") submit_action_parser = subparsers.add_parser("submit", help="Submit help") - + submit_action_parser.set_defaults(func=submit_task) submit_action_parser.add_argument("--request", type=str, default=None, help="Request json file path.") @@ -132,4 +132,4 @@ def run_query_task_status(args): # for test: # python client.py query --task_id="d3e10a49326449fb9d0d62f2bfc1cb43" -# python client.py submit --request="test_task_request.json" \ No newline at end of file +# python client.py submit --request="test_task_request.json" diff --git a/neural_solution/frontend/gRPC/proto/__init__.py b/neural_solution/frontend/gRPC/proto/__init__.py index facfd9f388a..059f44462f5 100644 --- a/neural_solution/frontend/gRPC/proto/__init__.py +++ b/neural_solution/frontend/gRPC/proto/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""gRPC proto.""" \ No newline at end of file +"""gRPC proto.""" diff --git a/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py b/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py index 8c8f874d6d3..65e8ff7ca34 100644 --- a/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py +++ b/neural_solution/frontend/gRPC/proto/neural_solution_pb2.py @@ -3,10 +3,11 @@ # source: neural_solution.proto # pylint: disable=all """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -14,30 +15,30 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15neural_solution.proto\x12\x0fneural_solution\x1a\x1bgoogle/protobuf/empty.proto\"y\n\x04Task\x12\x12\n\nscript_url\x18\x01 \x01(\t\x12\x11\n\toptimized\x18\x02 \x01(\x08\x12\x11\n\targuments\x18\x03 \x03(\t\x12\x10\n\x08\x61pproach\x18\x04 \x01(\t\x12\x14\n\x0crequirements\x18\x05 \x03(\t\x12\x0f\n\x07workers\x18\x06 \x01(\x05\"<\n\x0cTaskResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t\"\x19\n\x06TaskId\x12\x0f\n\x07task_id\x18\x01 \x01(\t\"K\n\nTaskStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x18\n\x10optimized_result\x18\x02 \x01(\t\x12\x13\n\x0bresult_path\x18\x03 \x01(\t\"\x0e\n\x0c\x45mptyRequest\"!\n\x0eWelcomeMessage\x12\x0f\n\x07message\x18\x01 \x01(\t\"2\n\x13ResponsePingMessage\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0b\n\x03msg\x18\x02 \x01(\t\"]\n\x12ResponseTaskResult\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x1a\n\x12tuning_information\x18\x02 \x01(\t\x12\x1b\n\x13optimization_result\x18\x03 \x01(\t2\xb5\x02\n\x0bTaskService\x12\x46\n\x04Ping\x12\x16.google.protobuf.Empty\x1a$.neural_solution.ResponsePingMessage\"\x00\x12\x44\n\nSubmitTask\x12\x15.neural_solution.Task\x1a\x1d.neural_solution.TaskResponse\"\x00\x12\x45\n\x0bGetTaskById\x12\x17.neural_solution.TaskId\x1a\x1b.neural_solution.TaskStatus\"\x00\x12Q\n\x0fQueryTaskResult\x12\x17.neural_solution.TaskId\x1a#.neural_solution.ResponseTaskResult\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x15neural_solution.proto\x12\x0fneural_solution\x1a\x1bgoogle/protobuf/empty.proto"y\n\x04Task\x12\x12\n\nscript_url\x18\x01 \x01(\t\x12\x11\n\toptimized\x18\x02 \x01(\x08\x12\x11\n\targuments\x18\x03 \x03(\t\x12\x10\n\x08\x61pproach\x18\x04 \x01(\t\x12\x14\n\x0crequirements\x18\x05 \x03(\t\x12\x0f\n\x07workers\x18\x06 \x01(\x05"<\n\x0cTaskResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07task_id\x18\x02 \x01(\t\x12\x0b\n\x03msg\x18\x03 \x01(\t"\x19\n\x06TaskId\x12\x0f\n\x07task_id\x18\x01 \x01(\t"K\n\nTaskStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x18\n\x10optimized_result\x18\x02 \x01(\t\x12\x13\n\x0bresult_path\x18\x03 \x01(\t"\x0e\n\x0c\x45mptyRequest"!\n\x0eWelcomeMessage\x12\x0f\n\x07message\x18\x01 \x01(\t"2\n\x13ResponsePingMessage\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0b\n\x03msg\x18\x02 \x01(\t"]\n\x12ResponseTaskResult\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x1a\n\x12tuning_information\x18\x02 \x01(\t\x12\x1b\n\x13optimization_result\x18\x03 \x01(\t2\xb5\x02\n\x0bTaskService\x12\x46\n\x04Ping\x12\x16.google.protobuf.Empty\x1a$.neural_solution.ResponsePingMessage"\x00\x12\x44\n\nSubmitTask\x12\x15.neural_solution.Task\x1a\x1d.neural_solution.TaskResponse"\x00\x12\x45\n\x0bGetTaskById\x12\x17.neural_solution.TaskId\x1a\x1b.neural_solution.TaskStatus"\x00\x12Q\n\x0fQueryTaskResult\x12\x17.neural_solution.TaskId\x1a#.neural_solution.ResponseTaskResult"\x00\x62\x06proto3' +) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'neural_solution_pb2', globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "neural_solution_pb2", globals()) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _TASK._serialized_start=71 - _TASK._serialized_end=192 - _TASKRESPONSE._serialized_start=194 - _TASKRESPONSE._serialized_end=254 - _TASKID._serialized_start=256 - _TASKID._serialized_end=281 - _TASKSTATUS._serialized_start=283 - _TASKSTATUS._serialized_end=358 - _EMPTYREQUEST._serialized_start=360 - _EMPTYREQUEST._serialized_end=374 - _WELCOMEMESSAGE._serialized_start=376 - _WELCOMEMESSAGE._serialized_end=409 - _RESPONSEPINGMESSAGE._serialized_start=411 - _RESPONSEPINGMESSAGE._serialized_end=461 - _RESPONSETASKRESULT._serialized_start=463 - _RESPONSETASKRESULT._serialized_end=556 - _TASKSERVICE._serialized_start=559 - _TASKSERVICE._serialized_end=868 + DESCRIPTOR._options = None + _TASK._serialized_start = 71 + _TASK._serialized_end = 192 + _TASKRESPONSE._serialized_start = 194 + _TASKRESPONSE._serialized_end = 254 + _TASKID._serialized_start = 256 + _TASKID._serialized_end = 281 + _TASKSTATUS._serialized_start = 283 + _TASKSTATUS._serialized_end = 358 + _EMPTYREQUEST._serialized_start = 360 + _EMPTYREQUEST._serialized_end = 374 + _WELCOMEMESSAGE._serialized_start = 376 + _WELCOMEMESSAGE._serialized_end = 409 + _RESPONSEPINGMESSAGE._serialized_start = 411 + _RESPONSEPINGMESSAGE._serialized_end = 461 + _RESPONSETASKRESULT._serialized_start = 463 + _RESPONSETASKRESULT._serialized_end = 556 + _TASKSERVICE._serialized_start = 559 + _TASKSERVICE._serialized_end = 868 # @@protoc_insertion_point(module_scope) diff --git a/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py b/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py index 285511e05f4..fcc1132c5c1 100644 --- a/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py +++ b/neural_solution/frontend/gRPC/proto/neural_solution_pb2_grpc.py @@ -2,8 +2,8 @@ # pylint: disable=all """Client and server classes corresponding to protobuf-defined services.""" import grpc - from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + import neural_solution.frontend.gRPC.proto.neural_solution_pb2 as neural__solution__pb2 @@ -17,25 +17,25 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Ping = channel.unary_unary( - '/neural_solution.TaskService/Ping', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=neural__solution__pb2.ResponsePingMessage.FromString, - ) + "/neural_solution.TaskService/Ping", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=neural__solution__pb2.ResponsePingMessage.FromString, + ) self.SubmitTask = channel.unary_unary( - '/neural_solution.TaskService/SubmitTask', - request_serializer=neural__solution__pb2.Task.SerializeToString, - response_deserializer=neural__solution__pb2.TaskResponse.FromString, - ) + "/neural_solution.TaskService/SubmitTask", + request_serializer=neural__solution__pb2.Task.SerializeToString, + response_deserializer=neural__solution__pb2.TaskResponse.FromString, + ) self.GetTaskById = channel.unary_unary( - '/neural_solution.TaskService/GetTaskById', - request_serializer=neural__solution__pb2.TaskId.SerializeToString, - response_deserializer=neural__solution__pb2.TaskStatus.FromString, - ) + "/neural_solution.TaskService/GetTaskById", + request_serializer=neural__solution__pb2.TaskId.SerializeToString, + response_deserializer=neural__solution__pb2.TaskStatus.FromString, + ) self.QueryTaskResult = channel.unary_unary( - '/neural_solution.TaskService/QueryTaskResult', - request_serializer=neural__solution__pb2.TaskId.SerializeToString, - response_deserializer=neural__solution__pb2.ResponseTaskResult.FromString, - ) + "/neural_solution.TaskService/QueryTaskResult", + request_serializer=neural__solution__pb2.TaskId.SerializeToString, + response_deserializer=neural__solution__pb2.ResponseTaskResult.FromString, + ) class TaskServiceServicer(object): @@ -44,26 +44,26 @@ class TaskServiceServicer(object): def Ping(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def SubmitTask(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GetTaskById(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def QueryTaskResult(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_TaskServiceServicer_to_server(servicer, server): @@ -74,47 +74,48 @@ def add_TaskServiceServicer_to_server(servicer, server): server (grpc._server._Server): server """ rpc_method_handlers = { - 'Ping': grpc.unary_unary_rpc_method_handler( - servicer.Ping, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=neural__solution__pb2.ResponsePingMessage.SerializeToString, - ), - 'SubmitTask': grpc.unary_unary_rpc_method_handler( - servicer.SubmitTask, - request_deserializer=neural__solution__pb2.Task.FromString, - response_serializer=neural__solution__pb2.TaskResponse.SerializeToString, - ), - 'GetTaskById': grpc.unary_unary_rpc_method_handler( - servicer.GetTaskById, - request_deserializer=neural__solution__pb2.TaskId.FromString, - response_serializer=neural__solution__pb2.TaskStatus.SerializeToString, - ), - 'QueryTaskResult': grpc.unary_unary_rpc_method_handler( - servicer.QueryTaskResult, - request_deserializer=neural__solution__pb2.TaskId.FromString, - response_serializer=neural__solution__pb2.ResponseTaskResult.SerializeToString, - ), + "Ping": grpc.unary_unary_rpc_method_handler( + servicer.Ping, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=neural__solution__pb2.ResponsePingMessage.SerializeToString, + ), + "SubmitTask": grpc.unary_unary_rpc_method_handler( + servicer.SubmitTask, + request_deserializer=neural__solution__pb2.Task.FromString, + response_serializer=neural__solution__pb2.TaskResponse.SerializeToString, + ), + "GetTaskById": grpc.unary_unary_rpc_method_handler( + servicer.GetTaskById, + request_deserializer=neural__solution__pb2.TaskId.FromString, + response_serializer=neural__solution__pb2.TaskStatus.SerializeToString, + ), + "QueryTaskResult": grpc.unary_unary_rpc_method_handler( + servicer.QueryTaskResult, + request_deserializer=neural__solution__pb2.TaskId.FromString, + response_serializer=neural__solution__pb2.ResponseTaskResult.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler( - 'neural_solution.TaskService', rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler("neural_solution.TaskService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class TaskService(object): """Interface exported by the server.""" @staticmethod - def Ping(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Ping( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): """Tets server status. Args: @@ -154,23 +155,35 @@ def Ping(request, Returns: The response to the RPC. """ - return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/Ping', + return grpc.experimental.unary_unary( + request, + target, + "/neural_solution.TaskService/Ping", google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, neural__solution__pb2.ResponsePingMessage.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def SubmitTask(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def SubmitTask( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): """Submit task. Args: @@ -210,23 +223,35 @@ def SubmitTask(request, Returns: The response to the RPC. """ - return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/SubmitTask', + return grpc.experimental.unary_unary( + request, + target, + "/neural_solution.TaskService/SubmitTask", neural__solution__pb2.Task.SerializeToString, neural__solution__pb2.TaskResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def GetTaskById(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def GetTaskById( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): """Get task status according to id. Args: @@ -266,23 +291,35 @@ def GetTaskById(request, Returns: The response to the RPC. """ - return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/GetTaskById', + return grpc.experimental.unary_unary( + request, + target, + "/neural_solution.TaskService/GetTaskById", neural__solution__pb2.TaskId.SerializeToString, neural__solution__pb2.TaskStatus.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def QueryTaskResult(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def QueryTaskResult( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): """Get task result according to id. Args: @@ -322,8 +359,18 @@ def QueryTaskResult(request, Returns: The response to the RPC. """ - return grpc.experimental.unary_unary(request, target, '/neural_solution.TaskService/QueryTaskResult', + return grpc.experimental.unary_unary( + request, + target, + "/neural_solution.TaskService/QueryTaskResult", neural__solution__pb2.TaskId.SerializeToString, neural__solution__pb2.ResponseTaskResult.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/neural_solution/frontend/gRPC/server.py b/neural_solution/frontend/gRPC/server.py index e451ee3dad0..b838e375ada 100644 --- a/neural_solution/frontend/gRPC/server.py +++ b/neural_solution/frontend/gRPC/server.py @@ -14,25 +14,24 @@ """Server of gRPC frontend.""" -from concurrent import futures -import logging -import grpc import argparse +import logging +from concurrent import futures -from neural_solution.frontend.gRPC.proto import ( - neural_solution_pb2, - neural_solution_pb2_grpc) +import grpc from neural_solution.config import config -from neural_solution.utils import logger -from neural_solution.utils.utility import get_db_path, dict_to_str +from neural_solution.frontend.gRPC.proto import neural_solution_pb2, neural_solution_pb2_grpc from neural_solution.frontend.task_submitter import task_submitter - from neural_solution.frontend.utility import ( - submit_task_to_db, check_service_status, + query_task_result, query_task_status, - query_task_result) + submit_task_to_db, +) +from neural_solution.utils import logger +from neural_solution.utils.utility import dict_to_str, get_db_path + class TaskSubmitterServicer(neural_solution_pb2_grpc.TaskServiceServicer): """Deliver services. @@ -58,7 +57,7 @@ def Ping(self, empty_msg, context): print(f"Ping grpc serve.") port_lst = [config.result_monitor_port] result = check_service_status(port_lst, service_address=config.service_address) - response = neural_solution_pb2.ResponsePingMessage(**result) # pylint: disable=no-member + response = neural_solution_pb2.ResponsePingMessage(**result) # pylint: disable=no-member return response def SubmitTask(self, task, context): @@ -82,7 +81,7 @@ def SubmitTask(self, task, context): print(db_path) result = submit_task_to_db(task=task, task_submitter=task_submitter, db_path=get_db_path(config.workspace)) # Return a response - response = neural_solution_pb2.TaskResponse(**result) # pylint: disable=no-member + response = neural_solution_pb2.TaskResponse(**result) # pylint: disable=no-member return response def GetTaskById(self, task_id, context): @@ -97,7 +96,7 @@ def GetTaskById(self, task_id, context): db_path = get_db_path(config.workspace) result = query_task_status(task_id.task_id, db_path) print(f"query result : result") - response = neural_solution_pb2.TaskStatus(**result) # pylint: disable=no-member + response = neural_solution_pb2.TaskStatus(**result) # pylint: disable=no-member return response def QueryTaskResult(self, task_id, context): @@ -111,9 +110,9 @@ def QueryTaskResult(self, task_id, context): """ db_path = get_db_path(config.workspace) result = query_task_result(task_id.task_id, db_path, config.workspace) - result['tuning_information'] = dict_to_str(result["tuning_information"]) - result['optimization_result'] = dict_to_str(result["optimization_result"]) - response = neural_solution_pb2.ResponseTaskResult(**result) # pylint: disable=no-member + result["tuning_information"] = dict_to_str(result["tuning_information"]) + result["optimization_result"] = dict_to_str(result["optimization_result"]) + response = neural_solution_pb2.ResponseTaskResult(**result) # pylint: disable=no-member return response @@ -121,9 +120,8 @@ def serve(): """Service entrance.""" port = str(config.grpc_api_port) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - neural_solution_pb2_grpc.add_TaskServiceServicer_to_server( - TaskSubmitterServicer(), server) - server.add_insecure_port('[::]:' + port) + neural_solution_pb2_grpc.add_TaskServiceServicer_to_server(TaskSubmitterServicer(), server) + server.add_insecure_port("[::]:" + port) server.start() print("Server started, listening on " + port) server.wait_for_termination() @@ -132,21 +130,16 @@ def serve(): def parse_arguments(): """Parse the command line options.""" parser = argparse.ArgumentParser(description="Frontend with gRPC API") - parser.add_argument("-H", "--host", type=str, default="0.0.0.0", \ - help="The address to submit task.") - parser.add_argument("-FP", "--grpc_api_port", type=int, default=8001, \ - help="Port to submit task by user.") - parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \ - help="Port to monitor task.") - parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \ - help="Port to monitor result.") - parser.add_argument("-WS", "--workspace", type=str, default="./ns_workspace", \ - help="Work space.") + parser.add_argument("-H", "--host", type=str, default="0.0.0.0", help="The address to submit task.") + parser.add_argument("-FP", "--grpc_api_port", type=int, default=8001, help="Port to submit task by user.") + parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.") + parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.") + parser.add_argument("-WS", "--workspace", type=str, default="./ns_workspace", help="Work space.") args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": logger.info(f"Try to start gRPC server.") logging.basicConfig() args = parse_arguments() diff --git a/neural_solution/frontend/task_submitter.py b/neural_solution/frontend/task_submitter.py index 880496fd14e..e4274260c3d 100644 --- a/neural_solution/frontend/task_submitter.py +++ b/neural_solution/frontend/task_submitter.py @@ -14,12 +14,14 @@ """Neural Solution task submitter.""" -import socket import json +import socket + from pydantic import BaseModel # pylint: disable=no-name-in-module from neural_solution.config import config + class Task(BaseModel): """Task definition for submitting requests. @@ -34,6 +36,7 @@ class Task(BaseModel): requirements: list workers: int + class TaskSubmitter: """Responsible for submitting tasks.""" @@ -65,5 +68,7 @@ def submit_task(self, tid): s.send(self.serialize(tid)) s.close() -task_submitter = TaskSubmitter(task_monitor_port=config.task_monitor_port, - result_monitor_port=config.result_monitor_port) \ No newline at end of file + +task_submitter = TaskSubmitter( + task_monitor_port=config.task_monitor_port, result_monitor_port=config.result_monitor_port +) diff --git a/neural_solution/frontend/utility.py b/neural_solution/frontend/utility.py index 454603ad1db..fd8fc2dcc25 100644 --- a/neural_solution/frontend/utility.py +++ b/neural_solution/frontend/utility.py @@ -15,13 +15,16 @@ """Common utilities for all frontend components.""" import json -import sqlite3 import os import re +import socket +import sqlite3 import uuid + import pandas as pd -import socket -from neural_solution.utils.utility import get_task_log_workspace, dict_to_str + +from neural_solution.utils.utility import dict_to_str, get_task_log_workspace + def query_task_status(task_id, db_path): """Query task status according to id. @@ -41,9 +44,12 @@ def query_task_status(task_id, db_path): res = cursor.fetchone() cursor.close() conn.close() - return {"status": res[0], - 'optimized_result': dict_to_str(deserialize(res[1]) if res[1] else res[1]), - "result_path": res[2]} + return { + "status": res[0], + "optimized_result": dict_to_str(deserialize(res[1]) if res[1] else res[1]), + "result_path": res[2], + } + def query_task_result(task_id, db_path, workspace): """Query the task result according id. @@ -64,13 +70,13 @@ def query_task_result(task_id, db_path, workspace): if os.path.isfile(db_path): conn = sqlite3.connect(db_path) cursor = conn.cursor() - cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id, )) + cursor.execute(r"select status, result, q_model_path from task where id=?", (task_id,)) res = cursor.fetchone() cursor.close() conn.close() print(f"in query") if not res: - status = "Please check url." + status = "Please check url." elif res[0] == "done": status = res[0] optimization_result = deserialize(res[1]) if res[1] else res[1] @@ -85,6 +91,7 @@ def query_task_result(task_id, db_path, workspace): result = {"status": status, "tuning_information": tuning_info, "optimization_result": optimization_result} return result + def check_service_status(port_lst, service_address): """Check server status. @@ -109,8 +116,8 @@ def check_service_status(port_lst, service_address): sock.close() continue except ConnectionRefusedError: - msg = "Ping fail! Make sure Neural Solution runner is running!" - break + msg = "Ping fail! Make sure Neural Solution runner is running!" + break except Exception as e: msg = "Ping fail! {}".format(e) break @@ -136,16 +143,19 @@ def submit_task_to_db(task, task_submitter, db_path): if os.path.isfile(db_path): conn = sqlite3.connect(db_path) cursor = conn.cursor() - task_id = str(uuid.uuid4()).replace('-','') - sql = r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" +\ - r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format( - task_id, - task.script_url, - task.optimized, - list_to_string(task.arguments), - task.approach, - list_to_string(task.requirements), - task.workers) + task_id = str(uuid.uuid4()).replace("-", "") + sql = ( + r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)" + + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')".format( + task_id, + task.script_url, + task.optimized, + list_to_string(task.arguments), + task.approach, + list_to_string(task.requirements), + task.workers, + ) + ) cursor.execute(sql) conn.commit() try: @@ -161,18 +171,21 @@ def submit_task_to_db(task, task_submitter, db_path): msg = "Task Submitted fail! db not found!" result["status"] = status result["task_id"] = task_id - result["msg"]=msg + result["msg"] = msg return result + def serialize(request: dict) -> bytes: """Serialize a dict object to bytes for inter-process communication.""" return json.dumps(request).encode() + def deserialize(request: bytes) -> dict: """Deserialize the received bytes to a dict object.""" return json.loads(request) -def get_cluster_info(db_path:str): + +def get_cluster_info(db_path: str): """Get cluster information from database. Returns: @@ -186,7 +199,8 @@ def get_cluster_info(db_path:str): conn.close() return {"Cluster info": rows} -def get_cluster_table(db_path:str): + +def get_cluster_table(db_path: str): """Get cluster table from database. Returns: @@ -197,11 +211,14 @@ def get_cluster_table(db_path:str): cursor.execute(r"select * from cluster") conn.commit() rows = cursor.fetchall() - df = pd.DataFrame(rows, columns=["Node", "Node info", "status","free workers", "busy workers", "total workers"]) - html_table = df.to_html(index=False, ) + df = pd.DataFrame(rows, columns=["Node", "Node info", "status", "free workers", "busy workers", "total workers"]) + html_table = df.to_html( + index=False, + ) conn.close() return html_table + def get_res_during_tuning(task_id: str, task_log_path): """Get result during tuning. @@ -214,8 +231,8 @@ def get_res_during_tuning(task_id: str, task_log_path): results = {} log_path = "{}/task_{}.txt".format(task_log_path, task_id) for line in reversed(open(log_path).readlines()): - res_pattern = r'Tune (\d+) result is: ' - res_pattern = r'Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?' + res_pattern = r"Tune (\d+) result is: " + res_pattern = r"Tune (\d+) result is:\s.*?\(int8\|fp32\):\s+(\d+\.\d+).*?\(int8\|fp32\):\s+(\d+\.\d+).*?" res_matches = re.findall(res_pattern, line) if res_matches: results["Tuning count"] = res_matches[0][0] @@ -225,7 +242,8 @@ def get_res_during_tuning(task_id: str, task_log_path): break print("Query results: {}".format(results)) - return results if results else "Tune 1 running..." + return results if results else "Tune 1 running..." + def get_baseline_during_tuning(task_id: str, task_log_path): """Get result during tuning. @@ -248,7 +266,8 @@ def get_baseline_during_tuning(task_id: str, task_log_path): break print("FP32 baseline: {}".format(results)) - return results if results else "Getting FP32 baseline..." + return results if results else "Getting FP32 baseline..." + def check_log_exists(task_id: str, task_log_path): """Check whether the log file exists. @@ -265,6 +284,7 @@ def check_log_exists(task_id: str, task_log_path): else: return False + def list_to_string(lst: list): """Convert the list to a space concatenated string. @@ -274,4 +294,4 @@ def list_to_string(lst: list): Returns: str: string """ - return " ".join(str(i) for i in lst) \ No newline at end of file + return " ".join(str(i) for i in lst) diff --git a/neural_solution/launcher.py b/neural_solution/launcher.py index c2628af12a4..3711bed982f 100644 --- a/neural_solution/launcher.py +++ b/neural_solution/launcher.py @@ -13,16 +13,18 @@ # limitations under the License. """The entry of Neural Solution.""" -import sys -import subprocess -import os import argparse +import os +import shlex import socket -import psutil +import subprocess +import sys import time -import shlex from datetime import datetime +import psutil + + def check_ports(args): """Check parameters ending in '_port'. @@ -30,9 +32,10 @@ def check_ports(args): args (argparse.Namespace): parameters. """ for arg in vars(args): - if '_port' in arg: + if "_port" in arg: check_port(getattr(args, arg)) + def check_port(port): """Check if the given port is standardized. @@ -43,6 +46,7 @@ def check_port(port): print(f"Error: Invalid port number: {port}") sys.exit(1) + def get_local_service_ip(port): """Get the local IP address of the machine running the service. @@ -53,32 +57,34 @@ def get_local_service_ip(port): str: The IP address of the machine running the service. """ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', port)) + s.connect(("8.8.8.8", port)) return s.getsockname()[0] + def stop_service(): """Stop service.""" # Get all running processes for proc in psutil.process_iter(): try: # Get the process details - pinfo = proc.as_dict(attrs=['pid', 'name', 'cmdline']) + pinfo = proc.as_dict(attrs=["pid", "name", "cmdline"]) # Check if the process is the target process - if "neural_solution.backend.runner" in pinfo['cmdline']: + if "neural_solution.backend.runner" in pinfo["cmdline"]: # Terminate the process using Process.kill() method - process = psutil.Process(pinfo['pid']) + process = psutil.Process(pinfo["pid"]) process.kill() - elif "neural_solution.frontend.fastapi.main_server" in pinfo['cmdline']: - process = psutil.Process(pinfo['pid']) + elif "neural_solution.frontend.fastapi.main_server" in pinfo["cmdline"]: + process = psutil.Process(pinfo["pid"]) process.kill() - elif "neural_solution.frontend.gRPC.server" in pinfo['cmdline']: - process = psutil.Process(pinfo['pid']) + elif "neural_solution.frontend.gRPC.server" in pinfo["cmdline"]: + process = psutil.Process(pinfo["pid"]) process.kill() except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): pass # Service End print("Neural Solution Service Stopped!") + def check_port_free(port): """Check if the port is free. @@ -89,10 +95,11 @@ def check_port_free(port): bool : the free state of the port. """ for conn in psutil.net_connections(): - if conn.status == 'LISTEN' and conn.laddr.port == port: + if conn.status == "LISTEN" and conn.laddr.port == port: return False return True + def start_service(args): """Start service. @@ -118,8 +125,10 @@ def start_service(args): print("No environment specified or conda environment activated !!!") sys.exit(1) else: - print(f"No environment specified, use environment activated:" + \ - f" ({conda_env}) as the task runtime environment.") + print( + f"No environment specified, use environment activated:" + + f" ({conda_env}) as the task runtime environment." + ) conda_env_name = conda_env else: conda_env_name = args.conda_env @@ -132,34 +141,67 @@ def start_service(args): date_suffix = "_" + date_time.strftime("%Y%m%d-%H%M%S") date_suffix = "" with open(f"{serve_log_dir}/backend{date_suffix}.log", "w") as f: - subprocess.Popen([ - "python", "-m", "neural_solution.backend.runner", - "--hostfile", shlex.quote(str(args.hostfile)), - "--task_monitor_port", shlex.quote(str(args.task_monitor_port)), - "--result_monitor_port", shlex.quote(str(args.result_monitor_port)), - "--workspace", shlex.quote(str(args.workspace)), - "--conda_env_name", shlex.quote(str(conda_env_name)), - "--upload_path", shlex.quote(str(args.upload_path)) - ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT) + subprocess.Popen( + [ + "python", + "-m", + "neural_solution.backend.runner", + "--hostfile", + shlex.quote(str(args.hostfile)), + "--task_monitor_port", + shlex.quote(str(args.task_monitor_port)), + "--result_monitor_port", + shlex.quote(str(args.result_monitor_port)), + "--workspace", + shlex.quote(str(args.workspace)), + "--conda_env_name", + shlex.quote(str(conda_env_name)), + "--upload_path", + shlex.quote(str(args.upload_path)), + ], + stdout=os.dup(f.fileno()), + stderr=subprocess.STDOUT, + ) if args.api_type in ["all", "restful"]: with open(f"{serve_log_dir}/frontend{date_suffix}.log", "w") as f: - subprocess.Popen([ - "python", "-m", "neural_solution.frontend.fastapi.main_server", - "--host", "0.0.0.0", - "--fastapi_port", shlex.quote(str(args.restful_api_port)), - "--task_monitor_port", shlex.quote(str(args.task_monitor_port)), - "--result_monitor_port", shlex.quote(str(args.result_monitor_port)), - "--workspace", shlex.quote(str(args.workspace)) - ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT) + subprocess.Popen( + [ + "python", + "-m", + "neural_solution.frontend.fastapi.main_server", + "--host", + "0.0.0.0", + "--fastapi_port", + shlex.quote(str(args.restful_api_port)), + "--task_monitor_port", + shlex.quote(str(args.task_monitor_port)), + "--result_monitor_port", + shlex.quote(str(args.result_monitor_port)), + "--workspace", + shlex.quote(str(args.workspace)), + ], + stdout=os.dup(f.fileno()), + stderr=subprocess.STDOUT, + ) if args.api_type in ["all", "grpc"]: with open(f"{serve_log_dir}/frontend_grpc.log", "w") as f: - subprocess.Popen([ - "python", "-m", "neural_solution.frontend.gRPC.server", - "--grpc_api_port", shlex.quote(str(args.grpc_api_port)), - "--task_monitor_port", shlex.quote(str(args.task_monitor_port)), - "--result_monitor_port", shlex.quote(str(args.result_monitor_port)), - "--workspace", shlex.quote(str(args.workspace)) - ], stdout=os.dup(f.fileno()), stderr=subprocess.STDOUT) + subprocess.Popen( + [ + "python", + "-m", + "neural_solution.frontend.gRPC.server", + "--grpc_api_port", + shlex.quote(str(args.grpc_api_port)), + "--task_monitor_port", + shlex.quote(str(args.task_monitor_port)), + "--result_monitor_port", + shlex.quote(str(args.result_monitor_port)), + "--workspace", + shlex.quote(str(args.workspace)), + ], + stdout=os.dup(f.fileno()), + stderr=subprocess.STDOUT, + ) ip_address = get_local_service_ip(80) # Check if the service is started @@ -169,8 +211,11 @@ def start_service(args): start_time = time.time() while True: # Check if the ports are in use - if check_port_free(args.task_monitor_port) or check_port_free(args.result_monitor_port) \ - or check_port_free(args.restful_api_port): + if ( + check_port_free(args.task_monitor_port) + or check_port_free(args.result_monitor_port) + or check_port_free(args.restful_api_port) + ): # If the ports are not in use, wait for a second and check again time.sleep(0.5) # Check if timed out @@ -205,28 +250,40 @@ def start_service(args): # Check completed print("Neural Solution Service Started!") - print(f"Service log saving path is in \"{os.path.abspath(serve_log_dir)}\"") + print(f'Service log saving path is in "{os.path.abspath(serve_log_dir)}"') print(f"To submit task at: {ip_address}:{args.restful_api_port}/task/submit/") print("[For information] neural_solution help") + def main(): """Implement the main function.""" parser = argparse.ArgumentParser(description="Neural Solution") - parser.add_argument('action', choices=['start', 'stop'], help='start/stop service') - parser.add_argument("--hostfile", default=None, - help="start backend serve host file which contains all available nodes") - parser.add_argument("--restful_api_port", type=int, default=8000, - help="start restful serve with {restful_api_port}, default 8000") - parser.add_argument("--grpc_api_port", type=int, default=8001, - help="start gRPC with {restful_api_port}, default 8001") - parser.add_argument("--result_monitor_port", type=int, default=3333, - help="start serve for result monitor at {result_monitor_port}, default 3333") - parser.add_argument("--task_monitor_port", type=int, default=2222, - help="start serve for task monitor at {task_monitor_port}, default 2222") - parser.add_argument("--api_type", default="all", - help="start web serve with all/grpc/restful, default all") - parser.add_argument("--workspace", default="./ns_workspace", - help="neural solution workspace, default \"./ns_workspace\"") + parser.add_argument("action", choices=["start", "stop"], help="start/stop service") + parser.add_argument( + "--hostfile", default=None, help="start backend serve host file which contains all available nodes" + ) + parser.add_argument( + "--restful_api_port", type=int, default=8000, help="start restful serve with {restful_api_port}, default 8000" + ) + parser.add_argument( + "--grpc_api_port", type=int, default=8001, help="start gRPC with {restful_api_port}, default 8001" + ) + parser.add_argument( + "--result_monitor_port", + type=int, + default=3333, + help="start serve for result monitor at {result_monitor_port}, default 3333", + ) + parser.add_argument( + "--task_monitor_port", + type=int, + default=2222, + help="start serve for task monitor at {task_monitor_port}, default 2222", + ) + parser.add_argument("--api_type", default="all", help="start web serve with all/grpc/restful, default all") + parser.add_argument( + "--workspace", default="./ns_workspace", help='neural solution workspace, default "./ns_workspace"' + ) parser.add_argument("--conda_env", default=None, help="specify the running environment for the task") parser.add_argument("--upload_path", default="examples", help="specify the file path for the tasks") args = parser.parse_args() @@ -234,10 +291,11 @@ def main(): # Check parameters ending in '_port' check_ports(args) - if args.action == 'start': + if args.action == "start": start_service(args) - elif args.action == 'stop': + elif args.action == "stop": stop_service() - -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/neural_solution/scripts/prepare_deps.py b/neural_solution/scripts/prepare_deps.py index 33c711b9273..c7df3524ac2 100644 --- a/neural_solution/scripts/prepare_deps.py +++ b/neural_solution/scripts/prepare_deps.py @@ -20,4 +20,3 @@ - CONDA - other packages, such as, mpi4py """ - diff --git a/neural_solution/test/backend/test_cluster.py b/neural_solution/test/backend/test_cluster.py index fbb001d14c7..8ea687e81b6 100644 --- a/neural_solution/test/backend/test_cluster.py +++ b/neural_solution/test/backend/test_cluster.py @@ -1,16 +1,12 @@ """Tests for cluster""" import importlib -import shutil import os +import shutil import unittest - -import unittest from neural_solution.backend.cluster import Cluster, Node from neural_solution.backend.task import Task - - -from neural_solution.utils.utility import get_task_workspace, get_task_log_workspace, get_db_path +from neural_solution.utils.utility import get_db_path, get_task_log_workspace, get_task_workspace NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE) @@ -23,16 +19,16 @@ def setUp(self): self.cluster = Cluster(node_lst, db_path=db_path) self.task = Task( - task_id = "1", - arguments = ["arg1", "arg2"], - workers = 2, - status = "pending", - script_url = "https://example.com", - optimized = True, - approach = "static", - requirement = ["req1", "req2"], - result = "", - q_model_path = "q_model_path" + task_id="1", + arguments=["arg1", "arg2"], + workers=2, + status="pending", + script_url="https://example.com", + optimized=True, + approach="static", + requirement=["req1", "req2"], + result="", + q_model_path="q_model_path", ) @classmethod @@ -49,7 +45,7 @@ def test_free_resource(self): task = self.task reserved_resource_lst = self.cluster.reserve_resource(task) self.cluster.free_resource(reserved_resource_lst) - self.assertEqual(self.cluster.socket_queue, ['2 node2', '2 node2', '1 node1', '1 node1']) + self.assertEqual(self.cluster.socket_queue, ["2 node2", "2 node2", "1 node1", "1 node1"]) def test_get_free_socket(self): free_socket_lst = self.cluster.get_free_socket(4) @@ -61,5 +57,6 @@ def test_get_free_socket(self): free_socket_lst = self.cluster.get_free_socket(10) self.assertEqual(free_socket_lst, 0) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_result_monitor.py b/neural_solution/test/backend/test_result_monitor.py index 18f83a78196..1691826729f 100644 --- a/neural_solution/test/backend/test_result_monitor.py +++ b/neural_solution/test/backend/test_result_monitor.py @@ -2,28 +2,29 @@ import threading import unittest from unittest.mock import MagicMock, patch + from neural_solution.backend.result_monitor import ResultMonitor -class TestResultMonitor(unittest.TestCase): - @patch('socket.socket') +class TestResultMonitor(unittest.TestCase): + @patch("socket.socket") def test_wait_result(self, mock_socket): # Mock data for testing task_db = MagicMock() task_db.lookup_task_status.return_value = "COMPLETED" - result = {'task_id': 1, 'q_model_path': "path/to/q_model", 'result': 0.8} + result = {"task_id": 1, "q_model_path": "path/to/q_model", "result": 0.8} serialized_result = json.dumps(result) - + mock_c = MagicMock() mock_c.recv.return_value = serialized_result - + mock_socket.return_value.accept.return_value = (mock_c, MagicMock()) mock_socket.return_value.recv.return_value = serialized_result mock_socket.return_value.__enter__.return_value = mock_socket.return_value # Create a ResultMonitor object and call the wait_result method result_monitor = ResultMonitor(8080, task_db) - with patch('neural_solution.backend.result_monitor.deserialize', return_value={"ping": "test"}): + with patch("neural_solution.backend.result_monitor.deserialize", return_value={"ping": "test"}): adding_abort = threading.Thread( target=result_monitor.wait_result, args=(), @@ -43,7 +44,7 @@ def test_query_task_status(self): # Assert that the task_db.lookup_task_status method was called with the correct argument task_db.lookup_task_status.assert_called_once_with(1) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_runner.py b/neural_solution/test/backend/test_runner.py index 1c0e447079b..c400f00c427 100644 --- a/neural_solution/test/backend/test_runner.py +++ b/neural_solution/test/backend/test_runner.py @@ -1,26 +1,33 @@ -import unittest -from unittest.mock import patch -import threading -import os import argparse +import os import shutil -from neural_solution.backend.runner import parse_args, main +import threading +import unittest +from unittest.mock import patch -class TestMain(unittest.TestCase): +from neural_solution.backend.runner import main, parse_args + +class TestMain(unittest.TestCase): @classmethod def tearDownClass(cls) -> None: - os.remove('test.txt') + os.remove("test.txt") shutil.rmtree("ns_workspace", ignore_errors=True) def test_parse_args(self): - args = ['-H', 'path/to/hostfile', '-TMP', '2222', '-RMP', '3333', '-CEN', 'inc'] - with patch('argparse.ArgumentParser.parse_args', \ - return_value=argparse.Namespace(hostfile='path/to/hostfile', \ - task_monitor_port=2222, result_monitor_port=3333, conda_env_name='inc')): - self.assertEqual(parse_args(args), \ - argparse.Namespace(hostfile='path/to/hostfile', \ - task_monitor_port=2222, result_monitor_port=3333, conda_env_name='inc')) + args = ["-H", "path/to/hostfile", "-TMP", "2222", "-RMP", "3333", "-CEN", "inc"] + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + hostfile="path/to/hostfile", task_monitor_port=2222, result_monitor_port=3333, conda_env_name="inc" + ), + ): + self.assertEqual( + parse_args(args), + argparse.Namespace( + hostfile="path/to/hostfile", task_monitor_port=2222, result_monitor_port=3333, conda_env_name="inc" + ), + ) def test_main(self): """Test blocking flag in abort_job method.""" @@ -29,11 +36,12 @@ def test_main(self): f.write("hostname1 2 20\nhostname2 2 20") adding_abort = threading.Thread( target=main, - kwargs={'args': ['-H', 'test.txt', '-TMP', '2222', '-RMP', '3333', '-CEN', 'inc_conda_env']}, + kwargs={"args": ["-H", "test.txt", "-TMP", "2222", "-RMP", "3333", "-CEN", "inc_conda_env"]}, daemon=True, ) adding_abort.start() adding_abort.join(timeout=2) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_scheduler.py b/neural_solution/test/backend/test_scheduler.py index 8139a5afb7b..a84689d4658 100644 --- a/neural_solution/test/backend/test_scheduler.py +++ b/neural_solution/test/backend/test_scheduler.py @@ -1,30 +1,31 @@ -import unittest import os -import threading import shutil +import threading +import unittest from subprocess import CalledProcessError -from unittest.mock import MagicMock, patch, mock_open, Mock +from unittest.mock import MagicMock, Mock, mock_open, patch + +from neural_solution.backend.cluster import Cluster from neural_solution.backend.scheduler import Scheduler from neural_solution.backend.task import Task -from neural_solution.backend.cluster import Cluster from neural_solution.backend.task_db import TaskDB from neural_solution.backend.utils.utility import dump_elapsed_time, get_task_log_path - -import os - -from neural_solution.utils.utility import get_db_path from neural_solution.config import config +from neural_solution.utils.utility import get_db_path NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE) config.workspace = NEURAL_SOLUTION_WORKSPACE + class TestScheduler(unittest.TestCase): def setUp(self): self.cluster = Cluster(db_path=db_path) self.task_db = TaskDB(db_path=db_path) self.result_monitor_port = 1234 - self.scheduler = Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="for_ns_test", config=config) + self.scheduler = Scheduler( + self.cluster, self.task_db, self.result_monitor_port, conda_env_name="for_ns_test", config=config + ) def tearDown(self) -> None: shutil.rmtree("ns_workspace", ignore_errors=True) @@ -34,28 +35,69 @@ def tearDownClass(cls) -> None: shutil.rmtree("examples") def test_prepare_env(self): - task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "pip", "test_result", "test_q_model_path") + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "pip", + "test_result", + "test_q_model_path", + ) result = self.scheduler.prepare_env(task) self.assertTrue(result.startswith(self.scheduler.conda_env_name)) # Test requirement in {conda_env} case - task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "pip", "test_result", "test_q_model_path") - scheduler_test = Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="base", config=config) + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "pip", + "test_result", + "test_q_model_path", + ) + scheduler_test = Scheduler( + self.cluster, self.task_db, self.result_monitor_port, conda_env_name="base", config=config + ) result = scheduler_test.prepare_env(task) - self.assertTrue(result.startswith('base')) + self.assertTrue(result.startswith("base")) # Test requirement is '' case - task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "", "test_result", "test_q_model_path") + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "", + "test_result", + "test_q_model_path", + ) result = self.scheduler.prepare_env(task) self.assertEqual(result, self.scheduler.conda_env_name) def test_prepare_task(self): - task = Task("test_task", "test_arguments", "test_workers", "test_status",\ - "test_example", \ - "test_optimized", "static", "test_requirement", "test_result", "test_q_model_path") + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_example", + "test_optimized", + "static", + "test_requirement", + "test_result", + "test_q_model_path", + ) test_path = "examples/test_example" if not os.path.exists(test_path): os.makedirs(test_path) @@ -65,33 +107,50 @@ def test_prepare_task(self): self.scheduler.prepare_task(task) # url case - with patch('neural_solution.backend.scheduler.is_remote_url', return_value=True): + with patch("neural_solution.backend.scheduler.is_remote_url", return_value=True): self.scheduler.prepare_task(task) # optimized is False case - task = Task("test_task", "test_arguments", "test_workers", "test_status",\ - "test_example", \ - False, "static", "test_requirement", "test_result", "test_q_model_path") + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_example", + False, + "static", + "test_requirement", + "test_result", + "test_q_model_path", + ) self.scheduler.prepare_task(task) - with patch('neural_solution.backend.scheduler.is_remote_url', return_value=True): - task = Task("test_task", "test_arguments", "test_workers", "test_status",\ - "test_example/test.py", \ - False, "static", "test_requirement", "test_result", "test_q_model_path") + with patch("neural_solution.backend.scheduler.is_remote_url", return_value=True): + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_example/test.py", + False, + "static", + "test_requirement", + "test_result", + "test_q_model_path", + ) self.scheduler.prepare_task(task) - def test_check_task_status(self): log_path = "test_log_path" # done case - with patch('builtins.open', mock_open(read_data='[INFO] Save deploy yaml to\n')) as mock_file: + with patch("builtins.open", mock_open(read_data="[INFO] Save deploy yaml to\n")) as mock_file: result = self.scheduler.check_task_status(log_path) - self.assertEqual(result, 'done') + self.assertEqual(result, "done") # failed case - with patch('builtins.open', mock_open(read_data='[INFO] Deploying...\n')) as mock_file: + with patch("builtins.open", mock_open(read_data="[INFO] Deploying...\n")) as mock_file: result = self.scheduler.check_task_status(log_path) - self.assertEqual(result, 'failed') + self.assertEqual(result, "failed") def test_sanitize_arguments(self): arguments = "test_arguments\xa0" @@ -99,15 +158,25 @@ def test_sanitize_arguments(self): self.assertEqual(sanitized_arguments, "test_arguments ") def test_dispatch_task(self): - task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path") + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "test_requirement", + "test_result", + "test_q_model_path", + ) resource = [("node1", "8"), ("node2", "8")] - with patch('neural_solution.backend.scheduler.Scheduler.launch_task') as mock_launch_task: + with patch("neural_solution.backend.scheduler.Scheduler.launch_task") as mock_launch_task: self.scheduler.dispatch_task(task, resource) self.assertTrue(mock_launch_task.called) - @patch('socket.socket') - @patch('builtins.open') + @patch("socket.socket") + @patch("builtins.open") def test_report_result(self, mock_open, mock_socket): task_id = "test_task" log_path = "test_log_path" @@ -115,29 +184,42 @@ def test_report_result(self, mock_open, mock_socket): self.scheduler.q_model_path = None mock_socket.return_value.connect.return_value = None mock_open.return_value.readlines.return_value = ["Tune 1 result is: (int8|fp32): 0.8 (int8|fp32): 0.9"] - expected_result = { - "optimization time (seconds)": "10.00", - "Accuracy": "0.8", - "Duration (seconds)": "0.9" - } + expected_result = {"optimization time (seconds)": "10.00", "Accuracy": "0.8", "Duration (seconds)": "0.9"} self.scheduler.report_result(task_id, log_path, task_runtime) mock_open.assert_called_once_with(log_path) mock_socket.assert_called_once_with() - mock_socket.return_value.connect.assert_called_once_with(('localhost', 1234)) + mock_socket.return_value.connect.assert_called_once_with(("localhost", 1234)) mock_socket.return_value.send.assert_called_once() - - @patch('neural_solution.backend.scheduler.Scheduler.prepare_task') - @patch('neural_solution.backend.scheduler.Scheduler.prepare_env') - @patch('neural_solution.backend.scheduler.Scheduler._parse_cmd') - @patch('subprocess.Popen') - @patch('neural_solution.backend.scheduler.Scheduler.check_task_status') - @patch('neural_solution.backend.scheduler.Scheduler.report_result') - def test_launch_task(self, mock_report_result, mock_check_task_status, mock_popen, mock_parse_cmd, mock_prepare_env, mock_prepare_task): - task = Task("test_task", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path") + @patch("neural_solution.backend.scheduler.Scheduler.prepare_task") + @patch("neural_solution.backend.scheduler.Scheduler.prepare_env") + @patch("neural_solution.backend.scheduler.Scheduler._parse_cmd") + @patch("subprocess.Popen") + @patch("neural_solution.backend.scheduler.Scheduler.check_task_status") + @patch("neural_solution.backend.scheduler.Scheduler.report_result") + def test_launch_task( + self, + mock_report_result, + mock_check_task_status, + mock_popen, + mock_parse_cmd, + mock_prepare_env, + mock_prepare_task, + ): + task = Task( + "test_task", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "test_requirement", + "test_result", + "test_q_model_path", + ) resource = ["1 node1", "2 node2"] mock_parse_cmd.return_value = "test_cmd" mock_check_task_status.return_value = "done" @@ -148,16 +230,26 @@ def test_launch_task(self, mock_report_result, mock_check_task_status, mock_pope self.scheduler.launch_task(task, resource) - - - @patch('neural_solution.backend.scheduler.Scheduler.launch_task') - @patch('neural_solution.backend.cluster.Cluster.reserve_resource') + @patch("neural_solution.backend.scheduler.Scheduler.launch_task") + @patch("neural_solution.backend.cluster.Cluster.reserve_resource") def test_schedule_tasks(self, mock_reserve_resource, mock_launch_task): - task1 = Task("1", "test_arguments", "test_workers", "test_status", "test_script_url", \ - "test_optimized", "test_approach", "test_requirement", "test_result", "test_q_model_path") - self.task_db.cursor.execute("insert into task values ('1', 'test_arguments', 'test_workers', \ + task1 = Task( + "1", + "test_arguments", + "test_workers", + "test_status", + "test_script_url", + "test_optimized", + "test_approach", + "test_requirement", + "test_result", + "test_q_model_path", + ) + self.task_db.cursor.execute( + "insert into task values ('1', 'test_arguments', 'test_workers', \ 'test_status', 'test_script_url', \ - 'test_optimized', 'test_approach', 'test_requirement', 'test_result', 'test_q_model_path')") + 'test_optimized', 'test_approach', 'test_requirement', 'test_result', 'test_q_model_path')" + ) # no pending task case adding_abort = threading.Thread( @@ -194,20 +286,20 @@ def test_schedule_tasks(self, mock_reserve_resource, mock_launch_task): class TestParseCmd(unittest.TestCase): - def setUp(self): self.cluster = Cluster(db_path=db_path) self.task_db = TaskDB(db_path=db_path) self.result_monitor_port = 1234 - self.task_scheduler = \ - Scheduler(self.cluster, self.task_db, self.result_monitor_port,conda_env_name="for_ns_test", config=config) + self.task_scheduler = Scheduler( + self.cluster, self.task_db, self.result_monitor_port, conda_env_name="for_ns_test", config=config + ) self.task = MagicMock() - self.resource = ['1 node1', '2 node2', '3 node3'] + self.resource = ["1 node1", "2 node2", "3 node3"] self.task.workers = 3 - self.task_name = 'test_task' - self.script_name = 'test_script.py' - self.task_path = '/path/to/task' - self.arguments = 'arg1 arg2' + self.task_name = "test_task" + self.script_name = "test_script.py" + self.task_path = "/path/to/task" + self.arguments = "arg1 arg2" self.task.arguments = self.arguments self.task.name = self.task_name self.task.optimized = True @@ -217,33 +309,37 @@ def setUp(self): self.task_scheduler.task_path = self.task_path def test__parse_cmd(self): - expected_cmd = 'cd /path/to/task\nmpirun -np 3 -host node1,node2,node3 -map-by socket:pe=5' + \ - ' -mca btl_tcp_if_include 192.168.20.0/24 -x OMP_NUM_THREADS=5 --report-bindings bash distributed_run.sh' - with patch('neural_solution.backend.scheduler.Scheduler.prepare_task') as mock_prepare_task, \ - patch('neural_solution.backend.scheduler.Scheduler.prepare_env') as mock_prepare_env, \ - patch('neural_solution.backend.scheduler.logger.info') as mock_logger_info, \ - patch('builtins.open', create=True) as mock_open, \ - patch('neural_solution.backend.scheduler.os.path.join') as mock_os_path_join: - + expected_cmd = ( + "cd /path/to/task\nmpirun -np 3 -host node1,node2,node3 -map-by socket:pe=5" + + " -mca btl_tcp_if_include 192.168.20.0/24 -x OMP_NUM_THREADS=5 --report-bindings bash distributed_run.sh" + ) + with patch("neural_solution.backend.scheduler.Scheduler.prepare_task") as mock_prepare_task, patch( + "neural_solution.backend.scheduler.Scheduler.prepare_env" + ) as mock_prepare_env, patch("neural_solution.backend.scheduler.logger.info") as mock_logger_info, patch( + "builtins.open", create=True + ) as mock_open, patch( + "neural_solution.backend.scheduler.os.path.join" + ) as mock_os_path_join: mock_prepare_task.return_value = None - mock_prepare_env.return_value = 'test_env' + mock_prepare_env.return_value = "test_env" mock_logger_info.return_value = None mock_open.return_value.__enter__ = lambda x: x mock_open.return_value.__exit__ = MagicMock() - mock_os_path_join.return_value = '/path/to/task/distributed_run.sh' + mock_os_path_join.return_value = "/path/to/task/distributed_run.sh" result = self.task_scheduler._parse_cmd(self.task, self.resource) self.assertEqual(result, expected_cmd) mock_prepare_task.assert_called_once_with(self.task) mock_prepare_env.assert_called_once_with(self.task) - mock_logger_info.assert_called_once_with('[TaskScheduler] host resource: node1,node2,node3') - mock_open.assert_called_once_with('/path/to/task/distributed_run.sh', 'w', encoding='utf-8') - mock_os_path_join.assert_called_once_with('/path/to/task', 'distributed_run.sh') + mock_logger_info.assert_called_once_with("[TaskScheduler] host resource: node1,node2,node3") + mock_open.assert_called_once_with("/path/to/task/distributed_run.sh", "w", encoding="utf-8") + mock_os_path_join.assert_called_once_with("/path/to/task", "distributed_run.sh") self.task.optimized = False result = self.task_scheduler._parse_cmd(self.task, self.resource) self.assertEqual(result, expected_cmd) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_task.py b/neural_solution/test/backend/test_task.py index 55073c92946..43d3c649ec5 100644 --- a/neural_solution/test/backend/test_task.py +++ b/neural_solution/test/backend/test_task.py @@ -1,15 +1,13 @@ import unittest + from neural_solution.backend.task import Task + class TestTask(unittest.TestCase): def setUp(self): - self.task = Task("123", - "python script.py", - 4, "pending", - "http://example.com/script.py", - True, - "approach", - "requirement") + self.task = Task( + "123", "python script.py", 4, "pending", "http://example.com/script.py", True, "approach", "requirement" + ) def test_task_attributes(self): self.assertEqual(self.task.task_id, "123") @@ -23,5 +21,6 @@ def test_task_attributes(self): self.assertEqual(self.task.result, "") self.assertEqual(self.task.q_model_path, "") -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_task_db.py b/neural_solution/test/backend/test_task_db.py index 2b1c5b6f28b..38e9690a8e7 100644 --- a/neural_solution/test/backend/test_task_db.py +++ b/neural_solution/test/backend/test_task_db.py @@ -1,29 +1,21 @@ - -import unittest -from unittest.mock import patch, MagicMock -from neural_solution.backend.task_db import TaskDB, Task -import shutil import os +import shutil +import unittest +from unittest.mock import MagicMock, patch +from neural_solution.backend.task_db import Task, TaskDB from neural_solution.utils.utility import get_db_path NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") db_path = get_db_path(NEURAL_SOLUTION_WORKSPACE) -class TestTaskDB(unittest.TestCase): +class TestTaskDB(unittest.TestCase): def setUp(self): self.taskdb = TaskDB(db_path=db_path) self.task = Task( - '1', - 'arguments', - 1, - 'pending', - 'script_url', - 0, 'approach', - 'requirement', - 'result', - 'q_model_path') + "1", "arguments", 1, "pending", "script_url", 0, "approach", "requirement", "result", "q_model_path" + ) @classmethod def tearDownClass(cls) -> None: @@ -32,71 +24,85 @@ def tearDownClass(cls) -> None: def test_append_task(self): self.taskdb.append_task(self.task) self.assertEqual(len(self.taskdb.task_queue), 1) - self.assertEqual(self.taskdb.task_queue[0], '1') + self.assertEqual(self.taskdb.task_queue[0], "1") def test_get_pending_task_num(self): self.taskdb.append_task(self.task) self.assertEqual(self.taskdb.get_pending_task_num(), 1) def test_get_all_pending_tasks(self): - self.taskdb.cursor.execute("insert into task values ('2', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('2', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() pending_tasks = self.taskdb.get_all_pending_tasks() self.assertEqual(len(pending_tasks), 1) - self.assertEqual(pending_tasks[0].task_id, '2') + self.assertEqual(pending_tasks[0].task_id, "2") def test_update_task_status(self): - self.taskdb.cursor.execute("insert into task values ('3', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('3', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() - self.taskdb.update_task_status('3', 'running') + self.taskdb.update_task_status("3", "running") self.taskdb.cursor.execute("select status from task where id='3'") status = self.taskdb.cursor.fetchone()[0] - self.assertEqual(status, 'running') + self.assertEqual(status, "running") with self.assertRaises(Exception): - self.taskdb.update_task_status('3', 'invalid_status') + self.taskdb.update_task_status("3", "invalid_status") def test_update_result(self): - self.taskdb.cursor.execute("insert into task values ('4', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('4', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() - self.taskdb.update_result('4', 'new_result') + self.taskdb.update_result("4", "new_result") self.taskdb.cursor.execute("select result from task where id='4'") result = self.taskdb.cursor.fetchone()[0] - self.assertEqual(result, 'new_result') + self.assertEqual(result, "new_result") def test_update_q_model_path_and_result(self): - self.taskdb.cursor.execute("insert into task values ('5', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('5', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() - self.taskdb.update_q_model_path_and_result('5', 'new_q_model_path', 'new_result') + self.taskdb.update_q_model_path_and_result("5", "new_q_model_path", "new_result") self.taskdb.cursor.execute("select q_model_path, result from task where id='5'") q_model_path, result = self.taskdb.cursor.fetchone() - self.assertEqual(q_model_path, 'new_q_model_path') - self.assertEqual(result, 'new_result') + self.assertEqual(q_model_path, "new_q_model_path") + self.assertEqual(result, "new_result") def test_lookup_task_status(self): - self.taskdb.cursor.execute("insert into task values ('6', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('6', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() - status_dict = self.taskdb.lookup_task_status('6') - self.assertEqual(status_dict['status'], 'pending') - self.assertEqual(status_dict['result'], 'result') + status_dict = self.taskdb.lookup_task_status("6") + self.assertEqual(status_dict["status"], "pending") + self.assertEqual(status_dict["result"], "result") def test_get_task_by_id(self): - self.taskdb.cursor.execute("insert into task values ('7', 'arguments', 1, \ - 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')") + self.taskdb.cursor.execute( + "insert into task values ('7', 'arguments', 1, \ + 'pending', 'script_url', 0, 'approach', 'requirement', 'result', 'q_model_path')" + ) self.taskdb.conn.commit() - task = self.taskdb.get_task_by_id('7') - self.assertEqual(task.task_id, '7') - self.assertEqual(task.arguments, 'arguments') + task = self.taskdb.get_task_by_id("7") + self.assertEqual(task.task_id, "7") + self.assertEqual(task.arguments, "arguments") self.assertEqual(task.workers, 1) - self.assertEqual(task.status, 'pending') - self.assertEqual(task.result, 'result') + self.assertEqual(task.status, "pending") + self.assertEqual(task.result, "result") def test_remove_task(self): - self.taskdb.remove_task('1') + self.taskdb.remove_task("1") # currently no garbage collection, so this function does nothing -if __name__ == '__main__': - unittest.main() \ No newline at end of file + + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/test_task_monitor.py b/neural_solution/test/backend/test_task_monitor.py index 68ba7ee40bb..1e4bb069726 100644 --- a/neural_solution/test/backend/test_task_monitor.py +++ b/neural_solution/test/backend/test_task_monitor.py @@ -1,8 +1,10 @@ +import threading import unittest -from unittest.mock import Mock, patch, MagicMock -from neural_solution.backend.task_monitor import TaskMonitor +from unittest.mock import MagicMock, Mock, patch + from neural_solution.backend.task import Task -import threading +from neural_solution.backend.task_monitor import TaskMonitor + class TestTaskMonitor(unittest.TestCase): def setUp(self): @@ -18,37 +20,66 @@ def test__start_listening(self): mock_socket.return_value.bind = mock_bind mock_socket.return_value.listen = mock_listen self.task_monitor._start_listening("localhost", 8888, 10) - + def test_receive_task(self): - self.mock_socket.accept.return_value = (Mock(), b'{"task_id": 123, "arguments": {}, "workers": 1, \ + self.mock_socket.accept.return_value = ( + Mock(), + b'{"task_id": 123, "arguments": {}, "workers": 1, \ "status": "pending", "script_url": "http://example.com", "optimized": True, \ - "approach": "static", "requirement": "neural_solution", "result": "", "q_model_path": ""}') + "approach": "static", "requirement": "neural_solution", "result": "", "q_model_path": ""}', + ) self.mock_task_db.get_task_by_id.return_value = Task( - task_id=123, arguments={}, workers=1, - status="pending", script_url="http://example.com", optimized=True, - approach="static", requirement="neural_solution", result="", q_model_path="") - + task_id=123, + arguments={}, + workers=1, + status="pending", + script_url="http://example.com", + optimized=True, + approach="static", + requirement="neural_solution", + result="", + q_model_path="", + ) + # Test normal task case - with patch('neural_solution.backend.task_monitor.deserialize', return_value={ - "task_id": 123, "arguments": {}, "workers": 1, - "status": "pending", "script_url": "http://example.com", - "optimized": True, "approach": "static", - "requirement": "neural_solution", "result": "", "q_model_path": ""}): + with patch( + "neural_solution.backend.task_monitor.deserialize", + return_value={ + "task_id": 123, + "arguments": {}, + "workers": 1, + "status": "pending", + "script_url": "http://example.com", + "optimized": True, + "approach": "static", + "requirement": "neural_solution", + "result": "", + "q_model_path": "", + }, + ): task = self.task_monitor._receive_task() self.assertEqual(task.task_id, 123) self.mock_task_db.get_task_by_id.assert_called_once_with(123) - + # Test ping case - with patch('neural_solution.backend.task_monitor.deserialize', return_value={"ping": "test"}): + with patch("neural_solution.backend.task_monitor.deserialize", return_value={"ping": "test"}): response = self.task_monitor._receive_task() self.assertEqual(response, False) self.mock_task_db.get_task_by_id.assert_called_once_with(123) def test_append_task(self): task = Task( - task_id=123, arguments={}, workers=1, - status="pending", script_url="http://example.com", optimized=True, - approach="static", requirement="neural_solution", result="", q_model_path="") + task_id=123, + arguments={}, + workers=1, + status="pending", + script_url="http://example.com", + optimized=True, + approach="static", + requirement="neural_solution", + result="", + q_model_path="", + ) self.task_monitor._append_task(task) self.mock_task_db.append_task.assert_called_once_with(task) @@ -61,7 +92,7 @@ def test_wait_new_task(self): self.task_monitor._receive_task = mock_receive_task self.task_monitor._append_task = mock_append_task self.task_monitor._start_listening = MagicMock() - + # Call the function to be tested adding_abort = threading.Thread( target=self.task_monitor.wait_new_task, @@ -70,13 +101,14 @@ def test_wait_new_task(self): ) adding_abort.start() adding_abort.join(timeout=1) - + # Test task is False mock_receive_task = MagicMock(return_value=False) mock_append_task = MagicMock() self.task_monitor._receive_task = mock_receive_task - + adding_abort.join(timeout=1) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/backend/utils/test_utility.py b/neural_solution/test/backend/utils/test_utility.py index 8816179a2a6..296e2dc3e44 100644 --- a/neural_solution/test/backend/utils/test_utility.py +++ b/neural_solution/test/backend/utils/test_utility.py @@ -1,28 +1,36 @@ -import unittest import os import shutil -from unittest.mock import patch, MagicMock, mock_open -from neural_solution.backend.utils.utility import ( - serialize, deserialize, - dump_elapsed_time, get_task_log_path, build_local_cluster, - build_cluster, get_current_time, - synchronized, build_workspace, is_remote_url, create_dir, - get_q_model_path) +import unittest +from unittest.mock import MagicMock, mock_open, patch -from neural_solution.utils.utility import get_task_workspace, get_task_log_workspace, get_db_path +from neural_solution.backend.utils.utility import ( + build_cluster, + build_local_cluster, + build_workspace, + create_dir, + deserialize, + dump_elapsed_time, + get_current_time, + get_q_model_path, + get_task_log_path, + is_remote_url, + serialize, + synchronized, +) from neural_solution.config import config +from neural_solution.utils.utility import get_db_path, get_task_log_workspace, get_task_workspace NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db" -TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" +TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log" SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log" config.workspace = NEURAL_SOLUTION_WORKSPACE db_path = get_db_path(config.workspace) -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): @classmethod def tearDown(self) -> None: if os.path.exists("ns_workspace"): @@ -42,24 +50,28 @@ def test_dump_elapsed_time(self): @dump_elapsed_time("test function") def test_function(): return True - with patch('neural_solution.utils.logger') as mock_logger: + + with patch("neural_solution.utils.logger") as mock_logger: test_function() def test_get_task_log_path(self): task_id = 123 expected_output = f"{TASK_LOG_path}/task_{task_id}.txt" - self.assertEqual(get_task_log_path(log_path=get_task_log_workspace(config.workspace), task_id=task_id), expected_output) + self.assertEqual( + get_task_log_path(log_path=get_task_log_workspace(config.workspace), task_id=task_id), expected_output + ) def test_build_local_cluster(self): - with patch('neural_solution.backend.cluster.Node') as mock_node, \ - patch('neural_solution.backend.cluster.Cluster') as mock_cluster: + with patch("neural_solution.backend.cluster.Node") as mock_node, patch( + "neural_solution.backend.cluster.Cluster" + ) as mock_cluster: mock_node_obj = MagicMock() mock_node.return_value = mock_node_obj - mock_node_obj.name = 'localhost' + mock_node_obj.name = "localhost" mock_node_obj.num_sockets = 2 mock_node_obj.num_cores_per_socket = 5 build_local_cluster(db_path=db_path) - mock_node.assert_called_with(name='localhost', num_sockets=2, num_cores_per_socket=5) + mock_node.assert_called_with(name="localhost", num_sockets=2, num_cores_per_socket=5) mock_cluster.assert_called_once() def test_build_cluster(self): @@ -73,11 +85,9 @@ def test_build_cluster(self): os.remove("test.txt") file_path = "test_host_file" - with patch('neural_solution.backend.cluster.Node') as mock_node, \ - patch('neural_solution.backend.cluster.Cluster') as mock_cluster, \ - patch('builtins.open') as mock_open, \ - patch('os.path.exists') as mock_exists: - + with patch("neural_solution.backend.cluster.Node") as mock_node, patch( + "neural_solution.backend.cluster.Cluster" + ) as mock_cluster, patch("builtins.open") as mock_open, patch("os.path.exists") as mock_exists: # Test None cluster, _ = build_cluster(file_path=None, db_path=db_path) mock_cluster.assert_called() @@ -87,11 +97,9 @@ def test_build_cluster(self): # test_build_cluster_file_not_exist file_path = "test_file" - with patch('neural_solution.backend.cluster.Node'), \ - patch('neural_solution.backend.cluster.Cluster'), \ - patch('builtins.open'), \ - patch('os.path.exists') as mock_exists, \ - patch('neural_solution.utils.logger') as mock_logger: + with patch("neural_solution.backend.cluster.Node"), patch("neural_solution.backend.cluster.Cluster"), patch( + "builtins.open" + ), patch("os.path.exists") as mock_exists, patch("neural_solution.utils.logger") as mock_logger: mock_exists.return_value = False self.assertRaises(Exception, build_cluster, file_path) mock_logger.reset_mock() @@ -103,17 +111,19 @@ def test_synchronized(self): class TestClass: def __init__(self): self.lock = MagicMock() + @synchronized def test_function(self): return True + test_class = TestClass() - with patch.object(test_class, 'lock'): + with patch.object(test_class, "lock"): test_class.test_function() def test_build_workspace(self): task_id = 123 expected_output = os.path.abspath(f"{TASK_WORKSPACE}/{task_id}") - self.assertEqual(build_workspace(path=get_task_workspace(config.workspace) ,task_id=task_id), expected_output) + self.assertEqual(build_workspace(path=get_task_workspace(config.workspace), task_id=task_id), expected_output) def test_is_remote_url(self): self.assertTrue(is_remote_url("http://test.com")) @@ -125,17 +135,18 @@ def test_create_dir(self): create_dir(path) self.assertTrue(os.path.exists(os.path.dirname(path))) - @patch('builtins.open', mock_open(read_data='Save quantized model to /path/to/model.')) + @patch("builtins.open", mock_open(read_data="Save quantized model to /path/to/model.")) def test_get_q_model_path_success(self): - log_path = 'fake_log_path' + log_path = "fake_log_path" q_model_path = get_q_model_path(log_path) - self.assertEqual(q_model_path, '/path/to/model') + self.assertEqual(q_model_path, "/path/to/model") - @patch('builtins.open', mock_open(read_data='No quantized model saved.')) + @patch("builtins.open", mock_open(read_data="No quantized model saved.")) def test_get_q_model_path_failure(self): - log_path = 'fake_log_path' + log_path = "fake_log_path" q_model_path = get_q_model_path(log_path) - self.assertEqual(q_model_path, 'quantized model path not found') + self.assertEqual(q_model_path, "quantized model path not found") + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/neural_solution/test/frontend/fastapi/test_main_server.py b/neural_solution/test/frontend/fastapi/test_main_server.py index 0f3473019b2..c101e616216 100644 --- a/neural_solution/test/frontend/fastapi/test_main_server.py +++ b/neural_solution/test/frontend/fastapi/test_main_server.py @@ -1,18 +1,19 @@ -import unittest import asyncio -from unittest.mock import patch, Mock, MagicMock -from fastapi.testclient import TestClient -from fastapi import WebSocket -from neural_solution.frontend.fastapi.main_server import app, LogEventHandler, start_log_watcher, Observer -import sqlite3 import os import shutil +import sqlite3 +import unittest +from unittest.mock import MagicMock, Mock, patch + +from fastapi import WebSocket +from fastapi.testclient import TestClient from neural_solution.config import config +from neural_solution.frontend.fastapi.main_server import LogEventHandler, Observer, app, start_log_watcher NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db" -TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" +TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log" SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log" @@ -22,38 +23,47 @@ def build_db(): if not os.path.exists(DB_PATH): os.makedirs(DB_PATH) - conn = sqlite3.connect(f'{DB_PATH}/task.db', check_same_thread=False) # sqlite should set this check_same_thread to False + conn = sqlite3.connect( + f"{DB_PATH}/task.db", check_same_thread=False + ) # sqlite should set this check_same_thread to False cursor = conn.cursor() cursor.execute( - 'create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), ' + - 'workers int, status varchar(20), script_url varchar(500), optimized integer, ' + - 'approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))') - cursor.execute('drop table if exists cluster ') - cursor.execute(r'create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,' + - 'node_info varchar(500),' + - 'status varchar(100),' + - 'free_sockets int,' + - 'busy_sockets int,' + - 'total_sockets int)') + "create table if not exists task(id TEXT PRIMARY KEY, arguments varchar(100), " + + "workers int, status varchar(20), script_url varchar(500), optimized integer, " + + "approach varchar(20), requirements varchar(500), result varchar(500), q_model_path varchar(200))" + ) + cursor.execute("drop table if exists cluster ") + cursor.execute( + r"create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT," + + "node_info varchar(500)," + + "status varchar(100)," + + "free_sockets int," + + "busy_sockets int," + + "total_sockets int)" + ) conn.commit() conn.close + def delete_db(): if os.path.exists(DB_PATH): shutil.rmtree(DB_PATH) + def use_db(): def f(func): def fi(*args, **kwargs): build_db() res = func(*args, **kwargs) delete_db() + return fi + return f -class TestMain(unittest.TestCase): +class TestMain(unittest.TestCase): @classmethod def setUpClass(self): if not os.path.exists(TASK_LOG_path): @@ -69,7 +79,7 @@ def test_read_root(self): assert response.status_code == 200 self.assertEqual(response.json(), {"message": "Welcome to Neural Solution!"}) - @patch('neural_solution.frontend.fastapi.main_server.socket') + @patch("neural_solution.frontend.fastapi.main_server.socket") def test_ping(self, mock_socket): response = client.get("/ping") self.assertEqual(response.status_code, 200) @@ -92,11 +102,12 @@ def test_get_description(self): data = { "description": "", } - path = "../../doc" + path = "../../doc" if not os.path.exists(path): os.makedirs(path) with open(os.path.join(path, "user_facing_api.json"), "w") as f: import json + json.dump(data, f) response = client.get("/description") assert response.status_code == 200 @@ -111,7 +122,7 @@ def test_submit_task(self, mock_submit_task): "arguments": ["arg1", "arg2"], "approach": "approach1", "requirements": ["req1", "req2"], - "workers": 3 + "workers": 3, } # test no db case @@ -132,7 +143,6 @@ def test_submit_task(self, mock_submit_task): self.assertIn("successfully", response.json()["status"]) mock_submit_task.assert_called_once() - # test ConnectionRefusedError case mock_submit_task.side_effect = ConnectionRefusedError response = client.post("/task/submit/", json=task) @@ -166,7 +176,7 @@ def test_get_task_by_id(self, mock_submit_task): "arguments": ["arg1", "arg2"], "approach": "approach1", "requirements": ["req1", "req2"], - "workers": 3 + "workers": 3, } response = client.post("/task/submit/", json=task) task_id = response.json()["task_id"] @@ -192,7 +202,7 @@ def test_get_task_status_by_id(self, mock_submit_task): "arguments": ["arg1", "arg2"], "approach": "approach1", "requirements": ["req1", "req2"], - "workers": 3 + "workers": 3, } response = client.post("/task/submit/", json=task) task_id = response.json()["task_id"] @@ -214,8 +224,8 @@ def test_read_logs(self): self.assertIn(task_id, response.text) os.remove(log_path) -class TestLogEventHandler(unittest.TestCase): +class TestLogEventHandler(unittest.TestCase): def setUp(self): self.loop = asyncio.get_event_loop() @@ -231,6 +241,7 @@ def test_init(self): def test_on_modified(self): from neural_solution.config import config + config.workspace = NEURAL_SOLUTION_WORKSPACE mock_websocket = MagicMock() mock_websocket.send_text = MagicMock() @@ -249,7 +260,7 @@ def test_on_modified(self): task_id = "1234" log_path = f"{TASK_LOG_path}/task_{task_id}.txt" event.src_path = log_path - with patch('builtins.open', MagicMock()) as mock_file: + with patch("builtins.open", MagicMock()) as mock_file: mock_file.return_value.__enter__.return_value.seek.return_value = None mock_file.return_value.__enter__.return_value.readlines.return_value = ["test line"] handler.on_modified(event) @@ -259,20 +270,20 @@ def test_on_modified(self): # handler.queue.put_nowait.assert_called_once_with("test line") os.remove(log_path) -class TestStartLogWatcher(unittest.TestCase): +class TestStartLogWatcher(unittest.TestCase): def setUp(self): self.loop = asyncio.get_event_loop() def test_start_log_watcher(self): mock_observer = MagicMock() mock_observer.schedule = MagicMock() - with patch('neural_solution.frontend.fastapi.main_server.Observer', MagicMock(return_value=mock_observer)): + with patch("neural_solution.frontend.fastapi.main_server.Observer", MagicMock(return_value=mock_observer)): observer = start_log_watcher("test_websocket", "1234", 0) self.assertIsInstance(observer, type(mock_observer)) -class TestWebsocketEndpoint(unittest.TestCase): +class TestWebsocketEndpoint(unittest.TestCase): def setUp(self): self.loop = asyncio.get_event_loop() self.client = TestClient(app) @@ -282,5 +293,6 @@ def test_websocket_endpoint(self): # with self.assertRaises(HTTPException): # asyncio.run(websocket_endpoint(WebSocket, "nonexistent_task")) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/frontend/fastapi/test_task_submitter.py b/neural_solution/test/frontend/fastapi/test_task_submitter.py index 48a0f312662..c08c2cd605e 100644 --- a/neural_solution/test/frontend/fastapi/test_task_submitter.py +++ b/neural_solution/test/frontend/fastapi/test_task_submitter.py @@ -1,7 +1,9 @@ +import socket import unittest from unittest.mock import patch -import socket -from neural_solution.frontend.task_submitter import TaskSubmitter, Task + +from neural_solution.frontend.task_submitter import Task, TaskSubmitter + class TestTask(unittest.TestCase): def test_task_creation(self): @@ -18,7 +20,7 @@ def test_task_creation(self): arguments=arguments, approach=approach, requirements=requirements, - workers=workers + workers=workers, ) self.assertEqual(task.script_url, script_url) @@ -28,15 +30,17 @@ def test_task_creation(self): self.assertEqual(task.requirements, requirements) self.assertEqual(task.workers, workers) + class TestTaskSubmitter(unittest.TestCase): - @patch('socket.socket') + @patch("socket.socket") def test_submit_task(self, mock_socket): task_submitter = TaskSubmitter() - task_id = '1234' + task_id = "1234" task_submitter.submit_task(task_id) - mock_socket.return_value.connect.assert_called_once_with(('localhost', 2222)) + mock_socket.return_value.connect.assert_called_once_with(("localhost", 2222)) mock_socket.return_value.send.assert_called_once_with(b'{"task_id": "1234"}') mock_socket.return_value.close.assert_called_once() -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/frontend/fastapi/test_utils.py b/neural_solution/test/frontend/fastapi/test_utils.py index d60364f0d26..abb57fff92b 100644 --- a/neural_solution/test/frontend/fastapi/test_utils.py +++ b/neural_solution/test/frontend/fastapi/test_utils.py @@ -1,21 +1,27 @@ -import unittest import os import shutil -from unittest.mock import patch, mock_open +import unittest +from unittest.mock import mock_open, patch from neural_solution.frontend.utility import ( - serialize, deserialize, - get_cluster_info,get_cluster_table, - get_res_during_tuning, get_baseline_during_tuning, - check_log_exists, list_to_string) + check_log_exists, + deserialize, + get_baseline_during_tuning, + get_cluster_info, + get_cluster_table, + get_res_during_tuning, + list_to_string, + serialize, +) + NEURAL_SOLUTION_WORKSPACE = os.path.join(os.getcwd(), "ns_workspace") DB_PATH = NEURAL_SOLUTION_WORKSPACE + "/db/task.db" -TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" +TASK_WORKSPACE = NEURAL_SOLUTION_WORKSPACE + "/task_workspace" TASK_LOG_path = NEURAL_SOLUTION_WORKSPACE + "/task_log" SERVE_LOG_PATH = NEURAL_SOLUTION_WORKSPACE + "/serve_log" -class TestMyModule(unittest.TestCase): +class TestMyModule(unittest.TestCase): @classmethod def setUpClass(self): if not os.path.exists(TASK_LOG_path): @@ -35,39 +41,41 @@ def test_deserialize(self): expected_result = {"key": "value"} self.assertEqual(deserialize(request), expected_result) - @patch('sqlite3.connect') + @patch("sqlite3.connect") def test_get_cluster_info(self, mock_connect): mock_cursor = mock_connect().cursor.return_value mock_cursor.fetchall.return_value = [(1, "node info", "status", 1, 2, 3)] expected_result = {"Cluster info": [(1, "node info", "status", 1, 2, 3)]} self.assertEqual(get_cluster_info(TASK_LOG_path), expected_result) - @patch('sqlite3.connect') + @patch("sqlite3.connect") def test_get_cluster_table(self, mock_connect): mock_cursor = mock_connect().cursor.return_value mock_cursor.fetchall.return_value = [(1, "node info", "status", 1, 2, 3)] - expected_result = ('\n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - ' \n' - '
NodeNode infostatusfree workersbusy workerstotal workers
1node infostatus123
') + expected_result = ( + '\n' + " \n" + ' \n' + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
NodeNode infostatusfree workersbusy workerstotal workers
1node infostatus123
" + ) self.assertEqual(get_cluster_table(TASK_LOG_path), expected_result) def test_get_res_during_tuning(self): @@ -102,5 +110,6 @@ def test_list_to_string(self): expected_result = "Hello Neural Solution" self.assertEqual(list_to_string(lst), expected_result) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/neural_solution/test/test_logger.py b/neural_solution/test/test_logger.py index db1df1d4c78..cd184a12e48 100644 --- a/neural_solution/test/test_logger.py +++ b/neural_solution/test/test_logger.py @@ -1,7 +1,9 @@ """Tests for logging utilities.""" -from neural_solution.utils import logger import unittest +from neural_solution.utils import logger + + class TestLogger(unittest.TestCase): def test_logger(self): logger.log(0, "call logger log function.") @@ -20,15 +22,13 @@ def test_logger(self): logger.warning({"msg": "call logger warning function"}) logger.warning(["call logger warning function", "done"]) logger.warning(("call logger warning function", "done")) - logger.warning({"msg": {('bert', "embedding"): {'weight': {'dtype': ['unint8', 'int8']}}}}) - logger.warning({"msg": {('bert', "embedding"): {'op': ('a', 'b')}}}) + logger.warning({"msg": {("bert", "embedding"): {"weight": {"dtype": ["unint8", "int8"]}}}}) + logger.warning({"msg": {("bert", "embedding"): {"op": ("a", "b")}}}) # the following log will not be prettified logger.warning([{"msg": "call logger warning function"}, {"msg2": "done"}]) logger.warning(({"msg": "call logger warning function"}, {"msg2": "done"})) - logger.warning(({"msg": [{"sub_msg":"call logger"}, - {"sub_msg2":"call warning function"}]}, - {"msg2": "done"})) + logger.warning(({"msg": [{"sub_msg": "call logger"}, {"sub_msg2": "call warning function"}]}, {"msg2": "done"})) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/neural_solution/utils/__init__.py b/neural_solution/utils/__init__.py index 3f778a4dcd7..0415332cb3f 100644 --- a/neural_solution/utils/__init__.py +++ b/neural_solution/utils/__init__.py @@ -15,4 +15,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""All common functions for both backend and frontend.""" \ No newline at end of file +"""All common functions for both backend and frontend.""" diff --git a/neural_solution/utils/logger.py b/neural_solution/utils/logger.py index 2bfbbba7dc7..03c8b86be56 100644 --- a/neural_solution/utils/logger.py +++ b/neural_solution/utils/logger.py @@ -17,8 +17,8 @@ """Logger: handles logging functionalities.""" -import os import logging +import os class Logger(object): @@ -35,13 +35,11 @@ def __new__(cls): def _log(self): """Set up the logger format and handler.""" - LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper() + LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() self._logger = logging.getLogger("neural_compressor") self._logger.handlers.clear() self._logger.setLevel(LOGLEVEL) - formatter = logging.Formatter( - '%(asctime)s [%(levelname)s] %(message)s', - "%Y-%m-%d %H:%M:%S") + formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S") streamHandler = logging.StreamHandler() streamHandler.setFormatter(formatter) self._logger.addHandler(streamHandler) @@ -54,25 +52,16 @@ def get_logger(self): def _pretty_dict(value, indent=0): """Make the logger dict pretty.""" - prefix = '\n' + ' ' * (indent + 4) + prefix = "\n" + " " * (indent + 4) if isinstance(value, dict): - items = [ - prefix + repr(key) + ': ' + _pretty_dict(value[key], indent + 4) - for key in value - ] - return '{%s}' % (','.join(items) + '\n' + ' ' * indent) + items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value] + return "{%s}" % (",".join(items) + "\n" + " " * indent) elif isinstance(value, list): - items = [ - prefix + _pretty_dict(item, indent + 4) - for item in value - ] - return '[%s]' % (','.join(items) + '\n' + ' ' * indent) + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "[%s]" % (",".join(items) + "\n" + " " * indent) elif isinstance(value, tuple): - items = [ - prefix + _pretty_dict(item, indent + 4) - for item in value - ] - return '(%s)' % (','.join(items) + '\n' + ' ' * indent) + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "(%s)" % (",".join(items) + "\n" + " " * indent) else: return repr(value) @@ -84,7 +73,7 @@ def _pretty_dict(value, indent=0): def log(level, msg, *args, **kwargs): """Output log with the level as a parameter.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().log(level, line, *args, **kwargs) else: Logger().get_logger().log(level, msg, *args, **kwargs) @@ -93,7 +82,7 @@ def log(level, msg, *args, **kwargs): def debug(msg, *args, **kwargs): """Output log with the debug level.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().debug(line, *args, **kwargs) else: Logger().get_logger().debug(msg, *args, **kwargs) @@ -102,7 +91,7 @@ def debug(msg, *args, **kwargs): def error(msg, *args, **kwargs): """Output log with the error level.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().error(line, *args, **kwargs) else: Logger().get_logger().error(msg, *args, **kwargs) @@ -111,7 +100,7 @@ def error(msg, *args, **kwargs): def fatal(msg, *args, **kwargs): """Output log with the fatal level.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().fatal(line, *args, **kwargs) else: Logger().get_logger().fatal(msg, *args, **kwargs) @@ -120,7 +109,7 @@ def fatal(msg, *args, **kwargs): def info(msg, *args, **kwargs): """Output log with the info level.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().info(line, *args, **kwargs) else: Logger().get_logger().info(msg, *args, **kwargs) @@ -129,7 +118,7 @@ def info(msg, *args, **kwargs): def warn(msg, *args, **kwargs): """Output log with the warning level.""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().warning(line, *args, **kwargs) else: Logger().get_logger().warning(msg, *args, **kwargs) @@ -138,7 +127,7 @@ def warn(msg, *args, **kwargs): def warning(msg, *args, **kwargs): """Output log with the warining level (Alias of the method warn).""" if isinstance(msg, dict): - for _, line in enumerate(_pretty_dict(msg).split('\n')): + for _, line in enumerate(_pretty_dict(msg).split("\n")): Logger().get_logger().warning(line, *args, **kwargs) else: - Logger().get_logger().warning(msg, *args, **kwargs) \ No newline at end of file + Logger().get_logger().warning(msg, *args, **kwargs) diff --git a/neural_solution/utils/utility.py b/neural_solution/utils/utility.py index 8c3999e343d..13ef5da1558 100644 --- a/neural_solution/utils/utility.py +++ b/neural_solution/utils/utility.py @@ -14,8 +14,9 @@ """Neural Solution utility.""" -import os import json +import os + def get_db_path(workspace="./"): """Get the database path. @@ -29,6 +30,7 @@ def get_db_path(workspace="./"): db_path = os.path.join(workspace, "db", "task.db") return os.path.abspath(db_path) + def get_task_workspace(workspace="./"): """Get the workspace of task. @@ -40,6 +42,7 @@ def get_task_workspace(workspace="./"): """ return os.path.join(workspace, "task_workspace") + def get_task_log_workspace(workspace="./"): """Get the log workspace for task. @@ -51,6 +54,7 @@ def get_task_log_workspace(workspace="./"): """ return os.path.join(workspace, "task_log") + def get_serve_log_workspace(workspace="./"): """Get log workspace for service. @@ -62,6 +66,7 @@ def get_serve_log_workspace(workspace="./"): """ return os.path.join(workspace, "serve_log") + def dict_to_str(d): """Covert a dict object to a string object. @@ -72,4 +77,4 @@ def dict_to_str(d): str: string """ result = json.dumps(d) - return result \ No newline at end of file + return result