Skip to content

Commit

Permalink
supports dp convert-from 0.12 (#1685)
Browse files Browse the repository at this point in the history
Resolves #1583.
  • Loading branch information
njzjz authored May 6, 2022
1 parent e9b27a7 commit 087ae56
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 5 deletions.
6 changes: 4 additions & 2 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from deepmd.utils.convert import convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21
from deepmd.utils.convert import convert_012_to_21, convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21

def convert(
*,
Expand All @@ -7,7 +7,9 @@ def convert(
output_model: str,
**kwargs,
):
if FROM == '1.0':
if FROM == '0.12':
convert_012_to_21(input_model, output_model)
elif FROM == '1.0':
convert_10_to_21(input_model, output_model)
elif FROM in ['1.1', '1.2']:
# no difference between 1.1 and 1.2
Expand Down
3 changes: 1 addition & 2 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def parse_args(args: Optional[List[str]] = None):
)

# * convert models
# supported: 1.2->2.0, 1.3->2.0
parser_transform = subparsers.add_parser(
'convert-from',
parents=[parser_log],
Expand All @@ -392,7 +391,7 @@ def parse_args(args: Optional[List[str]] = None):
parser_transform.add_argument(
'FROM',
type = str,
choices = ['1.0', '1.1', '1.2', '1.3', '2.0'],
choices = ['0.12', '1.0', '1.1', '1.2', '1.3', '2.0'],
help="The original model compatibility",
)
parser_transform.add_argument(
Expand Down
40 changes: 40 additions & 0 deletions deepmd/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ def convert_10_to_21(input_model: str, output_model: str):
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_012_to_21(input_model: str, output_model: str):
"""Convert DP 0.12 graph to 2.1 graph.
Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp012_to_dp10('frozen_model.pbtxt')
convert_dp10_to_dp11('frozen_model.pbtxt')
convert_dp12_to_dp13('frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
if os.path.isfile('frozen_model.pbtxt'):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_20_to_21(input_model: str, output_model: str):
"""Convert DP 2.0 graph to 2.1 graph.
Expand Down Expand Up @@ -134,6 +156,24 @@ def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str):
tf.train.write_graph(graph_def, './', pbfile, as_text=False)


def convert_dp012_to_dp10(file: str):
"""Convert DP 1.0 graph text to 1.1 graph text.
Parameters
----------
file : str
filename of the graph text
"""
with open(file) as fp:
file_content = fp.read()
file_content = file_content\
.replace('DescrptNorot', 'DescrptSeA') \
.replace('ProdForceNorot', 'ProdForceSeA') \
.replace('ProdVirialNorot', 'ProdVirialSeA')
with open(file, 'w') as fp:
fp.write(file_content)


def convert_dp10_to_dp11(file: str):
"""Convert DP 1.0 graph text to 1.1 graph text.
Expand Down
2 changes: 1 addition & 1 deletion doc/troubleshooting/model-compatability.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ One can execute `dp convert-from` to convert an old model to a new one.

| Model version | v0.12 | v1.0 | v1.1 | v1.2 | v1.3 | v2.0 | v2.1 |
|:-:|:-----------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| Compatibility | 😢 | 😊 | 😊 | 😊 | 😊 | 😄 | 😄 |
| Compatibility | 😊 | 😊 | 😊 | 😊 | 😊 | 😄 | 😄 |

**Legend**:
- 😄: The model is compatible with the DeePMD-kit package.
Expand Down
26 changes: 26 additions & 0 deletions source/op/prod_env_mat_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ REGISTER_OP("DescrptSeA")
.Output("rij: T")
.Output("nlist: int32");

// alias of ProdEnvMatA -- compatible with v0.12
REGISTER_OP("DescrptNorot")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("coord: T")
.Input("type: int32")
.Input("natoms: int32")
.Input("box : T")
.Input("mesh : int32")
.Input("davg: T")
.Input("dstd: T")
.Attr("rcut_a: float")
.Attr("rcut_r: float")
.Attr("rcut_r_smth: float")
.Attr("sel_a: list(int)")
.Attr("sel_r: list(int)")
.Output("descrpt: T")
.Output("descrpt_deriv: T")
.Output("rij: T")
.Output("nlist: int32");

REGISTER_OP("ProdEnvMatR")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("coord: T")
Expand Down Expand Up @@ -1423,6 +1443,9 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptNorot").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdEnvMatROp<CPUDevice, T>);
Expand All @@ -1442,6 +1465,9 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptNorot").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("DescrptSeR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms").HostMemory("box"), \
ProdEnvMatROp<GPUDevice, T>);
Expand Down
17 changes: 17 additions & 0 deletions source/op/prod_force_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ REGISTER_OP("ProdForceSeA")
.Attr("n_r_sel: int")
.Output("force: T");

// compatible with v0.12
REGISTER_OP("ProdForceNorot")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("net_deriv: T")
.Input("in_deriv: T")
.Input("nlist: int32")
.Input("natoms: int32")
.Attr("n_a_sel: int")
.Attr("n_r_sel: int")
.Output("force: T");

// rename temp op
REGISTER_OP("ParallelProdForceSeA")
.Attr("T: {float, double} = DT_DOUBLE")
Expand Down Expand Up @@ -235,6 +246,9 @@ class ProdForceSeROp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ProdForceSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdForceSeAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdForceNorot").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdForceSeAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ParallelProdForceSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdForceSeAOp<CPUDevice, T>); \
Expand All @@ -249,6 +263,9 @@ REGISTER_CPU(double);
REGISTER_KERNEL_BUILDER( \
Name("ProdForceSeA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdForceSeAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdForceNorot").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdForceSeAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdForceSeR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdForceSeROp<GPUDevice, T>);
Expand Down
18 changes: 18 additions & 0 deletions source/op/prod_virial_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ REGISTER_OP("ProdVirialSeA")
.Attr("n_r_sel: int")
.Output("virial: T")
.Output("atom_virial: T");
// compatible with v0.12
REGISTER_OP("ProdVirialNorot")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("net_deriv: T")
.Input("in_deriv: T")
.Input("rij: T")
.Input("nlist: int32")
.Input("natoms: int32")
.Attr("n_a_sel: int")
.Attr("n_r_sel: int")
.Output("virial: T")
.Output("atom_virial: T");

REGISTER_OP("ProdVirialSeR")
.Attr("T: {float, double} = DT_DOUBLE")
Expand Down Expand Up @@ -220,6 +232,9 @@ class ProdVirialSeROp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialSeA").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdVirialSeAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialNorot").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdVirialSeAOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialSeR").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ProdVirialSeROp<CPUDevice, T>);
Expand All @@ -231,6 +246,9 @@ REGISTER_CPU(double);
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialSeA").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdVirialSeAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialNorot").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdVirialSeAOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ProdVirialSeR").Device(DEVICE_GPU).TypeConstraint<T>("T").HostMemory("natoms"), \
ProdVirialSeROp<GPUDevice, T>);
Expand Down

0 comments on commit 087ae56

Please sign in to comment.