Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix nnictl bugs and add new feature #75

Merged
merged 14 commits into from
Sep 19, 2018
4 changes: 2 additions & 2 deletions tools/nnicmd/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('classArgs'): {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('speed'): int
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
Expand Down
30 changes: 25 additions & 5 deletions tools/nnicmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick
from .url_utils import cluster_metadata_url, experiment_url
from .config_utils import Config
from .common_utils import get_yml_content, get_json_content, print_error, print_normal
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_process
from .constants import EXPERIMENT_SUCCESS_INFO, STDOUT_FULL_PATH, STDERR_FULL_PATH, LOG_DIR, REST_PORT, ERROR_INFO, NORMAL_INFO
from .webui_utils import start_web_ui, check_web_ui

Expand All @@ -40,7 +40,8 @@ def start_rest_server(port, platform, mode, experiment_id=None):
print_normal('Checking experiment...')
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if rest_port and check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if rest_port and running:
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
Expand All @@ -66,7 +67,12 @@ def set_trial_config(experiment_config, port):
value_dict['gpuNum'] = experiment_config['trial']['gpuNum']
request_data['trial_config'] = value_dict
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
return True if response.status_code == 200 else False
if response.status_code == 200:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

200 is not the only succeed code per protocol, suggest to use a function wrap this (such as http_succeed(code)), this may change later,

https://www.restapitutorial.com/httpstatuscodes.html

return True
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return False

def set_local_config(experiment_config, port):
'''set local configuration'''
Expand All @@ -82,6 +88,8 @@ def set_remote_config(experiment_config, port):
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message

#set trial_config
Expand Down Expand Up @@ -117,11 +125,22 @@ def set_experiment(experiment_config, mode, port):
{'key': 'trial_config', 'value': value_dict})

response = rest_post(experiment_url(port), json.dumps(request_data), 20)
return response if response.status_code == 200 else None
if response.status_code == 200:
return response
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return None

def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config()
#Check if there is an experiment running
origin_rest_pid = nni_config.get_config('restServerPid')
if origin_rest_pid and detect_process(origin_rest_pid):
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
# start rest server
rest_process = start_rest_server(REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid)
Expand All @@ -144,7 +163,8 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No

# check rest server
print_normal('Checking restful server...')
if check_rest_server(REST_PORT):
running, _ = check_rest_server(REST_PORT)
if running:
print_normal('Restful server start success!')
else:
print_error('Restful server start failed!')
Expand Down
3 changes: 2 additions & 1 deletion tools/nnicmd/launcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def parse_tuner_content(experiment_config):

if experiment_config['tuner'].get('builtinTunerName') and experiment_config['tuner'].get('classArgs'):
experiment_config['tuner']['className'] = tuner_class_name_dict.get(experiment_config['tuner']['builtinTunerName'])
experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName'])
if tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName']):
experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName'])
elif experiment_config['tuner'].get('codeDir') and experiment_config['tuner'].get('classFileName') and experiment_config['tuner'].get('className'):
if not os.path.exists(os.path.join(experiment_config['tuner']['codeDir'], experiment_config['tuner']['classFileName'])):
raise ValueError('Tuner file directory is not valid!')
Expand Down
6 changes: 4 additions & 2 deletions tools/nnicmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from .nnictl_utils import *

def nni_help_info(*args):
print('please run "nnictl --help" to see nnictl guidance')
print('please run "nnictl {positional argument} --help" to see nnictl guidance')

def parse_args():
'''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nni ctl', description='use nni control')
parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments')
parser.set_defaults(func=nni_help_info)

# create subparsers for args with sub values
Expand Down Expand Up @@ -95,6 +95,8 @@ def parse_args():
parser_experiment_subparsers = parser_experiment.add_subparsers()
parser_experiment_show = parser_experiment_subparsers.add_parser('show', help='show the information of experiment')
parser_experiment_show.set_defaults(func=list_experiment)
parser_experiment_status = parser_experiment_subparsers.add_parser('status', help='show the status of experiment')
parser_experiment_status.set_defaults(func=experiment_status)

#parse config command
parser_config = subparsers.add_parser('config', help='get config information')
Expand Down
25 changes: 20 additions & 5 deletions tools/nnicmd/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def check_rest(args):
'''check if restful server is running'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if not running:
print_normal('Restful server is running...')
else:
print_normal('Restful server is not running...')
Expand All @@ -62,7 +63,8 @@ def stop_experiment(args):
print_normal('Experiment is not running...')
stop_web_ui()
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(experiment_url(rest_port), 20)
if not response or response.status_code != 200:
print_error('Stop experiment failed!')
Expand All @@ -82,7 +84,8 @@ def trial_ls(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(trial_jobs_url(rest_port), 20)
if response and response.status_code == 200:
content = json.loads(response.text)
Expand All @@ -102,7 +105,8 @@ def trial_kill(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.trialid), 20)
if response and response.status_code == 200:
print(response.text)
Expand All @@ -119,7 +123,8 @@ def list_experiment(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
content = convert_time_stamp_to_date(json.loads(response.text))
Expand All @@ -129,6 +134,16 @@ def list_experiment(args):
else:
print_error('Restful server is not running...')

def experiment_status(args):
'''Show the status of experiment'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
else:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))

def get_log_content(file_name, cmds):
'''use cmds to read config content'''
if os.path.exists(file_name):
Expand Down
10 changes: 5 additions & 5 deletions tools/nnicmd/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ def check_rest_server(rest_port):
response = rest_get(check_status_url(rest_port), 20)
if response:
if response.status_code == 200:
return True
return True, response
else:
return False
return False, response
else:
time.sleep(3)
return False
return False, response

def check_rest_server_quick(rest_port):
'''Check if restful server is ready, only check once'''
response = rest_get(check_status_url(rest_port), 5)
if response and response.status_code == 200:
return True
return False
return True, response
return False, None
3 changes: 2 additions & 1 deletion tools/nnicmd/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def update_experiment_profile(key, value):
'''call restful server to update experiment profile'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
experiment_profile = json.loads(response.text)
Expand Down