Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 7, 2024
1 parent ef1df18 commit a1b3ff8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
26 changes: 19 additions & 7 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@


def _get_model_suffix(jdata) -> str:
"""return the model suffix based on the backend"""
"""Return the model suffix based on the backend"""
backend = jdata.get("train_backend", "tensorflow")
if backend == "tensorflow":
suffix = ".pb"
Expand Down Expand Up @@ -193,7 +193,10 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii)
os.chdir(cur_train_path)
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
os.symlink(os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix))
os.symlink(
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
"graph.%03d%s" % (ii, suffix),
)
os.chdir(cwd)
with open(os.path.join(cur_train_path, "copied"), "w") as fp:
None
Expand Down Expand Up @@ -657,7 +660,9 @@ def make_train(iter_index, jdata, mdata):
)
if copied_models is not None:
for ii in range(len(copied_models)):
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
)
# Copy user defined forward files
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
# HDF5 format for training data
Expand Down Expand Up @@ -811,11 +816,18 @@ def run_train(iter_index, jdata, mdata):
elif training_init_frozen_model is not None or training_finetune_model is not None:
forward_files.append(os.path.join("old", "init%s" % suffix))

backward_files = ["frozen_model%s" % suffix, "lcurve.out", "train.log", "checkpoint"]
backward_files = [
"frozen_model%s" % suffix,
"lcurve.out",
"train.log",
"checkpoint",
]
if suffix == ".pb":
backward_files += ["model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001"]
backward_files += [
"model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001",
]
if jdata.get("dp_compress", False):
backward_files.append("frozen_model_compressed%s" % suffix)

Check warning on line 832 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L832

Added line #L832 was not covered by tests

Expand Down
2 changes: 1 addition & 1 deletion dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
record_iter,
)
from dpgen.generator.run import (
_get_model_suffix,
data_system_fmt,
fp_name,
fp_task_fmt,
Expand All @@ -43,7 +44,6 @@
run_train,
train_name,
train_task_fmt,
_get_model_suffix,
)
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import expand_sys_str, load_file, normalize, sepline, setup_ele_temp
Expand Down

0 comments on commit a1b3ff8

Please sign in to comment.