From 34379ec899bf1a27f5c3c6ac57615c20be58be68 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 4 Jul 2024 20:24:21 -0400 Subject: [PATCH 1/4] chore: add the `mlp_engine` option Signed-off-by: Jinzhe Zeng --- dpgen/generator/arginfo.py | 18 ++++++++++++++--- dpgen/generator/run.py | 40 ++++++++++++++++++++++++++++++-------- dpgen/simplify/arginfo.py | 2 +- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index 92097af89..d8c162694 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -79,7 +79,7 @@ def data_args() -> list[Argument]: # Training -def training_args() -> list[Argument]: +def training_args_dp() -> list[Argument]: """Traning arguments. Returns @@ -224,6 +224,18 @@ def training_args() -> list[Argument]: ] +def training_args() -> Variant: + doc_mlp_engine = "Machine learning potential engine. Currently, only DeePMD-kit (defualt) is supported." + doc_dp = "DeePMD-kit." + return Variant( + "mlp_engine", + [ + Argument("dp", dict, training_args_dp(), doc=doc_dp), + ], + default_tag="dp", + doc=doc_mlp_engine, + ) + # Exploration def model_devi_jobs_template_args() -> Argument: doc_template = ( @@ -987,7 +999,7 @@ def run_jdata_arginfo() -> Argument: return Argument( "run_jdata", dict, - sub_fields=basic_args() + data_args() + training_args() + fp_args(), - sub_variants=model_devi_args() + [fp_style_variant_type_args()], + sub_fields=basic_args() + data_args() + fp_args(), + sub_variants=[training_args(), *model_devi_args(), fp_style_variant_type_args()], doc=doc_run_jdata, ) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 1e3e0e3fa..0470f97e9 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -128,15 +128,18 @@ def _get_model_suffix(jdata) -> str: """Return the model suffix based on the backend.""" - suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} - backend = jdata.get("train_backend", "tensorflow") - if backend in suffix_map: - suffix = suffix_map[backend] + if jdata.get("mlp_engine", "dp") == "dp": + suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} + backend = jdata.get("train_backend", "tensorflow") + if backend in suffix_map: + suffix = suffix_map[backend] + else: + raise ValueError( + f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." + ) + return suffix else: - raise ValueError( - f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." - ) - return suffix + raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) def get_job_names(jdata): @@ -270,6 +273,13 @@ def dump_to_deepmd_raw(dump, deepmd_raw, type_map, fmt="gromacs/gro", charge=Non def make_train(iter_index, jdata, mdata): + if jdata.get("mlp_engine", "dp"): + return make_train_dp(iter_index, jdata, mdata) + else: + raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + + +def make_train_dp(iter_index, jdata, mdata): # load json param # train_param = jdata['train_param'] train_input_file = default_train_input_file @@ -714,6 +724,13 @@ def get_nframes(system): def run_train(iter_index, jdata, mdata): + if jdata.get("mlp_engine", "dp"): + return make_train_dp(iter_index, jdata, mdata) + else: + raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + + +def run_train_dp(iter_index, jdata, mdata): # print("debug:run_train:mdata", mdata) # load json param numb_models = jdata["numb_models"] @@ -899,6 +916,13 @@ def run_train(iter_index, jdata, mdata): def post_train(iter_index, jdata, mdata): + if jdata.get("mlp_engine", "dp"): + return post_train_dp(iter_index, jdata, mdata) + else: + raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + + +def post_train_dp(iter_index, jdata, mdata): # load json param numb_models = jdata["numb_models"] # paths diff --git a/dpgen/simplify/arginfo.py b/dpgen/simplify/arginfo.py index 516b27e60..45d74ae9a 100644 --- a/dpgen/simplify/arginfo.py +++ b/dpgen/simplify/arginfo.py @@ -201,10 +201,10 @@ def simplify_jdata_arginfo() -> Argument: *data_args(), *general_simplify_arginfo(), # simplify use the same training method as run - *training_args(), *fp_args(), ], sub_variants=[ + training_args(), fp_style_variant_type_args(), ], doc=doc_run_jdata, From 6957426e681e230537ce26715c8bf6c08b042136 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 4 Jul 2024 20:28:29 -0400 Subject: [PATCH 2/4] mlp_engine Signed-off-by: Jinzhe Zeng --- dpgen/generator/run.py | 20 ++++++++++++-------- dpgen/simplify/simplify.py | 8 ++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 0470f97e9..976f9ef52 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -128,7 +128,8 @@ def _get_model_suffix(jdata) -> str: """Return the model suffix based on the backend.""" - if jdata.get("mlp_engine", "dp") == "dp": + mlp_engine = jdata.get("mlp_engine", "dp") + if mlp_engine == "dp": suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} backend = jdata.get("train_backend", "tensorflow") if backend in suffix_map: @@ -139,7 +140,7 @@ def _get_model_suffix(jdata) -> str: ) return suffix else: - raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + raise ValueError("Unsupported engine: {}".format(mlp_engine)) def get_job_names(jdata): @@ -273,10 +274,11 @@ def dump_to_deepmd_raw(dump, deepmd_raw, type_map, fmt="gromacs/gro", charge=Non def make_train(iter_index, jdata, mdata): - if jdata.get("mlp_engine", "dp"): + mlp_engine = jdata.get("mlp_engine", "dp") + if mlp_engine == "dp": return make_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + raise ValueError("Unsupported engine: {}".format(mlp_engine)) def make_train_dp(iter_index, jdata, mdata): @@ -724,10 +726,11 @@ def get_nframes(system): def run_train(iter_index, jdata, mdata): - if jdata.get("mlp_engine", "dp"): + mlp_engine = jdata.get("mlp_engine", "dp") + if mlp_engine == "dp": return make_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + raise ValueError("Unsupported engine: {}".format(mlp_engine)) def run_train_dp(iter_index, jdata, mdata): @@ -916,10 +919,11 @@ def run_train_dp(iter_index, jdata, mdata): def post_train(iter_index, jdata, mdata): - if jdata.get("mlp_engine", "dp"): + mlp_engine = jdata.get("mlp_engine", "dp") + if mlp_engine == "dp": return post_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(jdata.get("mlp_engine"))) + raise ValueError("Unsupported engine: {}".format(mlp_engine)) def post_train_dp(iter_index, jdata, mdata): diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 02fe54d79..24205fda3 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -103,6 +103,14 @@ def get_multi_system(path: Union[str, list[str]], jdata: dict) -> dpdata.MultiSy def init_model(iter_index, jdata, mdata): + mlp_engine = jdata.get("mlp_engine", "dp") + if mlp_engine == "dp": + init_model_dp(iter_index, jdata, mdata) + else: + raise TypeError(f"unsupported engine {mlp_engine}") + + +def init_model_dp(iter_index, jdata, mdata): training_init_model = jdata.get("training_init_model", False) if not training_init_model: return From d364a27030cf1bfee4e156a2acc2e35e08c8ca06 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 4 Jul 2024 20:31:23 -0400 Subject: [PATCH 3/4] make numb_models as common argument Signed-off-by: Jinzhe Zeng --- dpgen/generator/arginfo.py | 10 +++++++--- dpgen/simplify/arginfo.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index d8c162694..e8753fc47 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -78,6 +78,12 @@ def data_args() -> list[Argument]: # Training +def training_args_common() -> list[Argument]: + doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend." + return [ + Argument("numb_models", int, optional=False, doc=doc_numb_models), + ] + def training_args_dp() -> list[Argument]: """Traning arguments. @@ -90,7 +96,6 @@ def training_args_dp() -> list[Argument]: doc_train_backend = ( "The backend of the training. Currently only support tensorflow and pytorch." ) - doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend." doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models." doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path." doc_default_training_param = "Training parameters for deepmd-kit in 00.train. You can find instructions from `DeePMD-kit documentation `_." @@ -133,7 +138,6 @@ def training_args_dp() -> list[Argument]: default="tensorflow", doc=doc_train_backend, ), - Argument("numb_models", int, optional=False, doc=doc_numb_models), Argument( "training_iter0_model_path", list[str], @@ -999,7 +1003,7 @@ def run_jdata_arginfo() -> Argument: return Argument( "run_jdata", dict, - sub_fields=basic_args() + data_args() + fp_args(), + sub_fields=basic_args() + data_args() + training_args_common() + fp_args(), sub_variants=[training_args(), *model_devi_args(), fp_style_variant_type_args()], doc=doc_run_jdata, ) diff --git a/dpgen/simplify/arginfo.py b/dpgen/simplify/arginfo.py index 45d74ae9a..53507b2f6 100644 --- a/dpgen/simplify/arginfo.py +++ b/dpgen/simplify/arginfo.py @@ -12,6 +12,7 @@ fp_style_siesta_args, fp_style_vasp_args, training_args, + training_args_common, ) @@ -201,6 +202,7 @@ def simplify_jdata_arginfo() -> Argument: *data_args(), *general_simplify_arginfo(), # simplify use the same training method as run + *training_args_common(), *fp_args(), ], sub_variants=[ From c1e84123e1adebef58263c725138ee2788feab16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 00:33:51 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpgen/generator/arginfo.py | 8 +++++++- dpgen/generator/run.py | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index e8753fc47..6cc38bbed 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -78,6 +78,7 @@ def data_args() -> list[Argument]: # Training + def training_args_common() -> list[Argument]: doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend." return [ @@ -240,6 +241,7 @@ def training_args() -> Variant: doc=doc_mlp_engine, ) + # Exploration def model_devi_jobs_template_args() -> Argument: doc_template = ( @@ -1004,6 +1006,10 @@ def run_jdata_arginfo() -> Argument: "run_jdata", dict, sub_fields=basic_args() + data_args() + training_args_common() + fp_args(), - sub_variants=[training_args(), *model_devi_args(), fp_style_variant_type_args()], + sub_variants=[ + training_args(), + *model_devi_args(), + fp_style_variant_type_args(), + ], doc=doc_run_jdata, ) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 976f9ef52..d376d467a 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -140,7 +140,7 @@ def _get_model_suffix(jdata) -> str: ) return suffix else: - raise ValueError("Unsupported engine: {}".format(mlp_engine)) + raise ValueError(f"Unsupported engine: {mlp_engine}") def get_job_names(jdata): @@ -278,7 +278,7 @@ def make_train(iter_index, jdata, mdata): if mlp_engine == "dp": return make_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(mlp_engine)) + raise ValueError(f"Unsupported engine: {mlp_engine}") def make_train_dp(iter_index, jdata, mdata): @@ -730,7 +730,7 @@ def run_train(iter_index, jdata, mdata): if mlp_engine == "dp": return make_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(mlp_engine)) + raise ValueError(f"Unsupported engine: {mlp_engine}") def run_train_dp(iter_index, jdata, mdata): @@ -923,7 +923,7 @@ def post_train(iter_index, jdata, mdata): if mlp_engine == "dp": return post_train_dp(iter_index, jdata, mdata) else: - raise ValueError("Unsupported engine: {}".format(mlp_engine)) + raise ValueError(f"Unsupported engine: {mlp_engine}") def post_train_dp(iter_index, jdata, mdata):