Skip to content

Commit

Permalink
Merge pull request #2 from dingzhaohan/devel
Browse files Browse the repository at this point in the history
Develop DP-GEN for Ali
  • Loading branch information
AnguseZhang authored Dec 8, 2019
2 parents eaaa68a + 840b5a4 commit d0a40bd
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 37 deletions.
57 changes: 57 additions & 0 deletions dpgen/dispatcher/ALI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from aliyunsdkecs.request.v20140526.DescribeInstancesRequest import DescribeInstancesRequest
from aliyunsdkcore.client import AcsClient
from aliyunsdkcore.acs_exception.exceptions import ClientException
from aliyunsdkcore.acs_exception.exceptions import ServerException
from aliyunsdkecs.request.v20140526.RunInstancesRequest import RunInstancesRequest
from aliyunsdkecs.request.v20140526.DeleteInstancesRequest import DeleteInstancesRequest
import time
import json
from dpgen.dispatcher.Batch import Batch
from dpgen.dispatcher.JobStatus import JobStatus
from dpgen.dispatcher.Shell import Shell
from dpgen.dispatcher.SSHContext import SSHContext, SSHSession

class ALI():
def __init__(self, adata):
self.ip_list = None
self.regionID = None
self.instance_list = None
self.AccessKey_ID = adata["AccessKey_ID"]
self.AccessKey_Secret = adata["AccessKey_Secret"]

def create_machine(self, instance_number, instance_type):
if True:
client = AcsClient(self.AccessKey_ID,self.AccessKey_Secret, 'cn-hangzhou')
request = RunInstancesRequest()
request.set_accept_format('json')
request.set_UniqueSuffix(True)
request.set_Password("975481DING!")
request.set_Amount(instance_number)
request.set_LaunchTemplateName(instance_type + '_cn-hangzhou_i')
response = client.do_action_with_exception(request)
response = json.loads(response)
self.instance_list = response["InstanceIdSets"]["InstanceIdSet"]
time.sleep(50)
request = DescribeInstancesRequest()
request.set_accept_format('json')
request.set_InstanceIds(self.instance_list)
response = client.do_action_with_exception(request)
response = json.loads(response)

ip = []
for i in range(len(response["Instances"]["Instance"])):
ip.append(response["Instances"]["Instance"][i]["PublicIpAddress"]['IpAddress'][0])
self.ip_list = ip
# print(self.ip_list, self.instance_list)
return self.ip_list, self.instance_list
else:
return "create failed"

def delete_machine(self, instance_id):
client = AcsClient(self.AccessKey_ID,self.AccessKey_Secret, 'cn-hangzhou')
request = DeleteInstancesRequest()
request.set_accept_format('json')
request.set_InstanceIds(instance_id)
request.set_Force(True)
response = client.do_action_with_exception(request)

36 changes: 33 additions & 3 deletions dpgen/dispatcher/Dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def all_finished(self,


class JobRecord(object):
def __init__ (self, path, task_chunks, fname = 'job_record.json'):
def __init__ (self, path, task_chunks, fname = 'job_record.json', ip=None):
self.path = os.path.abspath(path)
self.fname = os.path.join(self.path, fname)
self.task_chunks = task_chunks
Expand All @@ -232,9 +232,13 @@ def record_remote_context(self,
chunk_hash,
local_root,
remote_root,
job_uuid):
job_uuid,
ip=None):
self.valid_hash(chunk_hash)
self.record[chunk_hash]['context'] = [local_root, remote_root, job_uuid]
if not ip:
self.record[chunk_hash]['context'] = [local_root, remote_root, job_uuid, ip]
else:
self.record[chunk_hash]['context'] = [local_root, remote_root, job_uuid]

def get_uuid(self, chunk_hash):
self.valid_hash(chunk_hash)
Expand Down Expand Up @@ -305,3 +309,29 @@ def make_dispatcher(mdata):
context_type = 'lazy-local'
disp = Dispatcher(mdata, context_type=context_type, batch_type=batch_type)
return disp

