Skip to content

Commit

Permalink
Reformat neural solution code (#1054)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 3, 2023
1 parent 4d4c295 commit f2fec43
Show file tree
Hide file tree
Showing 42 changed files with 1,252 additions and 871 deletions.
2 changes: 1 addition & 1 deletion neural_solution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

"""Neural Solution."""
from neural_solution.utils import logger
from neural_solution.utils import logger
4 changes: 2 additions & 2 deletions neural_solution/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

"""Neural Solution backend."""
from neural_solution.backend.cluster import Cluster
from neural_solution.backend.result_monitor import ResultMonitor
from neural_solution.backend.scheduler import Scheduler
from neural_solution.backend.task_monitor import TaskMonitor
from neural_solution.backend.cluster import Cluster
from neural_solution.backend.task_db import TaskDB
from neural_solution.backend.task_monitor import TaskMonitor
68 changes: 34 additions & 34 deletions neural_solution/backend/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

"""Neural Solution cluster."""
import threading
import sqlite3
import threading
from collections import Counter
from typing import List
from neural_solution.backend.utils.utility import synchronized, create_dir

from neural_solution.backend.utils.utility import create_dir, synchronized
from neural_solution.utils import logger
from collections import Counter


class Cluster:
"""Cluster resource management based on sockets."""
Expand All @@ -35,7 +37,7 @@ def __init__(self, node_lst=[], db_path=None):
self.socket_queue = []
self.db_path = db_path
create_dir(db_path)
self.conn = sqlite3.connect(f'{db_path}', check_same_thread=False)
self.conn = sqlite3.connect(f"{db_path}", check_same_thread=False)
self.initial_cluster_from_node_lst(node_lst)
self.lock = threading.Lock()

Expand Down Expand Up @@ -63,17 +65,16 @@ def reserve_resource(self, task):
logger.info(f"[Cluster] Assign {reserved_resource_lst} to task {task.task_id}")
return reserved_resource_lst


@synchronized
def free_resource(self, reserved_resource_lst):
"""Free the resource by adding the previous occupied resources to the socket queue."""
self.socket_queue += reserved_resource_lst
counts = Counter(int(item.split()[0]) for item in reserved_resource_lst)
free_resources = {}
for node_id, count in counts.items():
free_resources[node_id] = count
free_resources[node_id] = count
for node_id, count in counts.items():
free_resources[node_id] = count
free_resources[node_id] = count
for node_id in free_resources:
sql = """
UPDATE cluster
Expand Down Expand Up @@ -105,43 +106,43 @@ def initial_cluster_from_node_lst(self, node_lst):
node_lst (List): the node list.
"""
# sqlite should set this check_same_thread to False
self.conn = sqlite3.connect(f'{self.db_path}', check_same_thread=False)
self.conn = sqlite3.connect(f"{self.db_path}", check_same_thread=False)
self.cursor = self.conn.cursor()
self.cursor.execute('drop table if exists cluster ')
self.cursor.execute(r'create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,' +
'node_info varchar(500),' +
'status varchar(100),' +
'free_sockets int,' +
'busy_sockets int,' +
'total_sockets int)')
self.cursor.execute("drop table if exists cluster ")
self.cursor.execute(
r"create table cluster(id INTEGER PRIMARY KEY AUTOINCREMENT,"
+ "node_info varchar(500),"
+ "status varchar(100),"
+ "free_sockets int,"
+ "busy_sockets int,"
+ "total_sockets int)"
)
self.node_lst = node_lst
for index, node in enumerate(self.node_lst):
self.socket_queue += [str(index+1) + " " + node.name] * node.num_sockets
self.cursor.execute(r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)" +
"values ('{}', '{}', {}, {}, {})".format(repr(node).replace("Node", f"Node{index+1}"),
"alive",
node.num_sockets,
0,
node.num_sockets))
self.socket_queue += [str(index + 1) + " " + node.name] * node.num_sockets
self.cursor.execute(
r"insert into cluster(node_info, status, free_sockets, busy_sockets, total_sockets)"
+ "values ('{}', '{}', {}, {}, {})".format(
repr(node).replace("Node", f"Node{index+1}"), "alive", node.num_sockets, 0, node.num_sockets
)
)

self.conn.commit()
logger.info(f"socket_queue: {self.socket_queue}")


class Node:
"""Node definition."""

name: str = "unknown_node"
ip: str = "unknown_ip"
num_sockets: int = 0
num_cores_per_socket: int = 0
num_gpus: int = 0 # For future use

def __init__(self,
name: str,
ip: str = "unknown_ip",
num_sockets: int = 0,
num_cores_per_socket: int = 0,
num_gpus: int = 0) -> None:
num_gpus: int = 0 # For future use

