Skip to content

Commit

Permalink
feat: plain text model format (#4025)
Browse files Browse the repository at this point in the history
Propose a plain text model format based on YAML, which humans can easily
read and might be easier to track changes in the git repository (which
is good for #2103).

Example:
[deeppot_dpa_sel.yaml](https://github.com/user-attachments/files/16384230/deeppot_dpa_sel.yaml.txt)


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added support for additional file formats (.yaml and .yml) for model
saving and loading.
- Enhanced the ability to serialize and deserialize model data in
multiple formats.

- **Bug Fixes**
- Improved error handling for unsupported file formats during model
loading.

- **Documentation**
- Updated documentation to reflect new supported file formats and
clarify backend capabilities.

- **Tests**
- Introduced new test cases to ensure functionality for saving and
loading models in YAML format.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jul 26, 2024
1 parent 561ff1b commit 7f61048
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 20 deletions.
2 changes: 1 addition & 1 deletion deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DPModelBackend(Backend):
Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
suffixes: ClassVar[List[str]] = [".dp", ".yaml", ".yml"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down
84 changes: 68 additions & 16 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from datetime import (
datetime,
)
from pathlib import (
Path,
)
from typing import (
Callable,
)

import h5py
import numpy as np
import yaml

try:
from deepmd._version import version as __version__
Expand All @@ -33,6 +38,8 @@ def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False
The model object after traversing.
"""
if isinstance(model_obj, dict):
if model_obj.get("@is_variable", False):
return callback(model_obj)
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
Expand Down Expand Up @@ -78,22 +85,48 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
filename_extension = Path(filename).suffix
extra_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
}
if filename_extension == ".dp":
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
**extra_dict,
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
elif filename_extension in {".yaml", ".yml"}:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
lambda x: {
"@class": "np.ndarray",
"@is_variable": True,
"@version": 1,
"dtype": x.dtype.name,
"value": x.tolist(),
},
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
with open(filename, "w") as f:
yaml.safe_dump(
{
**extra_dict,
**model_dict,
},
f,
)
else:
raise ValueError(f"Unknown filename extension: {filename_extension}")


def load_dp_model(filename: str) -> dict:
Expand All @@ -109,7 +142,26 @@ def load_dp_model(filename: str) -> dict:
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
filename_extension = Path(filename).suffix
if filename_extension == ".dp":
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
elif filename_extension in {".yaml", ".yml"}:

def convert_numpy_ndarray(x):
if isinstance(x, dict) and x.get("@class") == "np.ndarray":
dtype = np.dtype(x["dtype"])
value = np.asarray(x["value"], dtype=dtype)
return value
return x

with open(filename) as f:
model_dict = yaml.safe_load(f)
model_dict = traverse_model_dict(
model_dict,
convert_numpy_ndarray,
)
else:
raise ValueError(f"Unknown filename extension: {filename_extension}")
return model_dict
8 changes: 5 additions & 3 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
This backend is only for development and should not take into production.
:::

- Model filename extension: `.dp`
- Model filename extension: `.dp`, `.yaml`, `.yml`

DP is a reference backend for development, which uses pure [NumPy](https://numpy.org/) to implement models without using any heavy deep-learning frameworks.
Due to the limitation of NumPy, it doesn't support gradient calculation and thus cannot be used for training.
As a reference backend, it is not aimed at the best performance, but only the correct results.
The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent.
Only Python inference interface can load this format.
The DP backend has two formats, both of which are backend-independent:
The `.dp` format uses [HDF5](https://docs.h5py.org/) to store model serialization data, which has good performance.
The `.yaml` or `.yml` use [YAML](https://yaml.org/) to save the data as plain texts, which is easy to read for human beings.
Only Python inference interface can load these formats.

NumPy 1.21 or above is required.

Expand Down
10 changes: 10 additions & 0 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def setUp(self) -> None:
],
}
self.filename = "test_dp_dpmodel.dp"
self.filename_yaml = "test_dp_dpmodel.yaml"

def test_save_load_model(self):
save_dp_model(self.filename, {"model": deepcopy(self.model_dict)})
Expand All @@ -291,6 +292,15 @@ def test_save_load_model(self):
assert "software" in model
assert "version" in model

def test_save_load_model_yaml(self):
save_dp_model(self.filename_yaml, {"model": deepcopy(self.model_dict)})
model = load_dp_model(self.filename_yaml)
np.testing.assert_equal(model["model"], self.model_dict)
assert "software" in model
assert "version" in model

def tearDown(self) -> None:
if os.path.exists(self.filename):
os.remove(self.filename)
if os.path.exists(self.filename_yaml):
os.remove(self.filename_yaml)

0 comments on commit 7f61048

Please sign in to comment.