Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to select backends TF/PT #1541

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
575c5a9
begin add gpaw
thangckt Mar 16, 2024
3be325d
Create gpaw.py
thangckt Mar 29, 2024
d3cc70d
Merge pull request #1 from deepmodeling/devel
thangckt Mar 29, 2024
547e13b
Update gpaw.py
thangckt Apr 1, 2024
cfa904b
u
thangckt Apr 1, 2024
946561b
Update arginfo.py
thangckt Apr 2, 2024
11103f5
Merge pull request #2 from deepmodeling/devel
thangckt Apr 2, 2024
0ddbf7c
u
thangckt Apr 2, 2024
0b94ff6
u
thangckt Apr 3, 2024
494b796
u
thangckt Apr 4, 2024
2a58c7e
Merge branch 'devel' of https://github.com/thangckt/dpgen into devel
thangckt Apr 4, 2024
01fbd2f
Merge pull request #4 from deepmodeling/devel
thangckt May 2, 2024
babd77e
modify to use pytorch
thangckt May 5, 2024
28b8a49
option to choose between TF and PT
thangckt May 6, 2024
2822be0
Delete gpaw.py
thangckt May 6, 2024
99efd85
finish add option to select TF/PT
thangckt May 7, 2024
bdecc9f
Merge pull request #5 from deepmodeling/devel
thangckt May 7, 2024
f4f5665
Merge pull request #6 from thangckt/devel
thangckt May 7, 2024
ef1df18
remove GPAW to PR
thangckt May 7, 2024
a1b3ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
11dca54
Update simplify.py
thangckt May 7, 2024
a0684ca
Merge branch 'PR' of https://github.com/thangckt/dpgen into PR
thangckt May 7, 2024
220735d
Revert "Merge branch 'PR' of https://github.com/thangckt/dpgen into PR"
thangckt May 7, 2024
8f7f491
reset add GPAW from here
thangckt May 7, 2024
05b2412
u
thangckt May 7, 2024
ed832e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
af3fc27
Update arginfo.py
thangckt May 7, 2024
8a949d0
u
thangckt May 7, 2024
7def8ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
52a9989
Update run.py
thangckt May 7, 2024
0d99ea9
Merge branch 'PR' of https://github.com/thangckt/dpgen into PR
thangckt May 7, 2024
8bdea17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
67fab2b
Merge branch 'devel' into PR
thangckt May 7, 2024
706146a
Update run.py
thangckt May 8, 2024
e72f4c8
Merge branch 'devel' into PR
thangckt May 8, 2024
98eccb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2024
702510c
Update run.py
thangckt May 8, 2024
6ffd1eb
Merge branch 'PR' of https://github.com/thangckt/dpgen into PR
thangckt May 8, 2024
624a48b
Update arginfo.py
thangckt May 8, 2024
f3f49d3
remove gpaw
thangckt May 9, 2024
0f9b1b0
Update run.py
thangckt May 9, 2024
5049771
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2024
47a766e
Update run.py
thangckt May 9, 2024
083ad45
Update run.py
thangckt May 9, 2024
a500987
Update run.py
thangckt May 9, 2024
35b2713
Update run.py
thangckt May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py")


def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End the docstring with a period for consistency and proper grammar.

-    """Return the model suffix based on the backend"""
+    """Return the model suffix based on the backend."""

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
"""Return the model suffix based on the backend"""
"""Return the model suffix based on the backend."""

backend = jdata.get("train_backend", "tensorflow")
thangckt marked this conversation as resolved.
Show resolved Hide resolved
if backend == "tensorflow":
suffix = ".pb"
elif backend == "pytorch":
suffix = ".pth"

Check warning on line 134 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L133-L134

Added lines #L133 - L134 were not covered by tests
thangckt marked this conversation as resolved.
Show resolved Hide resolved
return suffix


def get_job_names(jdata):
jobkeys = []
for ii in jdata.keys():
Expand Down Expand Up @@ -172,7 +182,7 @@
return all(empty_sys)


def copy_model(numb_model, prv_iter_index, cur_iter_index):
def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
cwd = os.getcwd()
prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name)
cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name)
Expand All @@ -184,7 +194,8 @@
os.chdir(cur_train_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            "graph.%03d%s" % (ii, suffix),
+            "graph.{:03d}{}".format(ii, suffix),

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
os.chdir(cur_train_path)
os.chdir(cur_train_path)

os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
os.symlink(
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
+            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),

"graph.%03d%s" % (ii, suffix),
)
os.chdir(cwd)
with open(os.path.join(cur_train_path, "copied"), "w") as fp:
Expand Down Expand Up @@ -316,18 +327,19 @@
number_old_frames = 0
number_new_frames = 0

suffix = _get_model_suffix(jdata)
model_devi_engine = jdata.get("model_devi_engine", "lammps")
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)

Check warning on line 334 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L334

Added line #L334 was not covered by tests
return
elif (
model_devi_engine != "calypso"
and iter_index > 0
and _check_skip_train(model_devi_jobs[iter_index - 1])
):
log_task("skip training at step %d " % (iter_index - 1))
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)

