Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to select backends TF/PT #1541

Closed
wants to merge 46 commits into from
Closed

Conversation

thangckt
Copy link
Contributor

@thangckt thangckt commented May 7, 2024

add a new key in param.json file

"train_backend": "pytorch"/"tensorflow",

relate to this issue #1462

Summary by CodeRabbit

  • New Features
    • Enhanced model training and development processes.
    • Introduced new functionalities for model deviation analysis with various system configurations.

Copy link
Contributor

coderabbitai bot commented May 7, 2024

Warning

Rate Limit Exceeded

@thangckt has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 1 minutes and 23 seconds before requesting another review.

How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.
Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.
Please see our FAQ for further information.

Commits Files that changed from the base of the PR and between 083ad45 and 35b2713.

Walkthrough

The recent updates to the dpgen software focus on enhancing the model training and evaluation pipeline. These improvements include new features like model copying, training setup and execution, post-training tasks, and model deviation analysis using various simulation tools such as GROMACS and AMBER. These changes are designed to streamline processes, boost compatibility, and enhance efficiency in managing diverse molecular dynamics scenarios.

Changes

File Path Change Summary
.../run.py Added functions for model suffix retrieval, model copying, training setup, and execution. Introduced methods for post-training tasks and model deviation analysis in different environments (REVMAT, native, GROMACS, AMBER). Enhanced model deviation analysis with support for multiple molecular dynamics tools.

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Out of diff range and nitpick comments (2)
dpgen/generator/run.py (2)

128-128: Ensure the docstring ends with a period for consistency and clarity.


Line range hint 185-196: Use context managers when changing directories to ensure the directory context is restored even if an error occurs.

        prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii)
+       with change_directory(cur_train_path):
-       os.chdir(cur_train_path)
-       os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
+           os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
-       os.chdir(cwd)

dpgen/simplify/simplify.py Outdated Show resolved Hide resolved
dpgen/generator/run.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented May 7, 2024

Codecov Report

Attention: Patch coverage is 66.66667% with 16 lines in your changes are missing coverage. Please review.

Project coverage is 49.53%. Comparing base (8bf5778) to head (35b2713).

Files Patch % Lines
dpgen/generator/run.py 68.88% 14 Missing ⚠️
dpgen/simplify/simplify.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #1541      +/-   ##
==========================================
+ Coverage   49.49%   49.53%   +0.03%     
==========================================
  Files          83       83              
  Lines       14816    14841      +25     
==========================================
+ Hits         7333     7351      +18     
- Misses       7483     7490       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Out of diff range and nitpick comments (2)
dpgen/generator/run.py (2)

Line range hint 185-198: Optimize the copy_model function for error handling.

+ try:
      os.symlink(
          os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
          "graph.%03d%s" % (ii, suffix),
      )
+ except OSError as e:
+     logging.error("Failed to create symlink: %s", e)
      os.chdir(cwd)

Add error handling to manage potential issues with symlink creation, such as permissions errors or existing links.


Line range hint 717-782: Enhance the run_train function with additional logging for better traceability.

+ logging.info("Starting training process for model %d", ii)
  command = f"{train_command} train {train_input_file}{extra_flags}"
  command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
+ logging.info("Training command: %s", command)

Add logging statements to provide more information about the training process and the commands being executed.

dpgen/simplify/simplify.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Out of diff range and nitpick comments (1)
dpgen/simplify/simplify.py (1)

Line range hint 261-261: Replace hardcoded suffix ".pb" with dynamic suffix to maintain consistency with backend selection.

- models = glob.glob(os.path.join(work_path, "graph*pb"))
+ suffix = _get_model_suffix(jdata)
+ models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))

@thangckt thangckt changed the title add option to select backends TF/PT add option to collect files based on selected backends TF/PT May 7, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)

Line range hint 188-199: Use format specifiers instead of percent format for better performance and readability.

-        os.symlink(os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix), "graph.%03d%s" % (ii, suffix))
+        os.symlink(os.path.join(train_task_fmt.format(ii), f"frozen_model{suffix}"), f"graph.{ii:03d}{suffix}")