def make_dispatchers(num, mdata):
dispatchers = []
for i in range(num):
try:
hostname = mdata['hostname'][i]
context_type = 'ssh'
except:
context_type = 'local'
try:
batch_type = mdata['batch']
except:
dlog.info('cannot find key "batch" in machine file, try to use deprecated key "machine_type"')
batch_type = mdata['machine_type']
try:
lazy_local = mdata['lazy_local']
except:
lazy_local = False
if lazy_local and context_type == 'local':
dlog.info('Dispatcher switches to the lazy local mode')
context_type = 'lazy-local'
remote_profile = mdata.copy()
remote_profile['hostname'] = hostname
disp = Dispatcher(remote_profile, context_type=context_type, batch_type=batch_type, job_record='jr%d.json' %i)
dispatchers.append(disp)
return dispatchers
186 changes: 152 additions & 34 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from dpgen.remote.group_jobs import group_slurm_jobs
from dpgen.remote.group_jobs import group_local_jobs
from dpgen.remote.decide_machine import decide_train_machine, decide_fp_machine, decide_model_devi_machine
from dpgen.dispatcher.Dispatcher import Dispatcher, make_dispatcher
from dpgen.dispatcher.Dispatcher import Dispatcher, make_dispatcher, make_dispatchers, _split_tasks
from dpgen.dispatcher.ALI import ALI
from dpgen.util import sepline
from dpgen import ROOT_PATH
from pymatgen.io.vasp import Incar,Kpoints,Potcar
Expand Down Expand Up @@ -340,6 +341,20 @@ def detect_batch_size(batch_size, system=None):
else:
raise RuntimeError("Unsupported batch size")

def run_ALI(stage, num_of_instance, adata):
if stage == "train":
instance_type = "ecs.gn5-c8g1.2xlarge"
elif stage == "model_devi":
instance_type = "ecs.gn5-c8g1.2xlarge"
elif stage == "fp":
instance_type = "ecs.c6.2xlarge"
ali = ALI(adata)
return ali.create_machine(num_of_instance, instance_type)

def exit_ALI(instance_id, adata):
ali = ALI(adata)
ali.delete_machine(instance_id)

