Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/tf): add bias changing param/interface #3933

Merged
merged 8 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser:
help="treat all types as a single type. Used with se_atten descriptor.",
)

# change_bias
parser_change_bias = subparsers.add_parser(
"change-bias",
parents=[parser_log],
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp change-bias model.pt -s data -n 10 -m change
"""
),
)
parser_change_bias.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group()
parser_change_bias_source.add_argument(
"-s",
"--system",
default=".",
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_change_bias_source.add_argument(
"-b",
"--bias-value",
default=None,
type=float,
nargs="+",
help="The user defined value for each type in the type_map of the model, split with spaces.\n"
"For example, '-93.57 -187.1' for energy bias of two elements. "
"Only supports energy bias changing.",
)
parser_change_bias.add_argument(
"-n",
"--numb-batch",
default=0,
type=int,
help="The number of frames for bias changing in one data system. 0 means all data.",
)
parser_change_bias.add_argument(
"-m",
"--mode",
type=str,
default="change",
choices=["change", "set"],
help="The mode for changing energy bias: \n"
"change (default) : perform predictions using input model on target dataset, "
"and do least square on the errors to obtain the target shift as bias.\n"
"set : directly use the statistic bias in the target dataset.",
)
parser_change_bias.add_argument(
"-o",
"--output",
default=None,
type=str,
help="The model after changing bias.",
)
parser_change_bias.add_argument(
"--model-branch",
type=str,
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)

iProzd marked this conversation as resolved.
Show resolved Hide resolved
# --version
parser.add_argument(
"--version", action="version", version=f"DeePMD-kit v{__version__}"
Expand Down Expand Up @@ -831,6 +897,7 @@ def main():
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
137 changes: 137 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
iProzd marked this conversation as resolved.
Show resolved Hide resolved
import json
import logging
import os
Expand All @@ -23,6 +24,9 @@
from deepmd import (
__version__,
)
from deepmd.common import (
expand_sys_str,
)
from deepmd.env import (
GLOBAL_CONFIG,
)
Expand All @@ -44,6 +48,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt.utils import (
env,
)
Expand All @@ -59,6 +66,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -377,6 +390,128 @@
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if FLAGS.INPUT.endswith(".pt"):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
model_state_dict = old_state_dict

Check warning on line 403 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L394-L403

Added lines #L394 - L403 were not covered by tests
else:
raise RuntimeError(

Check warning on line 405 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L405

Added line #L405 was not covered by tests
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model_branch = FLAGS.model_branch
bias_adjust_mode = (

Check warning on line 411 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L409-L411

Added lines #L409 - L411 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
)
if multi_task:
assert (

Check warning on line 415 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L414-L415

Added lines #L414 - L415 were not covered by tests
model_branch is not None
), "For multitask model, the model branch must be set!"
assert model_branch in model_params["model_dict"], (

Check warning on line 418 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L418

Added line #L418 was not covered by tests
f"For multitask model, the model branch must be in the 'model_dict'! "
f"Available options are : {list(model_params['model_dict'].keys())}."
)
log.info(f"Changing out bias for model {model_branch}.")
iProzd marked this conversation as resolved.
Show resolved Hide resolved
model = training.get_model_for_wrapper(model_params)
type_map = (

Check warning on line 424 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L422-L424

Added lines #L422 - L424 were not covered by tests
model_params["type_map"]
if not multi_task
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])

Check warning on line 432 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L429-L432

Added lines #L429 - L432 were not covered by tests
else:
# for .pth
model.load_state_dict(old_state_dict)

Check warning on line 435 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L435

Added line #L435 was not covered by tests

if FLAGS.bias_value is not None:

Check warning on line 437 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L437

Added line #L437 was not covered by tests
# use user-defined bias
assert model_to_change.model_type in [

Check warning on line 439 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L439

Added line #L439 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"ener"
], "User-defined bias is only available for energy model!"
assert (

Check warning on line 442 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L442

Added line #L442 was not covered by tests
len(FLAGS.bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(

Check warning on line 446 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L445-L446

Added lines #L445 - L446 were not covered by tests
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(

Check warning on line 450 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L449-L450

Added lines #L449 - L450 were not covered by tests
f"Change output bias of {type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
)
updated_model = model_to_change

Check warning on line 455 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L455

Added line #L455 was not covered by tests
else:
# calculate bias on given systems
data_systems = process_systems(expand_sys_str(FLAGS.system))
data_single = DpLoaderSet(

Check warning on line 459 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L458-L459

Added lines #L458 - L459 were not covered by tests
data_systems,
1,
type_map,
)
mock_loss = training.get_loss(

Check warning on line 464 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L464

Added line #L464 was not covered by tests
{"inference": True}, 1.0, len(type_map), model_to_change
)
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
sampled_data = make_stat_input(

Check warning on line 471 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L467-L471

Added lines #L467 - L471 were not covered by tests
data_single.systems,
data_single.dataloaders,
nbatches,
)
updated_model = training.model_change_out_bias(

Check warning on line 476 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L476

Added line #L476 was not covered by tests
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
)

if not multi_task:
model = updated_model

Check warning on line 481 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L480-L481

Added lines #L480 - L481 were not covered by tests
else:
model[model_branch] = updated_model

Check warning on line 483 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L483

Added line #L483 was not covered by tests

if FLAGS.INPUT.endswith(".pt"):
output_path = (

Check warning on line 486 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L485-L486

Added lines #L485 - L486 were not covered by tests
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]

Check warning on line 494 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L491-L494

Added lines #L491 - L494 were not covered by tests
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
torch.save(old_state_dict, output_path)

Check warning on line 498 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L496-L498

Added lines #L496 - L498 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else:
# for .pth
output_path = (

Check warning on line 501 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L501

Added line #L501 was not covered by tests
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(

Check warning on line 507 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L506-L507

Added lines #L506 - L507 were not covered by tests
model,
output_path,
{},
)
log.info(f"Saved model to {output_path}")

Check warning on line 512 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L512

Added line #L512 was not covered by tests


iProzd marked this conversation as resolved.
Show resolved Hide resolved
@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -401,6 +536,8 @@
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)

Check warning on line 540 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L539-L540

Added lines #L539 - L540 were not covered by tests
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference
self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.out_bias = out_bias

Check warning on line 107 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L107

Added line #L107 was not covered by tests

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@
def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.atomic_model.set_out_bias(out_bias)

Check warning on line 179 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L179

Added line #L179 was not covered by tests

def change_out_bias(
self,
merged,
Expand Down
Loading