@@ -647,7 +659,7 @@ def make_train(iter_index, jdata, mdata):
)
if copied_models is not None:
for ii in range(len(copied_models)):
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb")
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            _link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
+            _link_old_models(work_path, [copied_models[ii]], ii, basename=f"init{suffix}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_link_old_models(work_path, [copied_models[ii]], ii, basename="init%s" % suffix)
_link_old_models(work_path, [copied_models[ii]], ii, basename=f"init{suffix}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)

Line range hint 188-201: Optimize the copy_model function to reduce directory changes and improve path handling.

188a189,190
>    # Store the current working directory to revert back later
>    original_cwd = os.getcwd()
192a195
>        # Use absolute paths for operations to avoid changing directories
194c197
<        os.chdir(cur_train_path)
---
>        # Avoid changing the directory by using absolute paths
196,197c199,200
<        os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
<        os.symlink(
---
>        os.symlink(prv_train_task, os.path.join(cur_train_path, train_task_fmt % ii))
>        os.symlink(
199c202
<        os.chdir(cwd)
---
>        os.chdir(original_cwd)  # Revert to the original working directory

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)

Line range hint 188-201: Optimize string formatting in the copy_model function.

-            os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
-            "graph.%03d%s" % (ii, suffix),
+            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
+            f"graph.{ii:03d}{suffix}",

Comment on lines 665 to 667
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use f-string for better readability and performance in string formatting.

-                work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+                work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
)
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)

Comment on lines 4710 to 4711
for ss in system_index:
sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use f-string for better readability and performance in string formatting.

-        sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss))
+        sys_output = glob.glob(os.path.join(work_path, f"task.{ss}.*"))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
for ss in system_index:
sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss))
for ss in system_index:
sys_output = glob.glob(os.path.join(work_path, f"task.{ss}.*"))

Comment on lines 4718 to 4719
sys_data_path = os.path.join(work_path, "data.%s" % ss)
all_sys.to_deepmd_raw(sys_data_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use f-string for better readability and performance in string formatting.

-        sys_data_path = os.path.join(work_path, "data.%s" % ss)
+        sys_data_path = os.path.join(work_path, f"data.{ss}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
sys_data_path = os.path.join(work_path, "data.%s" % ss)
all_sys.to_deepmd_raw(sys_data_path)
sys_data_path = os.path.join(work_path, f"data.{ss}")
all_sys.to_deepmd_raw(sys_data_path)

dpgen/generator/run.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

@@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index):
os.chdir(cur_train_path)
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
os.symlink(
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
+            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
os.path.join(train_task_fmt % ii, "frozen_model%s" % suffix),
os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),

@@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata):
)
if copied_models is not None:
for ii in range(len(copied_models)):
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb")
_link_old_models(
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-                work_path, [copied_models[ii]], ii, basename="init%s" % suffix
+                work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
work_path, [copied_models[ii]], ii, basename="init%s" % suffix
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"

@@ -1124,7 +1158,8 @@ def make_model_devi(iter_index, jdata, mdata):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

@@ -1305,7 +1340,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

@@ -1502,7 +1538,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

@@ -1644,7 +1681,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
+    models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))

@@ -1827,7 +1865,8 @@ def _make_model_devi_amber(
.replace("@qm_theory@", jdata["low_level"])
.replace("@rcut@", str(jdata["cutoff"]))
)
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-            models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
+            models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
models = sorted(glob.glob(os.path.join(train_path, "graph.*%s" % suffix)))
models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))

all_models = glob.glob(os.path.join(work_path, "graph*pb"))

suffix = _get_model_suffix(jdata)
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format specifiers instead of percent format for better performance and readability.

-    all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
+    all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
all_models = glob.glob(os.path.join(work_path, "graph*%s" % suffix))
all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)

Line range hint 188-201: Ensure consistent formatting and improve readability in the copy_model function.

-            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
-            "graph.%03d%s" % (ii, suffix)
+            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), f"graph.{ii:03d}{suffix}"

Comment on lines +915 to +921
suffix = _get_model_suffix(jdata)
for ii in range(numb_models):
if not jdata.get("dp_compress", False):
model_name = "frozen_model.pb"
else:
model_name = "frozen_model_compressed.pb"
model_name = f"frozen_model{suffix}"
if jdata.get("dp_compress", False):
model_name = f"frozen_model_compressed{suffix}"

ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the symlink creation in post_train function.

-        os.symlink(task_file, ofile)
+        os.symlink(task_file, ofile, target_is_directory=True)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
suffix = _get_model_suffix(jdata)
for ii in range(numb_models):
if not jdata.get("dp_compress", False):
model_name = "frozen_model.pb"
else:
model_name = "frozen_model_compressed.pb"
model_name = f"frozen_model{suffix}"
if jdata.get("dp_compress", False):
model_name = f"frozen_model_compressed{suffix}"
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))
suffix = _get_model_suffix(jdata)
for ii in range(numb_models):
model_name = f"frozen_model{suffix}"
if jdata.get("dp_compress", False):
model_name = f"frozen_model_compressed{suffix}"
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))
os.symlink(task_file, ofile, target_is_directory=True)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Out of diff range and nitpick comments (1)
dpgen/generator/run.py (1)

Line range hint 188-201: Optimize the copy_model function to handle different suffixes more efficiently.

def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
    cwd = os.getcwd()
    prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name)
    cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name)
    prv_train_path = os.path.abspath(prv_train_path)
    cur_train_path = os.path.abspath(cur_train_path)
    create_path(cur_train_path)
    for ii in range(numb_model):
        prv_train_task = os.path.join(prv_train_path, train_task_fmt % ii)
        os.chdir(cur_train_path)
        os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
        os.symlink(
            os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
            "graph.%03d%s" % (ii, suffix),
        )
        os.chdir(cwd)
    with open(os.path.join(cur_train_path, "copied"), "w") as fp:
        None

