-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add option to select backends TF/PT #1541
Changes from 20 commits
575c5a9
3be325d
d3cc70d
547e13b
cfa904b
946561b
11103f5
0ddbf7c
0b94ff6
494b796
2a58c7e
01fbd2f
babd77e
28b8a49
2822be0
99efd85
bdecc9f
f4f5665
ef1df18
a1b3ff8
11dca54
a0684ca
220735d
8f7f491
05b2412
ed832e8
af3fc27
8a949d0
7def8ef
52a9989
0d99ea9
8bdea17
67fab2b
706146a
e72f4c8
98eccb8
702510c
6ffd1eb
624a48b
f3f49d3
0f9b1b0
5049771
47a766e
083ad45
a500987
35b2713
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,6 +125,16 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _get_model_suffix(jdata) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Return the model suffix based on the backend""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backend = jdata.get("train_backend", "tensorflow") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if backend == "tensorflow": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = ".pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif backend == "pytorch": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = ".pth" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_job_names(jdata): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jobkeys = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in jdata.keys(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -172,7 +182,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return all(empty_sys) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def copy_model(numb_model, prv_iter_index, cur_iter_index): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cwd = os.getcwd() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -184,7 +194,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.chdir(cur_train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "graph.%03d%s" % (ii, suffix),
+ "graph.{:03d}{}".format(ii, suffix), Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
+ os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"graph.%03d%s" % (ii, suffix), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.chdir(cwd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with open(os.path.join(cur_train_path, "copied"), "w") as fp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -316,18 +327,19 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
number_old_frames = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
number_new_frames = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine = jdata.get("model_devi_engine", "lammps") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("prev data is empty, copy prev model") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index, suffix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine != "calypso" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and iter_index > 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and _check_skip_train(model_devi_jobs[iter_index - 1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("skip training at step %d " % (iter_index - 1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index, suffix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+332
to
+344
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor the conditional checks for clarity and maintainability. Consider using intermediate variables to simplify the conditions. - if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
+ previous_iter_empty = iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min)
+ if previous_iter_empty:
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -591,19 +603,19 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
len(np.array(model_devi_activation_func).shape) == 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): # 2-dim list for emd/fitting net-resolved assignment of actF | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jinput["model"]["descriptor"]["activation_function"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_activation_func[ii][0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jinput["model"]["fitting_net"]["activation_function"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_activation_func[ii][1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
len(np.array(model_devi_activation_func).shape) == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): # for backward compatibility, 1-dim list, not net-resolved | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jinput["model"]["descriptor"]["activation_function"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_activation_func[ii] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jinput["model"]["fitting_net"]["activation_function"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_activation_func[ii] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# dump the input.json | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -648,7 +660,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if copied_models is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(len(copied_models)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_link_old_models( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
work_path, [copied_models[ii]], ii, basename="init%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}" Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use f-string for better readability and performance in string formatting. - work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}" Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Copy user defined forward files | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# HDF5 format for training data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -700,6 +714,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print("debug:run_train:mdata", mdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# load json param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
numb_models = jdata["numb_models"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# train_param = jdata['train_param'] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_input_file = default_train_input_file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
training_reuse_iter = jdata.get("training_reuse_iter") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -762,9 +777,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if training_init_model: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --init-model old/model.ckpt" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_init_frozen_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --init-frz-model old/init.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --init-frz-model old/init%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_finetune_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --finetune old/init.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --finetune old/init%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = f"{train_command} train {train_input_file}{extra_flags}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = "/bin/sh -c %s" % shlex.quote(command) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -799,17 +814,23 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join("old", "model.ckpt.data-00000-of-00001"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_init_frozen_model is not None or training_finetune_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files.append(os.path.join("old", "init.pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files.append(os.path.join("old", "init%s" % suffix)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files = ["frozen_model.pb", "lcurve.out", "train.log"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files += [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.meta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.index", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.data-00000-of-00001", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"frozen_model%s" % suffix, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"lcurve.out", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"train.log", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files.append("frozen_model_compressed.pb") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if suffix == ".pb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files += [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.meta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.index", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.data-00000-of-00001", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files.append("frozen_model_compressed%s" % suffix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not jdata.get("one_h5", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_data_sys_ = jdata["init_data_sys"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_data_sys = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -880,13 +901,15 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("copied model, do not post train") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# symlink models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(numb_models): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model_compressed.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if suffix == ".pb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model_compressed%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_file = os.path.join(train_task_fmt % ii, model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ofile = os.path.join(work_path, "graph.%03d.pb" % ii) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if os.path.isfile(ofile): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.remove(ofile) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink(task_file, ofile) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1126,7 +1149,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
work_path = os.path.join(iter_name, model_devi_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
create_path(work_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if model_devi_engine == "calypso": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1307,7 +1331,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1504,7 +1529,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = glob.glob(os.path.join(train_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1646,7 +1672,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = glob.glob(os.path.join(train_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1829,7 +1856,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.replace("@qm_theory@", jdata["low_level"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.replace("@rcut@", str(jdata["cutoff"])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "unknown atomic identifier", atom, 'if one want to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using "auto".'
+ "unknown atomic identifier {}. If one wants to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using 'auto'.".format(atom) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1937,7 +1965,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
run_tasks = [os.path.basename(ii) for ii in run_tasks_] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# dlog.info("all_task is ", all_task) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# dlog.info("run_tasks in run_model_deviation",run_tasks_) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "should not get ele temp at setting: use_ele_temp == 0"
+ "should not get ele temp at setting: use_ele_temp == 0".format() Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
all_models = glob.glob(os.path.join(work_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
+ all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_names = [os.path.basename(ii) for ii in all_models] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine = jdata.get("model_devi_engine", "lammps") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2001,10 +2031,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ndx_filename: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -n {ndx_filename} -o {traj_filename} -pbc mol -ur compact -center' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -o {traj_filename} -pbc mol -ur compact -center' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += "&& if [ ! -d traj ]; then \n mkdir traj; fi\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"python -c \"import dpdata;system = dpdata.System('{traj_filename}', fmt='gromacs/gro'); [system.to_gromacs_gro('traj/%d.gromacstrj' % (i * {trj_freq}), frame_idx=i) for i in range(system.get_nframes())]; system.to_deepmd_npy('traj_deepmd')\"" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"&& dp model-devi -m ../graph.000.pb ../graph.001.pb ../graph.002.pb ../graph.003.pb -s traj_deepmd -o model_devi.out -f {trj_freq}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"&& dp model-devi -m ../graph.000{suffix} ../graph.001{suffix} ../graph.002{suffix} ../graph.003{suffix} -s traj_deepmd -o model_devi.out -f {trj_freq}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
commands = [command] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End the docstring with a period for consistency and proper grammar.
Committable suggestion