Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 19, 2024
1 parent c42fa4b commit 1dac68c
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,13 +489,23 @@ def collect_single_finetune_params(
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
)
elif _new_fitting and ((".out_bias" in item_key) or (".out_std" in item_key)):
elif _new_fitting and (
(".out_bias" in item_key) or (".out_std" in item_key)
):
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
if _random_state_dict[item_key].shape[-1] != _origin_state_dict[new_key].shape[-1]:
assert _random_state_dict[item_key].shape[:-1] == _origin_state_dict[new_key].shape[:-1]
_origin_state_dict[new_key] = _origin_state_dict[new_key].expand(_random_state_dict[item_key].shape)
if (
_random_state_dict[item_key].shape[-1]
!= _origin_state_dict[new_key].shape[-1]
):
assert (
_random_state_dict[item_key].shape[:-1]
== _origin_state_dict[new_key].shape[:-1]
)
_origin_state_dict[new_key] = _origin_state_dict[
new_key
].expand(_random_state_dict[item_key].shape)
_new_state_dict[item_key] = (
_origin_state_dict[new_key].clone().detach()
)
Expand Down

0 comments on commit 1dac68c

Please sign in to comment.