Check warning on line 342 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L342

Added line #L342 was not covered by tests
Comment on lines +332 to +344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the conditional checks for clarity and maintainability. Consider using intermediate variables to simplify the conditions.

- if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
+ previous_iter_empty = iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min)
+ if previous_iter_empty:
    log_task("prev data is empty, copy prev model")
    copy_model(numb_models, iter_index - 1, iter_index, suffix)
    return

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
suffix = _get_model_suffix(jdata)
model_devi_engine = jdata.get("model_devi_engine", "lammps")
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return
elif (
model_devi_engine != "calypso"
and iter_index > 0
and _check_skip_train(model_devi_jobs[iter_index - 1])
):
log_task("skip training at step %d " % (iter_index - 1))
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)
suffix = _get_model_suffix(jdata)
model_devi_engine = jdata.get("model_devi_engine", "lammps")
previous_iter_empty = iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min)
if previous_iter_empty:
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return
elif (
model_devi_engine != "calypso"
and iter_index > 0
and _check_skip_train(model_devi_jobs[iter_index - 1])
):
log_task("skip training at step %d " % (iter_index - 1))
copy_model(numb_models, iter_index - 1, iter_index, suffix)

return
else:
iter_name = make_iter_name(iter_index)
Expand Down Expand Up @@ -591,19 +603,19 @@
if (
len(np.array(model_devi_activation_func).shape) == 2
): # 2-dim list for emd/fitting net-resolved assignment of actF
jinput["model"]["descriptor"]["activation_function"] = (

Check warning on line 606 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L606

Added line #L606 was not covered by tests
model_devi_activation_func[ii][0]
)
jinput["model"]["fitting_net"]["activation_function"] = (

Check warning on line 609 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L609

Added line #L609 was not covered by tests
model_devi_activation_func[ii][1]
)
if (
len(np.array(model_devi_activation_func).shape) == 1
): # for backward compatibility, 1-dim list, not net-resolved
jinput["model"]["descriptor"]["activation_function"] = (

Check warning on line 615 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L615

Added line #L615 was not covered by tests
model_devi_activation_func[ii]
)
jinput["model"]["fitting_net"]["activation_function"] = (

Check warning on line 618 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L618

Added line #L618 was not covered by tests
model_devi_activation_func[ii]
)
# dump the input.json
Expand Down Expand Up @@ -648,7 +660,9 @@
)
if copied_models is not None:
thangckt marked this conversation as resolved.
Show resolved Hide resolved
for ii in range(len(copied_models)):
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb")
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-                work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+                work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use f-string for better readability and performance in string formatting.

-                work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+                work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
)
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)

# Copy user defined forward files
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
# HDF5 format for training data
Expand Down Expand Up @@ -700,6 +714,7 @@
# print("debug:run_train:mdata", mdata)
# load json param
numb_models = jdata["numb_models"]
suffix = _get_model_suffix(jdata)
# train_param = jdata['train_param']
train_input_file = default_train_input_file
training_reuse_iter = jdata.get("training_reuse_iter")
Expand Down Expand Up @@ -762,9 +777,9 @@
if training_init_model:
init_flag = " --init-model old/model.ckpt"
elif training_init_frozen_model is not None:
init_flag = " --init-frz-model old/init.pb"
init_flag = " --init-frz-model old/init%s" % suffix
elif training_finetune_model is not None:
init_flag = " --finetune old/init.pb"
init_flag = " --finetune old/init%s" % suffix
thangckt marked this conversation as resolved.
Show resolved Hide resolved
command = f"{train_command} train {train_input_file}{extra_flags}"
thangckt marked this conversation as resolved.
Show resolved Hide resolved
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
command = "/bin/sh -c %s" % shlex.quote(command)
Expand Down Expand Up @@ -799,17 +814,23 @@
os.path.join("old", "model.ckpt.data-00000-of-00001"),
]
elif training_init_frozen_model is not None or training_finetune_model is not None:
forward_files.append(os.path.join("old", "init.pb"))
forward_files.append(os.path.join("old", "init%s" % suffix))

backward_files = ["frozen_model.pb", "lcurve.out", "train.log"]
backward_files += [
"model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001",
backward_files = [
"frozen_model%s" % suffix,
"lcurve.out",
"train.log",
"checkpoint",
]
if jdata.get("dp_compress", False):
backward_files.append("frozen_model_compressed.pb")
if suffix == ".pb":
backward_files += [
"model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001",
]
if jdata.get("dp_compress", False):
backward_files.append("frozen_model_compressed%s" % suffix)

Check warning on line 832 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L832

Added line #L832 was not covered by tests
thangckt marked this conversation as resolved.
Show resolved Hide resolved

if not jdata.get("one_h5", False):
init_data_sys_ = jdata["init_data_sys"]
init_data_sys = []
Expand Down Expand Up @@ -880,13 +901,15 @@
log_task("copied model, do not post train")
return
# symlink models
suffix = _get_model_suffix(jdata)

Check warning on line 904 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L904

Added line #L904 was not covered by tests
for ii in range(numb_models):
if not jdata.get("dp_compress", False):
model_name = "frozen_model.pb"
else:
model_name = "frozen_model_compressed.pb"
model_name = "frozen_model%s" % suffix
if suffix == ".pb":
thangckt marked this conversation as resolved.
Show resolved Hide resolved
if jdata.get("dp_compress", False):
model_name = "frozen_model_compressed%s" % suffix

Check warning on line 909 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L906-L909

Added lines #L906 - L909 were not covered by tests

ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))

