diff --git a/backend/read_env.py b/backend/read_env.py index c3fe2d5127..ae82778f4e 100644 --- a/backend/read_env.py +++ b/backend/read_env.py @@ -60,6 +60,8 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]: cmake_minimum_required_version = "3.21" cmake_args.append("-DUSE_ROCM_TOOLKIT:BOOL=TRUE") rocm_root = os.environ.get("ROCM_ROOT") + if not rocm_root: + rocm_root = os.environ.get("ROCM_PATH") if rocm_root: cmake_args.append(f"-DCMAKE_HIP_COMPILER_ROCM_ROOT:STRING={rocm_root}") hipcc_flags = os.environ.get("HIP_HIPCC_FLAGS") diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 353109d650..d5eae71731 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -602,230 +602,3 @@ def eval_typeebd(self) -> np.ndarray: def get_model_def_script(self) -> str: """Get model defination script.""" return self.model_def_script - - -# For tests only -def eval_model( - model, - coords: Union[np.ndarray, torch.Tensor], - cells: Optional[Union[np.ndarray, torch.Tensor]], - atom_types: Union[np.ndarray, torch.Tensor, List[int]], - spins: Optional[Union[np.ndarray, torch.Tensor]] = None, - atomic: bool = False, - infer_batch_size: int = 2, - denoise: bool = False, -): - model = model.to(DEVICE) - energy_out = [] - atomic_energy_out = [] - force_out = [] - force_mag_out = [] - virial_out = [] - atomic_virial_out = [] - updated_coord_out = [] - logits_out = [] - err_msg = ( - f"All inputs should be the same format, " - f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " - ) - return_tensor = True - if isinstance(coords, torch.Tensor): - if cells is not None: - assert isinstance(cells, torch.Tensor), err_msg - if spins is not None: - assert isinstance(spins, torch.Tensor), err_msg - assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) - atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - elif isinstance(coords, np.ndarray): - if cells is not None: - assert isinstance(cells, np.ndarray), err_msg - if spins is not None: - assert isinstance(spins, np.ndarray), err_msg - assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) - atom_types = np.array(atom_types, dtype=np.int32) - return_tensor = False - - nframes = coords.shape[0] - if len(atom_types.shape) == 1: - natoms = len(atom_types) - if isinstance(atom_types, torch.Tensor): - atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( - nframes, -1 - ) - else: - atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) - else: - natoms = len(atom_types[0]) - - coord_input = torch.tensor( - coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - spin_input = None - if spins is not None: - spin_input = torch.tensor( - spins.reshape([-1, natoms, 3]), - dtype=GLOBAL_PT_FLOAT_PRECISION, - device=DEVICE, - ) - has_spin = getattr(model, "has_spin", False) - if callable(has_spin): - has_spin = has_spin() - type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - box_input = None - if cells is None: - pbc = False - else: - pbc = True - box_input = torch.tensor( - cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) - - for ii in range(num_iter): - batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_box = None - batch_spin = None - if spin_input is not None: - batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - if pbc: - batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - input_dict = { - "coord": batch_coord, - "atype": batch_atype, - "box": batch_box, - "do_atomic_virial": atomic, - } - if has_spin: - input_dict["spin"] = batch_spin - batch_output = model(**input_dict) - if isinstance(batch_output, tuple): - batch_output = batch_output[0] - if not return_tensor: - if "energy" in batch_output: - energy_out.append(batch_output["energy"].detach().cpu().numpy()) - if "atom_energy" in batch_output: - atomic_energy_out.append( - batch_output["atom_energy"].detach().cpu().numpy() - ) - if "force" in batch_output: - force_out.append(batch_output["force"].detach().cpu().numpy()) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) - if "virial" in batch_output: - virial_out.append(batch_output["virial"].detach().cpu().numpy()) - if "atom_virial" in batch_output: - atomic_virial_out.append( - batch_output["atom_virial"].detach().cpu().numpy() - ) - if "updated_coord" in batch_output: - updated_coord_out.append( - batch_output["updated_coord"].detach().cpu().numpy() - ) - if "logits" in batch_output: - logits_out.append(batch_output["logits"].detach().cpu().numpy()) - else: - if "energy" in batch_output: - energy_out.append(batch_output["energy"]) - if "atom_energy" in batch_output: - atomic_energy_out.append(batch_output["atom_energy"]) - if "force" in batch_output: - force_out.append(batch_output["force"]) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"]) - if "virial" in batch_output: - virial_out.append(batch_output["virial"]) - if "atom_virial" in batch_output: - atomic_virial_out.append(batch_output["atom_virial"]) - if "updated_coord" in batch_output: - updated_coord_out.append(batch_output["updated_coord"]) - if "logits" in batch_output: - logits_out.append(batch_output["logits"]) - if not return_tensor: - energy_out = ( - np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype - ) - atomic_energy_out = ( - np.concatenate(atomic_energy_out) - if atomic_energy_out - else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype - ) - force_out = ( - np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - force_mag_out = ( - np.concatenate(force_mag_out) - if force_mag_out - else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - virial_out = ( - np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype - ) - atomic_virial_out = ( - np.concatenate(atomic_virial_out) - if atomic_virial_out - else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype - ) - updated_coord_out = ( - np.concatenate(updated_coord_out) if updated_coord_out else None - ) - logits_out = np.concatenate(logits_out) if logits_out else None - else: - energy_out = ( - torch.cat(energy_out) - if energy_out - else torch.zeros( - [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_energy_out = ( - torch.cat(atomic_energy_out) - if atomic_energy_out - else torch.zeros( - [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_out = ( - torch.cat(force_out) - if force_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_mag_out = ( - torch.cat(force_mag_out) - if force_mag_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - virial_out = ( - torch.cat(virial_out) - if virial_out - else torch.zeros( - [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_virial_out = ( - torch.cat(atomic_virial_out) - if atomic_virial_out - else torch.zeros( - [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None - logits_out = torch.cat(logits_out) if logits_out else None - if denoise: - return updated_coord_out, logits_out - else: - results_dict = { - "energy": energy_out, - "force": force_out, - "virial": virial_out, - } - if has_spin: - results_dict["force_mag"] = force_mag_out - if atomic: - results_dict["atom_energy"] = atomic_energy_out - results_dict["atom_virial"] = atomic_virial_out - return results_dict diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 91d32d0ea6..fb8d921076 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1034,10 +1034,13 @@ def save_model(self, save_path, lr=0.0, step=0): if dist.is_available() and dist.is_initialized() else self.wrapper ) - module.train_infos["lr"] = lr + module.train_infos["lr"] = float(lr) module.train_infos["step"] = step + optim_state_dict = deepcopy(self.optimizer.state_dict()) + for item in optim_state_dict["param_groups"]: + item["lr"] = float(item["lr"]) torch.save( - {"model": module.state_dict(), "optimizer": self.optimizer.state_dict()}, + {"model": module.state_dict(), "optimizer": optim_state_dict}, save_path, ) checkpoint_dir = save_path.parent diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 377953cc35..e794a36cab 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import os from abc import ( ABC, @@ -373,7 +374,11 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ # got paths starts with current path first, which is faster - subpaths = [ii for ii in self._keys if ii.startswith(self._name)] + subpaths = [ + ii + for ii in itertools.chain(self._keys, self._new_keys) + if ii.startswith(self._name) + ] return [ type(self)(f"{self.root_path}#{pp}", mode=self.mode) for pp in globfilter(subpaths, self._connect_path(pattern)) diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 6f17a272c6..a725be0133 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -155,7 +155,8 @@ The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is **Type**: Path; **Default**: Detected automatically -The path to the ROCM toolkit directory. +The path to the ROCM toolkit directory. If `ROCM_ROOT` is not set, it will look for `ROCM_PATH`; if `ROCM_PATH` is also not set, it will be detected using `hipconfig --rocmpath`. + ::: :::{envvar} DP_ENABLE_TENSORFLOW diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 8886522360..16b343be8a 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -1,7 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, + Union, +) + +import numpy as np +import torch + from deepmd.main import ( main, ) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, +) def run_dp(cmd: str) -> int: @@ -27,3 +40,229 @@ def run_dp(cmd: str) -> int: main(cmds) return 0 + + +def eval_model( + model, + coords: Union[np.ndarray, torch.Tensor], + cells: Optional[Union[np.ndarray, torch.Tensor]], + atom_types: Union[np.ndarray, torch.Tensor, List[int]], + spins: Optional[Union[np.ndarray, torch.Tensor]] = None, + atomic: bool = False, + infer_batch_size: int = 2, + denoise: bool = False, +): + model = model.to(DEVICE) + energy_out = [] + atomic_energy_out = [] + force_out = [] + force_mag_out = [] + virial_out = [] + atomic_virial_out = [] + updated_coord_out = [] + logits_out = [] + err_msg = ( + f"All inputs should be the same format, " + f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " + ) + return_tensor = True + if isinstance(coords, torch.Tensor): + if cells is not None: + assert isinstance(cells, torch.Tensor), err_msg + if spins is not None: + assert isinstance(spins, torch.Tensor), err_msg + assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) + atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE) + elif isinstance(coords, np.ndarray): + if cells is not None: + assert isinstance(cells, np.ndarray), err_msg + if spins is not None: + assert isinstance(spins, np.ndarray), err_msg + assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) + atom_types = np.array(atom_types, dtype=np.int32) + return_tensor = False + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + if isinstance(atom_types, torch.Tensor): + atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( + nframes, -1 + ) + else: + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + spin_input = None + if spins is not None: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + box_input = None + if cells is None: + pbc = False + else: + pbc = True + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) + + for ii in range(num_iter): + batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_box = None + batch_spin = None + if spin_input is not None: + batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + if pbc: + batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + input_dict = { + "coord": batch_coord, + "atype": batch_atype, + "box": batch_box, + "do_atomic_virial": atomic, + } + if has_spin: + input_dict["spin"] = batch_spin + batch_output = model(**input_dict) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + if not return_tensor: + if "energy" in batch_output: + energy_out.append(batch_output["energy"].detach().cpu().numpy()) + if "atom_energy" in batch_output: + atomic_energy_out.append( + batch_output["atom_energy"].detach().cpu().numpy() + ) + if "force" in batch_output: + force_out.append(batch_output["force"].detach().cpu().numpy()) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) + if "virial" in batch_output: + virial_out.append(batch_output["virial"].detach().cpu().numpy()) + if "atom_virial" in batch_output: + atomic_virial_out.append( + batch_output["atom_virial"].detach().cpu().numpy() + ) + if "updated_coord" in batch_output: + updated_coord_out.append( + batch_output["updated_coord"].detach().cpu().numpy() + ) + if "logits" in batch_output: + logits_out.append(batch_output["logits"].detach().cpu().numpy()) + else: + if "energy" in batch_output: + energy_out.append(batch_output["energy"]) + if "atom_energy" in batch_output: + atomic_energy_out.append(batch_output["atom_energy"]) + if "force" in batch_output: + force_out.append(batch_output["force"]) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"]) + if "virial" in batch_output: + virial_out.append(batch_output["virial"]) + if "atom_virial" in batch_output: + atomic_virial_out.append(batch_output["atom_virial"]) + if "updated_coord" in batch_output: + updated_coord_out.append(batch_output["updated_coord"]) + if "logits" in batch_output: + logits_out.append(batch_output["logits"]) + if not return_tensor: + energy_out = ( + np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype + ) + atomic_energy_out = ( + np.concatenate(atomic_energy_out) + if atomic_energy_out + else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype + ) + force_out = ( + np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + force_mag_out = ( + np.concatenate(force_mag_out) + if force_mag_out + else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + virial_out = ( + np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype + ) + atomic_virial_out = ( + np.concatenate(atomic_virial_out) + if atomic_virial_out + else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype + ) + updated_coord_out = ( + np.concatenate(updated_coord_out) if updated_coord_out else None + ) + logits_out = np.concatenate(logits_out) if logits_out else None + else: + energy_out = ( + torch.cat(energy_out) + if energy_out + else torch.zeros( + [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_energy_out = ( + torch.cat(atomic_energy_out) + if atomic_energy_out + else torch.zeros( + [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_out = ( + torch.cat(force_out) + if force_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_mag_out = ( + torch.cat(force_mag_out) + if force_mag_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + virial_out = ( + torch.cat(virial_out) + if virial_out + else torch.zeros( + [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_virial_out = ( + torch.cat(atomic_virial_out) + if atomic_virial_out + else torch.zeros( + [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None + logits_out = torch.cat(logits_out) if logits_out else None + if denoise: + return updated_coord_out, logits_out + else: + results_dict = { + "energy": energy_out, + "force": force_out, + "virial": virial_out, + } + if has_spin: + results_dict["force_mag"] = force_mag_out + if atomic: + results_dict["atom_energy"] = atomic_energy_out + results_dict["atom_virial"] = atomic_virial_out + return results_dict diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index d891583491..1adcff55fc 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -21,8 +21,10 @@ dtype = torch.float64 -from .test_permutation import ( +from ..common import ( eval_model, +) +from .test_permutation import ( model_dpa1, model_dpa2, model_hybrid, diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py index c9857a6343..87a3f5b06e 100644 --- a/source/tests/pt/model/test_forward_lower.py +++ b/source/tests/pt/model/test_forward_lower.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -20,6 +17,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py index 1dca7ee119..a2e0fa66db 100644 --- a/source/tests/pt/model/test_null_input.py +++ b/source/tests/pt/model/test_null_input.py @@ -5,9 +5,6 @@ import numpy as np import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, get_zbl_model, @@ -22,6 +19,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index f5edc6ef64..2fbc5fde3c 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -5,9 +5,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -18,6 +15,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) CUR_DIR = os.path.dirname(__file__) diff --git a/source/tests/pt/model/test_permutation_denoise.py b/source/tests/pt/model/test_permutation_denoise.py index 133c48f551..53bf55fb0f 100644 --- a/source/tests/pt/model/test_permutation_denoise.py +++ b/source/tests/pt/model/test_permutation_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 23bdede923..ca6a6375c8 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_rot_denoise.py b/source/tests/pt/model/test_rot_denoise.py index 5fe99a0d7a..9828ba5225 100644 --- a/source/tests/pt/model/test_rot_denoise.py +++ b/source/tests/pt/model/test_rot_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index c33dddfab5..9a7040f9cc 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_smooth_denoise.py b/source/tests/pt/model/test_smooth_denoise.py index 069c578d52..faa892c5d0 100644 --- a/source/tests/pt/model/test_smooth_denoise.py +++ b/source/tests/pt/model/test_smooth_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa2, ) diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index afd70f8995..b62fac1312 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_trans_denoise.py b/source/tests/pt/model/test_trans_denoise.py index 2d31d5de50..84ec21929c 100644 --- a/source/tests/pt/model/test_trans_denoise.py +++ b/source/tests/pt/model/test_trans_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index e225719e7f..3f068d5e5b 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa2, ) diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index f76be40b3f..febc439f50 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -92,7 +92,9 @@ def test_change_bias_with_data(self): run_dp( f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}" ) - state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_data_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper) @@ -114,7 +116,7 @@ def test_change_bias_with_data_sys_file(self): f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}" ) state_dict = torch.load( - str(self.model_path_data_file_bias), map_location=DEVICE + str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) @@ -134,7 +136,9 @@ def test_change_bias_with_user_defined(self): run_dp( f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}" ) - state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_user_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper)