Skip to content

Commit

Permalink
resolve conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jun 10, 2024
1 parent ab04399 commit 4599213
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 41 deletions.
16 changes: 12 additions & 4 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ def has_message_passing(self) -> bool:


def extend_descrpt_stat(des, type_map, des_with_stat=None):
"""
Extend the statistics of a descriptor block with types in type_map.
r"""
Extend the statistics of a descriptor block with types from newly provided `type_map`.
After extending, the type related dimension of the extended statistics will have a length of
`len(old_type_map) + len(type_map)`, where `old_type_map` represents the type map in `des`.
The `get_index_between_two_maps()` function can then be used to correctly select statistics for types
from `old_type_map` or `type_map`.
Positive indices from 0 to `len(old_type_map) - 1` will select old statistics of types in `old_type_map`,
while negative indices from `-len(type_map)` to -1 will select new statistics of types in `type_map`.
Parameters
----------
Expand All @@ -142,8 +149,9 @@ def extend_descrpt_stat(des, type_map, des_with_stat=None):
type_map : List[str]
The name of each type of atoms to be extended.
des_with_stat : DescriptorBlock, Optional
The descriptor block has additional statistics in type_map.
If None, the default statistics will be used. Otherwise, the statistics provided in this DescriptorBlock will be used.
The descriptor block has additional statistics of types from newly provided `type_map`.
If None, the default statistics will be used.
Otherwise, the statistics provided in this DescriptorBlock will be used.
"""
if des_with_stat is not None:
Expand Down
16 changes: 12 additions & 4 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,15 @@ def make_default_type_embedding(


def extend_descrpt_stat(des, type_map, des_with_stat=None):
"""
Extend the statistics of a descriptor block with types in type_map.
r"""
Extend the statistics of a descriptor block with types from newly provided `type_map`.
After extending, the type related dimension of the extended statistics will have a length of
`len(old_type_map) + len(type_map)`, where `old_type_map` represents the type map in `des`.
The `get_index_between_two_maps()` function can then be used to correctly select statistics for types
from `old_type_map` or `type_map`.
Positive indices from 0 to `len(old_type_map) - 1` will select old statistics of types in `old_type_map`,
while negative indices from `-len(type_map)` to -1 will select new statistics of types in `type_map`.
Parameters
----------
Expand All @@ -193,8 +200,9 @@ def extend_descrpt_stat(des, type_map, des_with_stat=None):
type_map : List[str]
The name of each type of atoms to be extended.
des_with_stat : DescriptorBlock, Optional
The descriptor block has additional statistics in type_map.
If None, the default statistics will be used. Otherwise, the statistics provided in this DescriptorBlock will be used.
The descriptor block has additional statistics of types from newly provided `type_map`.
If None, the default statistics will be used.
Otherwise, the statistics provided in this DescriptorBlock will be used.
"""
if des_with_stat is not None:
Expand Down
15 changes: 9 additions & 6 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
training_params = config["training"]
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
self.finetune_update_stat = False
self.model_keys = (
list(model_params["model_dict"]) if self.multi_task else ["Default"]
)
Expand Down Expand Up @@ -534,11 +535,10 @@ def get_model_for_wrapper(_model_params):
_model_key_from
].get_type_map()
):
model_with_new_type_stat = (
self.wrapper.model[model_key]
if finetune_rule_single.get_has_new_type()
else None
)
model_with_new_type_stat = None
if finetune_rule_single.get_has_new_type():
self.finetune_update_stat = True
model_with_new_type_stat = self.wrapper.model[model_key]
pretrained_model_wrapper.model[
_model_key_from
].change_type_map(
Expand Down Expand Up @@ -640,7 +640,10 @@ def single_model_finetune(

# Multi-task share params
if shared_links is not None:
self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0)
self.wrapper.share_params(
shared_links,
resume=(resuming and not self.finetune_update_stat) or self.rank != 0,
)

if dist.is_available() and dist.is_initialized():
torch.cuda.set_device(LOCAL_RANK)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def to_torch_tensor(
if xx is None:
return None
assert xx is not None
if not isinstance(xx, np.ndarray):
return xx
# Create a reverse mapping of NP_PRECISION_DICT
reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()}
# Use the reverse mapping to find keys with the desired value
Expand Down
8 changes: 6 additions & 2 deletions deepmd/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ def get_index_between_two_maps(
Returns
-------
index_map: List[int]
List contains len(new_map) indices, where index_map[i] is the index of new_map[i] in old_map.
If new_map[i] is not in the old_map, the index will be (i - len(new_map)).
List contains `len(new_map)` indices, where `index_map[i]` is the index of `new_map[i]` in `old_map`.
If `new_map[i]` is not in the `old_map`, the index will be `i - len(new_map)`.
has_new_type: bool
Whether there are unseen types in the new type_map.
If True, some type related params in the model, such as statistics, need to be extended
to have a length of `len(old_map) + len(new_map)` in the type related dimension.
Then positive indices from 0 to `len(old_map) - 1` will select old params of types in `old_map`,
while negative indices from `-len(new_map)` to -1 will select new params of types in `new_map`.
"""
missing_type = [i for i in new_map if i not in old_map]
has_new_type = False
Expand Down
152 changes: 152 additions & 0 deletions source/tests/common/test_type_index_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

from deepmd.utils.finetune import (
get_index_between_two_maps,
map_atom_exclude_types,
map_pair_exclude_types,
)


class TestTypeIndexMap(unittest.TestCase):
def test_get_index_between_two_maps(self):
tm_1 = [
"Al",
"F",
"N",
"H",
"S",
"O",
"He",
"C",
"Li",
"Na",
"Be",
"Mg",
"Si",
"B",
"Ne",
"P",
] # 16 elements
tm_2 = [
"P",
"Na",
"Si",
"Mg",
"C",
"O",
"Be",
"B",
"Li",
"S",
"Ne",
"N",
"H",
"Al",
"F",
"He",
] # 16 elements
tm_3 = ["O", "H", "Be", "C", "N", "B", "Li"] # 7 elements

# self consistence
old_tm = tm_1
new_tm = tm_1
expected_map = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
expected_has_new = False
result_map, result_has_new = get_index_between_two_maps(old_tm, new_tm)
self.assertEqual(len(result_map), len(new_tm))
self.assertEqual(expected_map, result_map)
self.assertEqual(expected_has_new, result_has_new)

# test resort
old_tm = tm_1
new_tm = tm_2
expected_map = [15, 9, 12, 11, 7, 5, 10, 13, 8, 4, 14, 2, 3, 0, 1, 6]
expected_has_new = False
result_map, result_has_new = get_index_between_two_maps(old_tm, new_tm)
self.assertEqual(len(result_map), len(new_tm))
self.assertEqual(expected_map, result_map)
self.assertEqual(expected_has_new, result_has_new)

# test slim
old_tm = tm_1
new_tm = tm_3
expected_map = [5, 3, 10, 7, 2, 13, 8]
expected_has_new = False
result_map, result_has_new = get_index_between_two_maps(old_tm, new_tm)
self.assertEqual(len(result_map), len(new_tm))
self.assertEqual(expected_map, result_map)
self.assertEqual(expected_has_new, result_has_new)

# test extend
old_tm = tm_3
new_tm = tm_1
expected_map = [-16, -15, 4, 1, -12, 0, -10, 3, 6, -7, 2, -5, -4, 5, -2, -1]
expected_has_new = True
result_map, result_has_new = get_index_between_two_maps(old_tm, new_tm)
self.assertEqual(len(result_map), len(new_tm))
self.assertEqual(expected_map, result_map)
self.assertEqual(expected_has_new, result_has_new)

def test_map_exclude_types(self):
old_tm = [
"Al",
"F",
"N",
"H",
"S",
"O",
"He",
"C",
"Li",
"Na",
"Be",
"Mg",
"Si",
"B",
"Ne",
"P",
] # 16 elements
new_tm = ["O", "H", "Be", "C", "N", "B", "Li"] # 7 elements
remap_index, _ = get_index_between_two_maps(old_tm, new_tm)
remap_index_reverse, _ = get_index_between_two_maps(new_tm, old_tm)
aem_1 = [0]
aem_2 = [0, 5]
aem_3 = [7, 8, 11]
pem_1 = [(0, 0), (0, 5)]
pem_2 = [(0, 0), (0, 5), (5, 8)]
pem_3 = [(0, 0), (0, 5), (8, 7)]

# test map_atom_exclude_types
expected_aem_1 = []
result_aem_1 = map_atom_exclude_types(aem_1, remap_index)
self.assertEqual(expected_aem_1, result_aem_1)

expected_aem_2 = [0]
result_aem_2 = map_atom_exclude_types(aem_2, remap_index)
self.assertEqual(expected_aem_2, result_aem_2)

expected_aem_3 = [3, 6]
result_aem_3 = map_atom_exclude_types(aem_3, remap_index)
self.assertEqual(expected_aem_3, result_aem_3)

expected_aem_1_reverse = [5]
result_aem_1_reverse = map_atom_exclude_types(aem_1, remap_index_reverse)
self.assertEqual(expected_aem_1_reverse, result_aem_1_reverse)

# test map_pair_exclude_types
expected_pem_1 = []
result_pem_1 = map_pair_exclude_types(pem_1, remap_index)
self.assertEqual(expected_pem_1, result_pem_1)

expected_pem_2 = [(0, 6)]
result_pem_2 = map_pair_exclude_types(pem_2, remap_index)
self.assertEqual(expected_pem_2, result_pem_2)

expected_pem_3 = [(6, 3)]
result_pem_3 = map_pair_exclude_types(pem_3, remap_index)
self.assertEqual(expected_pem_3, result_pem_3)

expected_pem_1_reverse = [(5, 5), (5, 13)]
result_pem_1_reverse = map_pair_exclude_types(pem_1, remap_index_reverse)
self.assertEqual(expected_pem_1_reverse, result_pem_1_reverse)
33 changes: 16 additions & 17 deletions source/tests/universal/common/cases/descriptor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from copy import (
deepcopy,
)
from random import (
shuffle,
)

import numpy as np

from deepmd.dpmodel.utils import (
PairExcludeMask,
)

from .....seed import (
GLOBAL_SEED,
)
from ..cases import (
TestCaseSingleFrameWithNlist,
)
Expand Down Expand Up @@ -132,31 +132,30 @@ def test_change_type_map(self):
"Cl",
"Ar",
] # 18 elements
rng = np.random.default_rng(GLOBAL_SEED)
for old_tm, new_tm, em, econf in itertools.product(
[
deepcopy(full_type_map_test), # 18 elements
deepcopy(
full_type_map_test[:16]
), # 16 elements, double of tebd default first dim
deepcopy(full_type_map_test[:8]), # 8 elements, tebd default first dim
full_type_map_test[:], # 18 elements
full_type_map_test[
:16
], # 16 elements, double of tebd default first dim
full_type_map_test[:8], # 8 elements, tebd default first dim
["H", "O"], # slimmed types
], # old_type_map
[
deepcopy(full_type_map_test), # 18 elements
deepcopy(
full_type_map_test[:16]
), # 16 elements, double of tebd default first dim
deepcopy(full_type_map_test[:8]), # 8 elements, tebd default first dim
full_type_map_test[:], # 18 elements
full_type_map_test[
:16
], # 16 elements, double of tebd default first dim
full_type_map_test[:8], # 8 elements, tebd default first dim
["H", "O"], # slimmed types
], # new_type_map
[[], [[0, 1]], [[1, 1]]], # exclude_types for original_type_map
[False, True], # use_econf_tebd
):
if len(old_tm) >= len(new_tm):
continue
# use shuffled type_map
shuffle(old_tm)
shuffle(new_tm)
rng.shuffle(old_tm)
rng.shuffle(new_tm)
old_tm_index = np.array(
[old_tm.index(i) for i in original_type_map], dtype=np.int32
)
Expand Down
18 changes: 10 additions & 8 deletions source/tests/universal/common/cases/fitting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from copy import (
deepcopy,
)
from random import (
shuffle,
)

import numpy as np

from deepmd.dpmodel.utils import (
AtomExcludeMask,
)

from .....seed import (
GLOBAL_SEED,
)
from ..cases import (
TestCaseSingleFrameWithNlist,
)
Expand Down Expand Up @@ -131,22 +131,25 @@ def test_change_type_map(self):
"Cl",
"Ar",
] # 18 elements
rng = np.random.default_rng(GLOBAL_SEED)
for old_tm, new_tm, em in itertools.product(
[
deepcopy(full_type_map_test[:8]), # 8 elements
full_type_map_test[:8], # 8 elements
["H", "O"], # slimmed types
], # large_type_map
[
deepcopy(full_type_map_test[:8]), # 8 elements
full_type_map_test[:8], # 8 elements
["H", "O"], # slimmed types
], # small_type_map
[
[],
[0],
[1],
], # exclude_types for original_type_map
):
# use shuffled type_map
shuffle(old_tm)
shuffle(new_tm)
rng.shuffle(old_tm)
rng.shuffle(new_tm)
old_tm_index = np.array(
[old_tm.index(i) for i in original_type_map], dtype=np.int32
)
Expand All @@ -161,7 +164,6 @@ def test_change_type_map(self):
old_tm_module = self.module_class(**old_tm_input)
serialize_dict = old_tm_module.serialize()
# set random bias
rng = np.random.default_rng()
serialize_dict["@variables"]["bias_atom_e"] = rng.random(
size=serialize_dict["@variables"]["bias_atom_e"].shape
)
Expand Down
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading

0 comments on commit 4599213

Please sign in to comment.