def run_train (iter_index,
jdata,
mdata,
Expand Down Expand Up @@ -444,16 +459,47 @@ def run_train (iter_index,
except:
train_group_size = 1

dispatcher.run_jobs(mdata['train_resources'],
commands,
work_path,
run_tasks,
train_group_size,
trans_comm_data,
forward_files,
backward_files,
outlog = 'train.log',
errlog = 'train.log')
if mdata['train_machine']['type'] == 'ALI':
task_chunks = _split_tasks(run_tasks, train_group_size)
nchunks = len(task_chunks)
ip, instance_id = run_ALI('train', nchunks, mdata['ali_auth'])
mdata['train_machine']['hostname'] = ip
disp = make_dispatchers(nchunks, mdata['train_machine'])
job_handlers = []
for ii in range(nchunks):
job_handler = disp[ii].submit_jobs(mdata['train_resources'],
commands,
work_path,
task_chunks[ii],
train_group_size,
trans_comm_data,
forward_files,
backward_files,
outlog = 'train.log',
errlog = 'train.log')
job_handlers.append(job_handler)

while True:
cnt = 0
for ii in range(nchunks):
if disp[ii].all_finished(job_handlers[ii]):
cnt += 1
if cnt == nchunks:
break
else:
time.sleep(10)
exit_ALI(instance_id, mdata['ali_auth'])
else:
dispatcher.run_jobs(mdata['train_resources'],
commands,
work_path,
run_tasks,
train_group_size,
trans_comm_data,
forward_files,
backward_files,
outlog = 'train.log',
errlog = 'train.log')


def post_train (iter_index,
Expand Down Expand Up @@ -903,16 +949,47 @@ def run_model_devi (iter_index,
forward_files += ['input.plumed']
backward_files += ['output.plumed']

dispatcher.run_jobs(mdata['model_devi_resources'],
commands,
work_path,
run_tasks,
model_devi_group_size,
model_names,
forward_files,
backward_files,
outlog = 'model_devi.log',
errlog = 'model_devi.log')
if mdata['model_devi_machine']['type'] == 'ALI':
task_chunks = _split_tasks(run_tasks, model_devi_group_size)
nchunks = len(task_chunks)
ip, instance_id = run_ALI('model_devi', nchunks, mdata['ali_auth'])
mdata['model_devi_machine']['hostname'] = ip
disp = make_dispatchers(nchunks, mdata['model_devi_machine'])
job_handlers = []
for ii in range(nchunks):
job_handler = disp[ii].submit_jobs(mdata['model_devi_resources'],
commands,
work_path,
task_chunks[ii],
model_devi_group_size,
model_names,
forward_files,
backward_files,
outlog = 'model_devi.log',
errlog = 'model_devi.log')
job_handlers.append(job_handler)

while True:
cnt = 0
for ii in range(nchunks):
if disp[ii].all_finished(job_handlers[ii]):
cnt += 1
if cnt == nchunks:
break
else:
time.sleep(10)
exit_ALI(instance_id, mdata['ali_auth'])
else:
dispatcher.run_jobs(mdata['model_devi_resources'],
commands,
work_path,
run_tasks,
model_devi_group_size,
model_names,
forward_files,
backward_files,
outlog = 'model_devi.log',
errlog = 'model_devi.log')


def post_model_devi (iter_index,
Expand Down Expand Up @@ -1485,16 +1562,48 @@ def run_fp_inner (iter_index,
# fp_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in fp_run_tasks]

dispatcher.run_jobs(mdata['fp_resources'],
[fp_command],
work_path,
run_tasks,
fp_group_size,
forward_common_files,
forward_files,
backward_files,
outlog = log_file,
errlog = log_file)
if mdata['fp_machine']['type'] == 'ALI':
task_chunks = _split_tasks(run_tasks, fp_group_size)
nchunks = len(task_chunks)
ip, instance_id = run_ALI('fp', nchunks, mdata['ali_auth'])
mdata['fp_machine']['hostname'] = ip
disp = make_dispatchers(nchunks, mdata['fp_machine'])
job_handlers = []
for ii in range(nchunks):
job_handler = disp[ii].submit_jobs(mdata['fp_resources'],
[fp_command],
work_path,
task_chunks[ii],
fp_group_size,
forward_common_files,
forward_files,
backward_files,
outlog = log_file,
errlog = log_file)
job_handlers.append(job_handler)

while True:
cnt = 0
for ii in range(nchunks):
if disp[ii].all_finished(job_handlers[ii]):
cnt += 1
if cnt == nchunks:
break
else:
time.sleep(10)
exit_ALI(instance_id, mdata['ali_auth'])

else:
dispatcher.run_jobs(mdata['fp_resources'],
[fp_command],
work_path,
run_tasks,
fp_group_size,
forward_common_files,
forward_files,
backward_files,
outlog = log_file,
errlog = log_file)


def run_fp (iter_index,
Expand Down Expand Up @@ -1906,7 +2015,10 @@ def run_iter (param_file, machine_file) :
elif jj == 1 :
log_iter ("run_train", ii, jj)
mdata = decide_train_machine(mdata)
disp = make_dispatcher(mdata['train_machine'])
if mdata['train_machine']['type'] == 'ALI':
disp = []
else:
disp = make_dispatcher(mdata['train_machine'])
run_train (ii, jdata, mdata, disp)
elif jj == 2 :
log_iter ("post_train", ii, jj)
Expand All @@ -1919,7 +2031,10 @@ def run_iter (param_file, machine_file) :
elif jj == 4 :
log_iter ("run_model_devi", ii, jj)
mdata = decide_model_devi_machine(mdata)
disp = make_dispatcher(mdata['model_devi_machine'])
if mdata['model_devi_machine']['type'] == 'ALI':
disp = []
else:
disp = make_dispatcher(mdata['model_devi_machine'])
run_model_devi (ii, jdata, mdata, disp)
elif jj == 5 :
log_iter ("post_model_devi", ii, jj)
Expand All @@ -1930,7 +2045,10 @@ def run_iter (param_file, machine_file) :
elif jj == 7 :
log_iter ("run_fp", ii, jj)
mdata = decide_fp_machine(mdata)
disp = make_dispatcher(mdata['fp_machine'])
if mdata['fp_machine']['type'] == 'ALI':
disp = []
else:
disp = make_dispatcher(mdata['fp_machine'])
run_fp (ii, jdata, mdata, disp)
elif jj == 8 :
log_iter ("post_fp", ii, jj)
Expand Down
Loading

0 comments on commit d0a40bd

Please sign in to comment.