-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add option to select backends TF/PT #1541
Conversation
fetch branch devel
Devel update
fetch devel_stream
add Option to choose backend TF/PT
for more information, see https://pre-commit.ci
Warning Rate Limit Exceeded@thangckt has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 1 minutes and 23 seconds before requesting another review. How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. WalkthroughThe recent updates to the Changes
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Out of diff range and nitpick comments (2)
dpgen/generator/run.py (2)
128-128
: Ensure the docstring ends with a period for consistency and clarity.
Line range hint
185-196
: Use context managers when changing directories to ensure the directory context is restored even if an error occurs.prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii) + with change_directory(cur_train_path): - os.chdir(cur_train_path) - os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) + os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) - os.chdir(cwd)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #1541 +/- ##
==========================================
+ Coverage 49.49% 49.53% +0.03%
==========================================
Files 83 83
Lines 14816 14841 +25
==========================================
+ Hits 7333 7351 +18
- Misses 7483 7490 +7 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Out of diff range and nitpick comments (2)
dpgen/generator/run.py (2)
Line range hint
185-198
: Optimize thecopy_model
function for error handling.+ try: os.symlink( os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix), ) + except OSError as e: + logging.error("Failed to create symlink: %s", e) os.chdir(cwd)Add error handling to manage potential issues with symlink creation, such as permissions errors or existing links.
Line range hint
717-782
: Enhance therun_train
function with additional logging for better traceability.+ logging.info("Starting training process for model %d", ii) command = f"{train_command} train {train_input_file}{extra_flags}" command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" + logging.info("Training command: %s", command)Add logging statements to provide more information about the training process and the commands being executed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Out of diff range and nitpick comments (1)
dpgen/simplify/simplify.py (1)
Line range hint
261-261
: Replace hardcoded suffix ".pb" with dynamic suffix to maintain consistency with backend selection.- models = glob.glob(os.path.join(work_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Signed-off-by: C. Thang Nguyen <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)
Line range hint
188-199
: Use format specifiers instead of percent format for better performance and readability.- os.symlink(os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix)) + os.symlink(os.path.join(train_task_fmt.format(ii), f"frozen_model{suffix}"), f"graph.{ii:03d}{suffix}")
dpgen/generator/run.py
Outdated
@@ -647,7 +659,7 @@ def make_train(iter_index, jdata, mdata): | |||
) | |||
if copied_models is not None: | |||
for ii in range(len(copied_models)): | |||
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb") | |||
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- _link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
+ _link_old_models(work_path, [copied_models[ii]], ii, basename=f"init{suffix}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix) | |
_link_old_models(work_path, [copied_models[ii]], ii, basename=f"init{suffix}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)
Line range hint
188-201
: Optimize thecopy_model
function to reduce directory changes and improve path handling.188a189,190 > # Store the current working directory to revert back later > original_cwd = os.getcwd() 192a195 > # Use absolute paths for operations to avoid changing directories 194c197 < os.chdir(cur_train_path) --- > # Avoid changing the directory by using absolute paths 196,197c199,200 < os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) < os.symlink( --- > os.symlink(prv_train_task, os.path.join(cur_train_path, train_task_fmt % ii)) > os.symlink( 199c202 < os.chdir(cwd) --- > os.chdir(original_cwd) # Revert to the original working directory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)
Line range hint
188-201
: Optimize string formatting in thecopy_model
function.- os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), - "graph.%03d%s" % (ii, suffix), + os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), + f"graph.{ii:03d}{suffix}",
dpgen/generator/run.py
Outdated
_link_old_models( | ||
work_path, [copied_models[ii]], ii, basename="init%s" % suffix | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use f-string for better readability and performance in string formatting.
- work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
_link_old_models( | |
work_path, [copied_models[ii]], ii, basename="init%s" % suffix | |
) | |
_link_old_models( | |
work_path, [copied_models[ii]], ii, basename=f"init{suffix}" | |
) |
dpgen/generator/run.py
Outdated
for ss in system_index: | ||
sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use f-string for better readability and performance in string formatting.
- sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss))
+ sys_output = glob.glob(os.path.join(work_path, f"task.{ss}.*"))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
for ss in system_index: | |
sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss)) | |
for ss in system_index: | |
sys_output = glob.glob(os.path.join(work_path, f"task.{ss}.*")) |
dpgen/generator/run.py
Outdated
sys_data_path = os.path.join(work_path, "data.%s" % ss) | ||
all_sys.to_deepmd_raw(sys_data_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use f-string for better readability and performance in string formatting.
- sys_data_path = os.path.join(work_path, "data.%s" % ss)
+ sys_data_path = os.path.join(work_path, f"data.{ss}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
sys_data_path = os.path.join(work_path, "data.%s" % ss) | |
all_sys.to_deepmd_raw(sys_data_path) | |
sys_data_path = os.path.join(work_path, f"data.{ss}") | |
all_sys.to_deepmd_raw(sys_data_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
dpgen/generator/run.py
Outdated
@@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index): | |||
os.chdir(cur_train_path) | |||
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) | |||
os.symlink( | |||
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii | |||
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
+ os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), | |
os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), |
dpgen/generator/run.py
Outdated
@@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata): | |||
) | |||
if copied_models is not None: | |||
for ii in range(len(copied_models)): | |||
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb") | |||
_link_old_models( | |||
work_path, [copied_models[ii]], ii, basename="init%s" % suffix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+ work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
work_path, [copied_models[ii]], ii, basename="init%s" % suffix | |
work_path, [copied_models[ii]], ii, basename=f"init{suffix}" |
dpgen/generator/run.py
Outdated
@@ -1124,7 +1158,8 @@ def make_model_devi(iter_index, jdata, mdata): | |||
iter_name = make_iter_name(iter_index) | |||
train_path = os.path.join(iter_name, train_name) | |||
train_path = os.path.abspath(train_path) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | |||
suffix = _get_model_suffix(jdata) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | |
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) |
dpgen/generator/run.py
Outdated
@@ -1305,7 +1340,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): | |||
iter_name = make_iter_name(iter_index) | |||
train_path = os.path.join(iter_name, train_name) | |||
train_path = os.path.abspath(train_path) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) | |||
suffix = _get_model_suffix(jdata) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | |
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) |
dpgen/generator/run.py
Outdated
@@ -1502,7 +1538,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems): | |||
iter_name = make_iter_name(iter_index) | |||
train_path = os.path.join(iter_name, train_name) | |||
train_path = os.path.abspath(train_path) | |||
models = glob.glob(os.path.join(train_path, "graph*pb")) | |||
suffix = _get_model_suffix(jdata) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | |
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) |
dpgen/generator/run.py
Outdated
@@ -1644,7 +1681,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems): | |||
iter_name = make_iter_name(iter_index) | |||
train_path = os.path.join(iter_name, train_name) | |||
train_path = os.path.abspath(train_path) | |||
models = glob.glob(os.path.join(train_path, "graph*pb")) | |||
suffix = _get_model_suffix(jdata) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix))) | |
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) |
dpgen/generator/run.py
Outdated
@@ -1827,7 +1865,8 @@ def _make_model_devi_amber( | |||
.replace("@qm_theory@", jdata["low_level"]) | |||
.replace("@rcut@", str(jdata["cutoff"])) | |||
) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb"))) | |||
suffix = _get_model_suffix(jdata) | |||
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
+ models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix))) | |
models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}"))) |
dpgen/generator/run.py
Outdated
all_models = glob.glob(os.path.join(work_path, "graph*pb")) | ||
|
||
suffix = _get_model_suffix(jdata) | ||
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use format specifiers instead of percent format for better performance and readability.
- all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
+ all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix)) | |
all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)
Line range hint
188-201
: Ensure consistent formatting and improve readability in thecopy_model
function.- os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), - "graph.%03d%s" % (ii, suffix) + os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), f"graph.{ii:03d}{suffix}"
suffix = _get_model_suffix(jdata) | ||
for ii in range(numb_models): | ||
if not jdata.get("dp_compress", False): | ||
model_name = "frozen_model.pb" | ||
else: | ||
model_name = "frozen_model_compressed.pb" | ||
model_name = f"frozen_model{suffix}" | ||
if jdata.get("dp_compress", False): | ||
model_name = f"frozen_model_compressed{suffix}" | ||
|
||
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the symlink creation in post_train
function.
- os.symlink(task_file, ofile)
+ os.symlink(task_file, ofile, target_is_directory=True)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
suffix = _get_model_suffix(jdata) | |
for ii in range(numb_models): | |
if not jdata.get("dp_compress", False): | |
model_name = "frozen_model.pb" | |
else: | |
model_name = "frozen_model_compressed.pb" | |
model_name = f"frozen_model{suffix}" | |
if jdata.get("dp_compress", False): | |
model_name = f"frozen_model_compressed{suffix}" | |
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) | |
suffix = _get_model_suffix(jdata) | |
for ii in range(numb_models): | |
model_name = f"frozen_model{suffix}" | |
if jdata.get("dp_compress", False): | |
model_name = f"frozen_model_compressed{suffix}" | |
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) | |
os.symlink(task_file, ofile, target_is_directory=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)
Line range hint
188-201
: Optimize thecopy_model
function to handle different suffixes more efficiently.def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"): cwd = os.getcwd() prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name) cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name) prv_train_path = os.path.abspath(prv_train_path) cur_train_path = os.path.abspath(cur_train_path) create_path(cur_train_path) for ii in range(numb_model): prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii) os.chdir(cur_train_path) os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) os.symlink( os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), "graph.%03d%s" % (ii, suffix), ) os.chdir(cwd) with open(os.path.join(cur_train_path, "copied"), "w") as fp: None
def _get_model_suffix(jdata) -> str: | ||
"""Return the model suffix based on the backend""" | ||
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} | ||
backend = jdata.get("train_backend", "tensorflow") | ||
if backend in suffix_map: | ||
suffix = suffix_map[backend] | ||
else: | ||
raise ValueError( | ||
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." | ||
) | ||
return suffix | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor the function _get_model_suffix
for clarity and maintainability.
def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
try:
return suffix_map[backend]
except KeyError:
raise ValueError(
"The backend '{}' is not supported. Please choose from: 'tensorflow', 'pytorch'.".format(backend)
)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def _get_model_suffix(jdata) -> str: | |
"""Return the model suffix based on the backend""" | |
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} | |
backend = jdata.get("train_backend", "tensorflow") | |
if backend in suffix_map: | |
suffix = suffix_map[backend] | |
else: | |
raise ValueError( | |
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." | |
) | |
return suffix | |
def _get_model_suffix(jdata) -> str: | |
"""Return the model suffix based on the backend.""" | |
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} | |
backend = jdata.get("train_backend", "tensorflow") | |
try: | |
return suffix_map[backend] | |
except KeyError: | |
raise ValueError( | |
"The backend '{}' is not supported. Please choose from: 'tensorflow', 'pytorch'.".format(backend) | |
) |
_link_old_models( | ||
work_path, [copied_models[ii]], ii, basename=f"init{suffix}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the make_train
function to handle different training setups more efficiently.
def make_train(iter_index, jdata, mdata):
# load json param
# train_param = jdata['train_param']
train_input_file = default_train_input_file
numb_models = jdata["numb_models"]
init_data_prefix = jdata["init_data_prefix"]
init_data_prefix = os.path.abspath(init_data_prefix)
init_data_sys_ = jdata["init_data_sys"]
fp_task_min = jdata["fp_task_min"]
model_devi_jobs = jdata["model_devi_jobs"]
use_ele_temp = jdata.get("use_ele_temp", 0)
training_iter0_model = jdata.get("training_iter0_model_path", [])
training_init_model = jdata.get("training_init_model", False)
training_reuse_iter = jdata.get("training_reuse_iter")
training_reuse_old_ratio = jdata.get("training_reuse_old_ratio", "auto")
# if you want to use DP-ZBL potential , you have to give the path of your energy potential file
if "srtab_file_path" in jdata.keys():
srtab_file_path = os.path.abspath(jdata.get("srtab_file_path", None))
if "training_reuse_stop_batch" in jdata.keys():
training_reuse_stop_batch = jdata["training_reuse_stop_batch"]
elif "training_reuse_numb_steps" in jdata.keys():
training_reuse_stop_batch = jdata["training_reuse_numb_steps"]
else:
training_reuse_stop_batch = None
training_reuse_start_lr = jdata.get("training_reuse_start_lr")
training_reuse_start_pref_e = jdata.get("training_reuse_start_pref_e")
training_reuse_start_pref_f = jdata.get("training_reuse_start_pref_f")
model_devi_activation_func = jdata.get("model_devi_activation_func", None)
training_init_frozen_model = (
jdata.get("training_init_frozen_model") if iter_index == 0 else None
)
training_finetune_model = (
jdata.get("training_finetune_model") if iter_index == 0 else None
)
auto_ratio = False
if (
training_reuse_iter is not None
and isinstance(training_reuse_old_ratio, str)
and training_reuse_old_ratio.startswith("auto")
):
s = training_reuse_old_ratio.split(":")
if len(s) == 1:
new_to_old_ratio = 10.0
elif len(s) == 2:
new_to_old_ratio = float(s[1])
else:
raise ValueError(
f"training_reuse_old_ratio is not correct, got {training_reuse_old_ratio}"
)
dlog.info(
"Use automatic training_reuse_old_ratio to make new-to-old ratio close to %d times of the default value.",
new_to_old_ratio,
)
auto_ratio = True
number_old_frames = 0
number_new_frames = 0
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
_link_old_models( | |
work_path, [copied_models[ii]], ii, basename=f"init{suffix}" | |
) | |
_link_old_models( | |
work_path, [copied_models[ii]], ii, basename=f"init{suffix}" | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
@@ -125,6 +125,19 @@ | |||
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py") | |||
|
|||
|
|||
def _get_model_suffix(jdata) -> str: | |||
"""Return the model suffix based on the backend""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End the docstring with a period for consistency and proper grammar.
- """Return the model suffix based on the backend"""
+ """Return the model suffix based on the backend."""
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
"""Return the model suffix based on the backend""" | |
"""Return the model suffix based on the backend.""" |
reopen PR #1541 due to branch is deleted add a new key in `param.json` file ``` "train_backend": "pytorch"/"tensorflow", ``` relate to this issue #1462 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Improved model management by dynamically generating model suffixes based on the selected backend, enhancing compatibility. - **Enhancements** - Updated model-related functions to incorporate backend-specific model suffixes for accurate file handling during training processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: C. Thang Nguyen <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
add a new key in
param.json
filerelate to this issue #1462
Summary by CodeRabbit