Check warning on line 911 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L911

Added line #L911 was not covered by tests
task_file = os.path.join(train_task_fmt % ii, model_name)
ofile = os.path.join(work_path, "graph.%03d.pb" % ii)
if os.path.isfile(ofile):
os.remove(ofile)
os.symlink(task_file, ofile)
Expand Down Expand Up @@ -1126,7 +1149,8 @@
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
thangckt marked this conversation as resolved.
Show resolved Hide resolved
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

work_path = os.path.join(iter_name, model_devi_name)
create_path(work_path)
if model_devi_engine == "calypso":
Expand Down Expand Up @@ -1307,7 +1331,8 @@
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
thangckt marked this conversation as resolved.
Show resolved Hide resolved
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1504,7 +1529,8 @@
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
thangckt marked this conversation as resolved.
Show resolved Hide resolved
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1646,7 +1672,8 @@
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
thangckt marked this conversation as resolved.
Show resolved Hide resolved
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1829,7 +1856,8 @@
.replace("@qm_theory@", jdata["low_level"])
.replace("@rcut@", str(jdata["cutoff"]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            "unknown atomic identifier", atom, 'if one want to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using "auto".'
+            "unknown atomic identifier {}. If one wants to use isotopes, or non-standard element names, chemical symbols, or atomic number in the type_map list, please customize the mass_map list instead of using 'auto'.".format(atom)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
.replace("@rcut@", str(jdata["cutoff"]))
.replace("@rcut@", str(jdata["cutoff"]))

)
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
+            models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))

task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1937,7 +1965,9 @@
run_tasks = [os.path.basename(ii) for ii in run_tasks_]
# dlog.info("all_task is ", all_task)
# dlog.info("run_tasks in run_model_deviation",run_tasks_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            "should not get ele temp at setting: use_ele_temp == 0"
+            "should not get ele temp at setting: use_ele_temp == 0".format()

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
# dlog.info("run_tasks in run_model_deviation",run_tasks_)
# dlog.info("run_tasks in run_model_deviation",run_tasks_)

all_models = glob.glob(os.path.join(work_path, "graph*pb"))

suffix = _get_model_suffix(jdata)
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
+    all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))

model_names = [os.path.basename(ii) for ii in all_models]

model_devi_engine = jdata.get("model_devi_engine", "lammps")
Expand Down Expand Up @@ -2001,10 +2031,10 @@
if ndx_filename:
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -n {ndx_filename} -o {traj_filename} -pbc mol -ur compact -center'
else:
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -o {traj_filename} -pbc mol -ur compact -center'

Check warning on line 2034 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L2034

Added line #L2034 was not covered by tests
command += "&& if [ ! -d traj ]; then \n mkdir traj; fi\n"
command += f"python -c \"import dpdata;system = dpdata.System('{traj_filename}', fmt='gromacs/gro'); [system.to_gromacs_gro('traj/%d.gromacstrj' % (i * {trj_freq}), frame_idx=i) for i in range(system.get_nframes())]; system.to_deepmd_npy('traj_deepmd')\""
command += f"&& dp model-devi -m ../graph.000.pb ../graph.001.pb ../graph.002.pb ../graph.003.pb -s traj_deepmd -o model_devi.out -f {trj_freq}"
command += f"&& dp model-devi -m ../graph.000{suffix} ../graph.001{suffix} ../graph.002{suffix} ../graph.003{suffix} -s traj_deepmd -o model_devi.out -f {trj_freq}"

Check warning on line 2037 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L2037

Added line #L2037 was not covered by tests
commands = [command]

forward_files = [
Expand Down
5 changes: 4 additions & 1 deletion dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
record_iter,
)
from dpgen.generator.run import (
_get_model_suffix,
data_system_fmt,
fp_name,
fp_task_fmt,
Expand Down Expand Up @@ -186,7 +187,9 @@
# link the model
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = glob.glob(os.path.join(train_path, "graph*%s" % suffix))

Check warning on line 191 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L190-L191

Added lines #L190 - L191 were not covered by tests
thangckt marked this conversation as resolved.
Show resolved Hide resolved

thangckt marked this conversation as resolved.
Show resolved Hide resolved
for mm in models:
model_name = os.path.basename(mm)
os.symlink(mm, os.path.join(work_path, model_name))
Expand Down
Loading