Comment on lines +128 to +139
def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend"""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the function _get_model_suffix for clarity and maintainability.

def _get_model_suffix(jdata) -> str:
    """Return the model suffix based on the backend."""
    suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
    backend = jdata.get("train_backend", "tensorflow")
    try:
        return suffix_map[backend]
    except KeyError:
        raise ValueError(
            "The backend '{}' is not supported. Please choose from: 'tensorflow', 'pytorch'.".format(backend)
        )

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend"""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix
def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
try:
return suffix_map[backend]
except KeyError:
raise ValueError(
"The backend '{}' is not supported. Please choose from: 'tensorflow', 'pytorch'.".format(backend)
)

Comment on lines +665 to +667
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the make_train function to handle different training setups more efficiently.

def make_train(iter_index, jdata, mdata):
    # load json param
    # train_param = jdata['train_param']
    train_input_file = default_train_input_file
    numb_models = jdata["numb_models"]
    init_data_prefix = jdata["init_data_prefix"]
    init_data_prefix = os.path.abspath(init_data_prefix)
    init_data_sys_ = jdata["init_data_sys"]
    fp_task_min = jdata["fp_task_min"]
    model_devi_jobs = jdata["model_devi_jobs"]
    use_ele_temp = jdata.get("use_ele_temp", 0)
    training_iter0_model = jdata.get("training_iter0_model_path", [])
    training_init_model = jdata.get("training_init_model", False)
    training_reuse_iter = jdata.get("training_reuse_iter")
    training_reuse_old_ratio = jdata.get("training_reuse_old_ratio", "auto")
    # if you want to use DP-ZBL potential , you have to give the path of your energy potential file
    if "srtab_file_path" in jdata.keys():
        srtab_file_path = os.path.abspath(jdata.get("srtab_file_path", None))
    if "training_reuse_stop_batch" in jdata.keys():
        training_reuse_stop_batch = jdata["training_reuse_stop_batch"]
    elif "training_reuse_numb_steps" in jdata.keys():
        training_reuse_stop_batch = jdata["training_reuse_numb_steps"]
    else:
        training_reuse_stop_batch = None
    training_reuse_start_lr = jdata.get("training_reuse_start_lr")
    training_reuse_start_pref_e = jdata.get("training_reuse_start_pref_e")
    training_reuse_start_pref_f = jdata.get("training_reuse_start_pref_f")
    model_devi_activation_func = jdata.get("model_devi_activation_func", None)
    training_init_frozen_model = (
        jdata.get("training_init_frozen_model") if iter_index == 0 else None
    )
    training_finetune_model = (
        jdata.get("training_finetune_model") if iter_index == 0 else None
    )
    auto_ratio = False
    if (
        training_reuse_iter is not None
        and isinstance(training_reuse_old_ratio, str)
        and training_reuse_old_ratio.startswith("auto")
    ):
        s = training_reuse_old_ratio.split(":")
        if len(s) == 1:
            new_to_old_ratio = 10.0
        elif len(s) == 2:
            new_to_old_ratio = float(s[1])
        else:
            raise ValueError(
                f"training_reuse_old_ratio is not correct, got {training_reuse_old_ratio}"
            )
        dlog.info(
            "Use automatic training_reuse_old_ratio to make new-to-old ratio close to %d times of the default value.",
            new_to_old_ratio,
        )
        auto_ratio = True
        number_old_frames = 0
        number_new_frames = 0

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

@@ -125,6 +125,19 @@
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py")


def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End the docstring with a period for consistency and proper grammar.

-    """Return the model suffix based on the backend"""
+    """Return the model suffix based on the backend."""

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
"""Return the model suffix based on the backend"""
"""Return the model suffix based on the backend."""

@thangckt thangckt closed this May 9, 2024
@thangckt thangckt deleted the PR branch May 9, 2024 08:25
@thangckt thangckt restored the PR branch May 9, 2024 08:26
@thangckt thangckt reopened this May 9, 2024
@thangckt thangckt closed this May 9, 2024
@thangckt thangckt deleted the PR branch May 9, 2024 08:31
wanghan-iapcm pushed a commit that referenced this pull request May 11, 2024
reopen PR #1541 due to branch is deleted

add a new key in `param.json` file

```
"train_backend": "pytorch"/"tensorflow",
```
relate to this issue #1462

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Improved model management by dynamically generating model suffixes
based on the selected backend, enhancing compatibility.
  
- **Enhancements**
- Updated model-related functions to incorporate backend-specific model
suffixes for accurate file handling during training processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: C. Thang Nguyen <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Support different backends for DeePMD-kit
2 participants