Skip to content

Commit

Permalink
fix(tf): fix argcheck when compressing a model converted from other b…
Browse files Browse the repository at this point in the history
…ackends (#4331)

When the model is converted from other backends, the input script only
contains the `model` section. This PR sets the default for any necessary
argument.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced the data structure for model compression by adding default
keys for training steps and learning rate.
  
- **Bug Fixes**
- Improved error handling with more informative runtime exceptions for
missing training scripts.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
njzjz and Your Name authored Nov 11, 2024
1 parent dcbf607 commit 02a3048
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/tf/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def compress(
10 * step,
int(frequency),
]
jdata.setdefault("training", {"numb_steps": 0})
jdata.setdefault("learning_rate", {})
jdata["training"]["save_ckpt"] = os.path.join("model-compression", "model.ckpt")
jdata = update_deepmd_input(jdata)
jdata = normalize(jdata)
Expand Down

0 comments on commit 02a3048

Please sign in to comment.