-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add option to select backends TF/PT #1541
Changes from 40 commits
575c5a9
3be325d
d3cc70d
547e13b
cfa904b
946561b
11103f5
0ddbf7c
0b94ff6
494b796
2a58c7e
01fbd2f
babd77e
28b8a49
2822be0
99efd85
bdecc9f
f4f5665
ef1df18
a1b3ff8
11dca54
a0684ca
220735d
8f7f491
05b2412
ed832e8
af3fc27
8a949d0
7def8ef
52a9989
0d99ea9
8bdea17
67fab2b
706146a
e72f4c8
98eccb8
702510c
6ffd1eb
624a48b
f3f49d3
0f9b1b0
5049771
47a766e
083ad45
a500987
35b2713
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,6 +125,19 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _get_model_suffix(jdata) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Return the model suffix based on the backend""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backend = jdata.get("train_backend", "tensorflow") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if backend in suffix_map: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = suffix_map[backend] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+128
to
+139
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor the function def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
try:
return suffix_map[backend]
except KeyError:
raise ValueError(
"The backend '{}' is not supported. Please choose from: 'tensorflow', 'pytorch'.".format(backend)
) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_job_names(jdata): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
jobkeys = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in jdata.keys(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -172,7 +185,7 @@ def _check_empty_iter(iter_index, max_v=0): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return all(empty_sys) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def copy_model(numb_model, prv_iter_index, cur_iter_index): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cwd = os.getcwd() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.chdir(cur_train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "graph.%03d%s" % (ii, suffix),
+ "graph.{:03d}{}".format(ii, suffix), Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
+ os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"graph.%03d%s" % (ii, suffix), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.chdir(cwd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with open(os.path.join(cur_train_path, "copied"), "w") as fp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -315,18 +329,19 @@ def make_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
number_old_frames = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
number_new_frames = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine = jdata.get("model_devi_engine", "lammps") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("prev data is empty, copy prev model") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index, suffix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine != "calypso" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and iter_index > 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and _check_skip_train(model_devi_jobs[iter_index - 1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("skip training at step %d " % (iter_index - 1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
copy_model(numb_models, iter_index - 1, iter_index, suffix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+332
to
+344
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor the conditional checks for clarity and maintainability. Consider using intermediate variables to simplify the conditions. - if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
+ previous_iter_empty = iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min)
+ if previous_iter_empty:
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if copied_models is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(len(copied_models)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_link_old_models( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
work_path, [copied_models[ii]], ii, basename="init%s" % suffix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}" Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use f-string for better readability and performance in string formatting. - work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}" Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Copy user defined forward files | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# HDF5 format for training data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -699,6 +716,7 @@ def run_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print("debug:run_train:mdata", mdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# load json param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
numb_models = jdata["numb_models"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# train_param = jdata['train_param'] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_input_file = default_train_input_file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
training_reuse_iter = jdata.get("training_reuse_iter") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -731,6 +749,12 @@ def run_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_command = mdata.get("train_command", "dp") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert train_command == "dp", "The 'train_command' should be 'dp'" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if suffix == ".pb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_command += " --tf" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif suffix == ".pth": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_command += " --pt" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_resources = mdata["train_resources"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# paths | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -761,9 +785,9 @@ def run_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if training_init_model: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --init-model old/model.ckpt" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_init_frozen_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --init-frz-model old/init.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = f" --init-frz-model old/init{suffix}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_finetune_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = " --finetune old/init.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_flag = f" --finetune old/init{suffix}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = f"{train_command} train {train_input_file}{extra_flags}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command = f"/bin/sh -c {shlex.quote(command)}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -798,17 +822,26 @@ def run_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.path.join("old", "model.ckpt.data-00000-of-00001"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif training_init_frozen_model is not None or training_finetune_model is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files.append(os.path.join("old", "init.pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files.append(os.path.join("old", f"init{suffix}")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files = ["frozen_model.pb", "lcurve.out", "train.log"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files += [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.meta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.index", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.data-00000-of-00001", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"frozen_model{suffix}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"lcurve.out", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"train.log", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"checkpoint", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files.append("frozen_model_compressed.pb") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files.append(f"frozen_model_compressed{suffix}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if suffix == ".pb": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files += [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.meta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.index", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"model.ckpt.data-00000-of-00001", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif suffix == ".pth": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
backward_files += ["model.ckpt.pt"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not jdata.get("one_h5", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_data_sys_ = jdata["init_data_sys"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
init_data_sys = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -879,13 +912,14 @@ def post_train(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_task("copied model, do not post train") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# symlink models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in range(numb_models): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = "frozen_model_compressed.pb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = f"frozen_model{suffix}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if jdata.get("dp_compress", False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_name = f"frozen_model_compressed{suffix}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+915
to
+921
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimize the symlink creation in - os.symlink(task_file, ofile)
+ os.symlink(task_file, ofile, target_is_directory=True) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_file = os.path.join(train_task_fmt % ii, model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ofile = os.path.join(work_path, "graph.%03d.pb" % ii) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if os.path.isfile(ofile): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.remove(ofile) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.symlink(task_file, ofile) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1124,7 +1158,8 @@ def make_model_devi(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
work_path = os.path.join(iter_name, model_devi_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
create_path(work_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if model_devi_engine == "calypso": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1305,7 +1340,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1502,7 +1538,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = glob.glob(os.path.join(train_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1644,7 +1681,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_name = make_iter_name(iter_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.join(iter_name, train_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thangckt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_path = os.path.abspath(train_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = glob.glob(os.path.join(train_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1827,7 +1865,8 @@ def _make_model_devi_amber( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.replace("@qm_theory@", jdata["low_level"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.replace("@rcut@", str(jdata["cutoff"])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "unknown atomic identifier", atom, 'if one want to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using "auto".'
+ "unknown atomic identifier {}. If one wants to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using 'auto'.".format(atom) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb"))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}"))) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ii in models: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
task_model_list.append(os.path.join("..", os.path.basename(ii))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1935,7 +1974,9 @@ def run_md_model_devi(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
run_tasks = [os.path.basename(ii) for ii in run_tasks_] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# dlog.info("all_task is ", all_task) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# dlog.info("run_tasks in run_model_deviation",run_tasks_) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - "should not get ele temp at setting: use_ele_temp == 0"
+ "should not get ele temp at setting: use_ele_temp == 0".format() Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
all_models = glob.glob(os.path.join(work_path, "graph*pb")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suffix = _get_model_suffix(jdata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use format specifiers instead of percent format for better performance and readability. - all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
+ all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_names = [os.path.basename(ii) for ii in all_models] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_engine = jdata.get("model_devi_engine", "lammps") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2002,7 +2043,7 @@ def run_md_model_devi(iter_index, jdata, mdata): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -o {traj_filename} -pbc mol -ur compact -center' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += "&& if [ ! -d traj ]; then \n mkdir traj; fi\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"python -c \"import dpdata;system = dpdata.System('{traj_filename}', fmt='gromacs/gro'); [system.to_gromacs_gro('traj/%d.gromacstrj' % (i * {trj_freq}), frame_idx=i) for i in range(system.get_nframes())]; system.to_deepmd_npy('traj_deepmd')\"" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"&& dp model-devi -m ../graph.000.pb ../graph.001.pb ../graph.002.pb ../graph.003.pb -s traj_deepmd -o model_devi.out -f {trj_freq}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
command += f"&& dp model-devi -m ../graph.000{suffix} ../graph.001{suffix} ../graph.002{suffix} ../graph.003{suffix} -s traj_deepmd -o model_devi.out -f {trj_freq}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
commands = [command] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
forward_files = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2177,7 +2218,7 @@ def _read_model_devi_file( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert all( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_content.shape[0] == model_devi_contents[0].shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for model_devi_content in model_devi_contents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), r"Not all beads generated the same number of lines in the model_devi${ibead}.out file. Check your pimd task carefully." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), "Not all beads generated the same number of lines in the model_devi$\{ibead\}.out file. Check your pimd task carefully." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
last_step = model_devi_contents[0][-1, 0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for ibead in range(1, num_beads): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_devi_contents[ibead][:, 0] = model_devi_contents[ibead][ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End the docstring with a period for consistency and proper grammar.
Committable suggestion