def __init__(
self, name: str, ip: str = "unknown_ip", num_sockets: int = 0, num_cores_per_socket: int = 0, num_gpus: int = 0
) -> None:
"""Init node.
hostfile template:
Expand All @@ -167,8 +168,7 @@ def __repr__(self) -> str:
Returns:
str: node info.
"""
return f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) " \
return (
f"Node: {self.name}(ip: {self.ip}) has {self.num_sockets} socket(s) "
f"and each socket has {self.num_cores_per_socket} cores."



)
11 changes: 6 additions & 5 deletions neural_solution/backend/result_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Neural Solution result monitor."""

import socket
from neural_solution.backend.utils.utility import serialize, deserialize
from neural_solution.utils import logger

from neural_solution.backend.task_db import TaskDB
from neural_solution.backend.utils.utility import deserialize, serialize
from neural_solution.utils import logger


class ResultMonitor:
"""ResultMonitor is a thread that monitors the coming task results and update the task collection in the TaskDb.
Expand All @@ -40,7 +42,7 @@ def __init__(self, port, task_db: TaskDB):

def wait_result(self):
"""Monitor the task results and update them in the task db and send back to studio."""
self.s.bind(("localhost", self.port)) # open a port as the serving port for results
self.s.bind(("localhost", self.port)) # open a port as the serving port for results
self.s.listen(10)
while True:
logger.info("[ResultMonitor] waiting for results...")
Expand All @@ -53,7 +55,7 @@ def wait_result(self):
c.close()
continue
logger.info("[ResultMonitor] getting result: {}".format(result))
logger.info("[ResultMonitor] getting q_model path: {}".format(result['q_model_path']))
logger.info("[ResultMonitor] getting q_model path: {}".format(result["q_model_path"]))
self.task_db.update_q_model_path_and_result(result["task_id"], result["q_model_path"], result["result"])
c.close()
# TODO send back the result to the studio
Expand All @@ -63,4 +65,3 @@ def query_task_status(self, task_id):
"""Synchronize query on the task status."""
# TODO send back the result to the studio? RPC for query?
logger.info(self.task_db.lookup_task_status(task_id))

57 changes: 33 additions & 24 deletions neural_solution/backend/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

"""Main backend runner."""
import threading
import argparse
import threading

from neural_solution.backend import TaskDB, Scheduler, TaskMonitor, ResultMonitor
from neural_solution.utils import logger
from neural_solution.backend import ResultMonitor, Scheduler, TaskDB, TaskMonitor
from neural_solution.backend.utils.utility import build_cluster, get_db_path
from neural_solution.config import config
from neural_solution.utils import logger


def parse_args(args=None):
"""Parse the command line options.
Expand All @@ -30,24 +31,25 @@ def parse_args(args=None):
Returns:
argparse.Namespace: arguments.
"""
parser = argparse.ArgumentParser(description="Neural Solution runner automatically schedules multiple inc tasks and\
executes multi-node distributed tuning.")

parser.add_argument("-H", "--hostfile", type=str, default=None, \
help="Path to the host file which contains all available nodes.")
parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, \
help="Port to monitor task.")
parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, \
help="Port to monitor result.")
parser.add_argument("-WS", "--workspace", type=str, default="./", \
help="Work space.")
parser.add_argument("-CEN", "--conda_env_name", type=str, default="inc", \
help="Conda environment for task execution")
parser.add_argument("-UP", "--upload_path", type=str, default="./examples", \
help="Custom example path.")
parser = argparse.ArgumentParser(
description="Neural Solution runner automatically schedules multiple inc tasks and\
executes multi-node distributed tuning."
)

parser.add_argument(
"-H", "--hostfile", type=str, default=None, help="Path to the host file which contains all available nodes."
)
parser.add_argument("-TMP", "--task_monitor_port", type=int, default=2222, help="Port to monitor task.")
parser.add_argument("-RMP", "--result_monitor_port", type=int, default=3333, help="Port to monitor result.")
parser.add_argument("-WS", "--workspace", type=str, default="./", help="Work space.")
parser.add_argument(
"-CEN", "--conda_env_name", type=str, default="inc", help="Conda environment for task execution"
)
parser.add_argument("-UP", "--upload_path", type=str, default="./examples", help="Custom example path.")

return parser.parse_args(args=args)


def main(args=None):
"""Implement the main entry of backend.
Expand All @@ -72,9 +74,15 @@ def main(args=None):
t_rm = threading.Thread(target=rm.wait_result)
config.workspace = args.workspace

ts = Scheduler(cluster, task_db, args.result_monitor_port, \
conda_env_name=args.conda_env_name, upload_path=args.upload_path, config=config, \
num_threads_per_process=num_threads_per_process)
ts = Scheduler(
cluster,
task_db,
args.result_monitor_port,
conda_env_name=args.conda_env_name,
upload_path=args.upload_path,
config=config,
num_threads_per_process=num_threads_per_process,
)
t_ts = threading.Thread(target=ts.schedule_tasks)

tm = TaskMonitor(args.task_monitor_port, task_db)
Expand All @@ -83,14 +91,15 @@ def main(args=None):
t_rm.start()
t_ts.start()
t_tm.start()
logger.info("task monitor port {} and result monitor port {}".\
format(args.task_monitor_port, args.result_monitor_port))
logger.info(
"task monitor port {} and result monitor port {}".format(args.task_monitor_port, args.result_monitor_port)
)
logger.info("server start...")

t_rm.join()
t_ts.join()
t_tm.join()


if __name__ == '__main__':
if __name__ == "__main__":
main()
Loading

0 comments on commit f2fec43

Please sign in to comment.