Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/tf): add bias changing param/interface #3933

Merged
merged 8 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser:
help="treat all types as a single type. Used with se_atten descriptor.",
)

# change_bias
parser_change_bias = subparsers.add_parser(
"change-bias",
parents=[parser_log],
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp change-bias model.pt -s data -n 10 -m change
"""
),
)
parser_change_bias.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group()
parser_change_bias_source.add_argument(
"-s",
"--system",
default=".",
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_change_bias_source.add_argument(
"-b",
"--bias-value",
default=None,
type=float,
nargs="+",
help="The user defined value for each type in the type_map of the model, split with spaces.\n"
"For example, '-93.57 -187.1' for energy bias of two elements. "
"Only supports energy bias changing.",
)
parser_change_bias.add_argument(
"-n",
"--numb-batch",
default=0,
type=int,
help="The number of frames for bias changing in one data system. 0 means all data.",
)
parser_change_bias.add_argument(
"-m",
"--mode",
type=str,
default="change",
choices=["change", "set"],
help="The mode for changing energy bias: \n"
"change (default) : perform predictions using input model on target dataset, "
"and do least square on the errors to obtain the target shift as bias.\n"
"set : directly use the statistic bias in the target dataset.",
)
parser_change_bias.add_argument(
"-o",
"--output",
default=None,
type=str,
help="The model after changing bias.",
)
parser_change_bias.add_argument(
"--model-branch",
type=str,
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)

iProzd marked this conversation as resolved.
Show resolved Hide resolved
# --version
parser.add_argument(
"--version", action="version", version=f"DeePMD-kit v{__version__}"
Expand Down Expand Up @@ -831,6 +897,7 @@ def main():
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
137 changes: 137 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
iProzd marked this conversation as resolved.
Show resolved Hide resolved
import json
import logging
import os
Expand All @@ -23,6 +24,9 @@
from deepmd import (
__version__,
)
from deepmd.common import (
expand_sys_str,
)
from deepmd.env import (
GLOBAL_CONFIG,
)
Expand All @@ -44,6 +48,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt.utils import (
env,
)
Expand All @@ -59,6 +66,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -377,6 +390,128 @@
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if FLAGS.INPUT.endswith(".pt"):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
model_state_dict = old_state_dict

Check warning on line 403 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L398-L403

Added lines #L398 - L403 were not covered by tests
else:
raise RuntimeError(

Check warning on line 405 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L405

Added line #L405 was not covered by tests
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_branch = FLAGS.model_branch
bias_adjust_mode = (
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
)
if multi_task:
assert (

Check warning on line 415 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L415

Added line #L415 was not covered by tests
model_branch is not None
), "For multitask model, the model branch must be set!"
assert model_branch in model_params["model_dict"], (

Check warning on line 418 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L418

Added line #L418 was not covered by tests
f"For multitask model, the model branch must be in the 'model_dict'! "
f"Available options are : {list(model_params['model_dict'].keys())}."
)
log.info(f"Changing out bias for model {model_branch}.")

Check warning on line 422 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L422

Added line #L422 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model = training.get_model_for_wrapper(model_params)
type_map = (
model_params["type_map"]
if not multi_task
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
else:
# for .pth
model.load_state_dict(old_state_dict)

Check warning on line 435 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L435

Added line #L435 was not covered by tests

if FLAGS.bias_value is not None:
# use user-defined bias
assert model_to_change.model_type in [
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"ener"
], "User-defined bias is only available for energy model!"
assert (
len(FLAGS.bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
f"Change output bias of {type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
)
updated_model = model_to_change
else:
# calculate bias on given systems
data_systems = process_systems(expand_sys_str(FLAGS.system))
data_single = DpLoaderSet(
data_systems,
1,
type_map,
)
mock_loss = training.get_loss(
{"inference": True}, 1.0, len(type_map), model_to_change
)
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
sampled_data = make_stat_input(
data_single.systems,
data_single.dataloaders,
nbatches,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
)

if not multi_task:
model = updated_model
else:
model[model_branch] = updated_model

Check warning on line 483 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L483

Added line #L483 was not covered by tests

if FLAGS.INPUT.endswith(".pt"):
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]

Check warning on line 497 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L496-L497

Added lines #L496 - L497 were not covered by tests
torch.save(old_state_dict, output_path)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else:
# for .pth
output_path = (

Check warning on line 501 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L501

Added line #L501 was not covered by tests
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(

Check warning on line 507 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L506-L507

Added lines #L506 - L507 were not covered by tests
model,
output_path,
{},
)
log.info(f"Saved model to {output_path}")


iProzd marked this conversation as resolved.
Show resolved Hide resolved
@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -401,6 +536,8 @@
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference
self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def init_out_stat(self):
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.out_bias = out_bias

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def forward_common(
def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.atomic_model.set_out_bias(out_bias)

def change_out_bias(
self,
merged,
Expand Down
Loading