Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add the mlp_engine option #1576

Merged
merged 4 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def data_args() -> list[Argument]:
# Training


def training_args() -> list[Argument]:
def training_args_common() -> list[Argument]:
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
return [
Argument("numb_models", int, optional=False, doc=doc_numb_models),
]


def training_args_dp() -> list[Argument]:
"""Traning arguments.

Returns
Expand All @@ -90,7 +97,6 @@ def training_args() -> list[Argument]:
doc_train_backend = (
"The backend of the training. Currently only support tensorflow and pytorch."
)
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
doc_default_training_param = "Training parameters for deepmd-kit in 00.train. You can find instructions from `DeePMD-kit documentation <https://docs.deepmodeling.org/projects/deepmd/>`_."
Expand Down Expand Up @@ -133,7 +139,6 @@ def training_args() -> list[Argument]:
default="tensorflow",
doc=doc_train_backend,
),
Argument("numb_models", int, optional=False, doc=doc_numb_models),
Argument(
"training_iter0_model_path",
list[str],
Expand Down Expand Up @@ -224,6 +229,19 @@ def training_args() -> list[Argument]:
]


def training_args() -> Variant:
doc_mlp_engine = "Machine learning potential engine. Currently, only DeePMD-kit (defualt) is supported."
doc_dp = "DeePMD-kit."
return Variant(
"mlp_engine",
[
Argument("dp", dict, training_args_dp(), doc=doc_dp),
],
default_tag="dp",
doc=doc_mlp_engine,
)


# Exploration
def model_devi_jobs_template_args() -> Argument:
doc_template = (
Expand Down Expand Up @@ -987,7 +1005,11 @@ def run_jdata_arginfo() -> Argument:
return Argument(
"run_jdata",
dict,
sub_fields=basic_args() + data_args() + training_args() + fp_args(),
sub_variants=model_devi_args() + [fp_style_variant_type_args()],
sub_fields=basic_args() + data_args() + training_args_common() + fp_args(),
sub_variants=[
training_args(),
*model_devi_args(),
fp_style_variant_type_args(),
],
doc=doc_run_jdata,
)
44 changes: 36 additions & 8 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,19 @@

def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(

Check warning on line 138 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L138

Added line #L138 was not covered by tests
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix
raise ValueError(f"Unsupported engine: {mlp_engine}")

Check warning on line 143 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L143

Added line #L143 was not covered by tests


def get_job_names(jdata):
Expand Down Expand Up @@ -270,6 +274,14 @@


def make_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return make_train_dp(iter_index, jdata, mdata)
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")

Check warning on line 281 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L281

Added line #L281 was not covered by tests


def make_train_dp(iter_index, jdata, mdata):
# load json param
# train_param = jdata['train_param']
train_input_file = default_train_input_file
Expand Down Expand Up @@ -714,6 +726,14 @@


def run_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return make_train_dp(iter_index, jdata, mdata)
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")

Check warning on line 733 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L733

Added line #L733 was not covered by tests


def run_train_dp(iter_index, jdata, mdata):
# print("debug:run_train:mdata", mdata)
# load json param
numb_models = jdata["numb_models"]
Expand Down Expand Up @@ -899,6 +919,14 @@


def post_train(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
return post_train_dp(iter_index, jdata, mdata)

Check warning on line 924 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L922-L924

Added lines #L922 - L924 were not covered by tests
else:
raise ValueError(f"Unsupported engine: {mlp_engine}")

Check warning on line 926 in dpgen/generator/run.py

View check run for this annotation

Codecov / codecov/patch

dpgen/generator/run.py#L926

Added line #L926 was not covered by tests


def post_train_dp(iter_index, jdata, mdata):
# load json param
numb_models = jdata["numb_models"]
# paths
Expand Down
4 changes: 3 additions & 1 deletion dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
fp_style_siesta_args,
fp_style_vasp_args,
training_args,
training_args_common,
)


Expand Down Expand Up @@ -201,10 +202,11 @@ def simplify_jdata_arginfo() -> Argument:
*data_args(),
*general_simplify_arginfo(),
# simplify use the same training method as run
*training_args(),
*training_args_common(),
*fp_args(),
],
sub_variants=[
training_args(),
fp_style_variant_type_args(),
],
doc=doc_run_jdata,
Expand Down
8 changes: 8 additions & 0 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@


def init_model(iter_index, jdata, mdata):
mlp_engine = jdata.get("mlp_engine", "dp")
if mlp_engine == "dp":
init_model_dp(iter_index, jdata, mdata)

Check warning on line 108 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L106-L108

Added lines #L106 - L108 were not covered by tests
else:
raise TypeError(f"unsupported engine {mlp_engine}")

Check warning on line 110 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L110

Added line #L110 was not covered by tests


def init_model_dp(iter_index, jdata, mdata):
training_init_model = jdata.get("training_init_model", False)
if not training_init_model:
return
Expand Down