diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index e262404762..b6c6b8460f 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -118,33 +118,47 @@ def serialize(self) -> dict: def deserialize(cls): pass - def do_grad( + def do_grad_r( self, var_name: Optional[str] = None, ) -> bool: - """Tell if the output variable `var_name` is differentiable. - if var_name is None, returns if any of the variable is differentiable. + """Tell if the output variable `var_name` is r_differentiable. + if var_name is None, returns if any of the variable is r_differentiable. """ odef = self.fitting_output_def() if var_name is None: require: List[bool] = [] for vv in odef.keys(): - require.append(self.do_grad_(vv)) + require.append(self.do_grad_(vv, "r")) return any(require) else: - return self.do_grad_(var_name) + return self.do_grad_(var_name, "r") - def do_grad_( + def do_grad_c( self, - var_name: str, + var_name: Optional[str] = None, ) -> bool: + """Tell if the output variable `var_name` is c_differentiable. + if var_name is None, returns if any of the variable is c_differentiable. + + """ + odef = self.fitting_output_def() + if var_name is None: + require: List[bool] = [] + for vv in odef.keys(): + require.append(self.do_grad_(vv, "c")) + return any(require) + else: + return self.do_grad_(var_name, "c") + + def do_grad_(self, var_name: str, base: str) -> bool: """Tell if the output variable `var_name` is differentiable.""" assert var_name is not None - return ( - self.fitting_output_def()[var_name].r_differentiable - or self.fitting_output_def()[var_name].c_differentiable - ) + assert base in ["c", "r"] + if base == "c": + return self.fitting_output_def()[var_name].c_differentiable + return self.fitting_output_def()[var_name].r_differentiable setattr(BAM, fwd_method_name, BAM.fwd) delattr(BAM, "fwd") diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index c210945e76..f5acabf7b1 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -68,9 +68,14 @@ class DipoleFitting(GeneralFitting): mixed_types If true, use a uniform fitting net for all atom types, otherwise use different fitting nets for different atom types. - exclude_types: List[int] + exclude_types Atomic contributions of the excluded atom types are set zero. - + r_differentiable + If the variable is differentiated with respect to coordinates of atoms. + Only reduciable variable are differentiable. + c_differentiable + If the variable is differentiated with respect to the cell tensor (pbc case). + Only reduciable variable are differentiable. """ def __init__( @@ -94,6 +99,8 @@ def __init__( spin: Any = None, mixed_types: bool = False, exclude_types: List[int] = [], + r_differentiable: bool = True, + c_differentiable: bool = True, old_impl=False, ): # seed, uniform_seed are not included @@ -109,6 +116,8 @@ def __init__( raise NotImplementedError("atom_ener is not implemented") self.embedding_width = embedding_width + self.r_differentiable = r_differentiable + self.c_differentiable = c_differentiable super().__init__( var_name=var_name, ntypes=ntypes, @@ -139,6 +148,8 @@ def serialize(self) -> dict: data = super().serialize() data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl + data["r_differentiable"] = self.r_differentiable + data["c_differentiable"] = self.c_differentiable return data def output_def(self): @@ -148,8 +159,8 @@ def output_def(self): self.var_name, [3], reduciable=True, - r_differentiable=True, - c_differentiable=True, + r_differentiable=self.r_differentiable, + c_differentiable=self.c_differentiable, ), ] ) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index d828693fe0..c3cbe7bd1a 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -176,8 +176,8 @@ def output_def(self): self.var_name, [3, 3], reduciable=True, - r_differentiable=True, - c_differentiable=True, + r_differentiable=False, + c_differentiable=False, ), ] ) diff --git a/deepmd/infer/deep_tensor.py b/deepmd/infer/deep_tensor.py index a6cefa63c1..1bdc459920 100644 --- a/deepmd/infer/deep_tensor.py +++ b/deepmd/infer/deep_tensor.py @@ -105,6 +105,8 @@ def eval( **kwargs, ) sel_natoms = self._get_sel_natoms(atom_types[0]) + if sel_natoms == 0: + sel_natoms = atom_types.shape[-1] # set to natoms if atomic: return results[self.output_tensor_name].reshape(nframes, sel_natoms, -1) else: @@ -184,7 +186,10 @@ def eval_full( aparam=aparam, **kwargs, ) + sel_natoms = self._get_sel_natoms(atom_types[0]) + if sel_natoms == 0: + sel_natoms = atom_types.shape[-1] # set to natoms energy = results[f"{self.output_tensor_name}_redu"].reshape(nframes, -1) force = results[f"{self.output_tensor_name}_derv_r"].reshape( nframes, -1, natoms, 3 @@ -192,14 +197,13 @@ def eval_full( virial = results[f"{self.output_tensor_name}_derv_c_redu"].reshape( nframes, -1, 9 ) - atomic_energy = results[self.output_tensor_name].reshape( - nframes, sel_natoms, -1 - ) - atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape( - nframes, -1, natoms, 9 - ) - if atomic: + atomic_energy = results[self.output_tensor_name].reshape( + nframes, sel_natoms, -1 + ) + atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape( + nframes, -1, natoms, 9 + ) return ( energy, force, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 601bd6f755..f642d34d61 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -373,6 +373,9 @@ def _eval_model( shape = self._get_output_shape(odef, nframes, natoms) out = batch_output[pt_name].reshape(shape).detach().cpu().numpy() results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky return tuple(results) def _get_output_shape(self, odef, nframes, natoms): diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index e6dc395500..ecac50737b 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -146,7 +146,7 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape atype = extended_atype[:, :nloc] - if self.do_grad(): + if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 29ca9c8f96..16b06b2211 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -145,7 +145,7 @@ def forward_atomic( the result dict, defined by the fitting net output def. """ nframes, nloc, nnei = nlist.shape - if self.do_grad(): + if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index c79e742d63..d8b830d1eb 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -148,7 +148,7 @@ def forward_atomic( ) -> Dict[str, torch.Tensor]: nframes, nloc, nnei = nlist.shape extended_coord = extended_coord.view(nframes, -1, 3) - if self.do_grad(): + if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) # this will mask all -1 in the nlist diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py new file mode 100644 index 0000000000..220fdbb273 --- /dev/null +++ b/deepmd/pt/model/model/dipole_model.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + +from .dp_model import ( + DPModel, +) + + +class DipoleModel(DPModel): + model_type = "dipole" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole"): + model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2) + if self.do_grad_c("dipole"): + model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze( + -3 + ) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole"): + model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2) + if self.do_grad_c("dipole"): + model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze( + -3 + ) + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 8d71157b60..4683f62466 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -48,11 +48,12 @@ def forward( model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - if self.do_grad("energy"): + if self.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) - model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] return model_predict @@ -80,13 +81,12 @@ def forward_lower( model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - if self.do_grad("energy"): - model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -2 - ) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 2afeb2762b..946cfd20f8 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -42,13 +42,14 @@ def forward( model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - if self.do_grad("energy"): + if self.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( -3 ) - model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] else: @@ -79,13 +80,14 @@ def forward_lower( model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - if self.do_grad("energy"): + if self.do_grad_r("energy"): model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "energy_derv_c" - ].squeeze(-2) + ].squeeze(-3) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py new file mode 100644 index 0000000000..85aeebc0f5 --- /dev/null +++ b/deepmd/pt/model/model/polar_model.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + +from .dp_model import ( + DPModel, +) + + +class PolarModel(DPModel): + model_type = "polar" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["polar"] = model_ret["polar"] + model_predict["global_polar"] = model_ret["polar_redu"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["polar"] = model_ret["polar"] + model_predict["global_polar"] = model_ret["polar_redu"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 4ea66e2636..88391b1922 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -56,6 +56,12 @@ class DipoleFittingNet(GeneralFitting): The condition number for the regression of atomic energy. seed : int, optional Random seed. + r_differentiable + If the variable is differentiated with respect to coordinates of atoms. + Only reduciable variable are differentiable. + c_differentiable + If the variable is differentiated with respect to the cell tensor (pbc case). + Only reduciable variable are differentiable. """ def __init__( @@ -74,9 +80,13 @@ def __init__( rcond: Optional[float] = None, seed: Optional[int] = None, exclude_types: List[int] = [], + r_differentiable: bool = True, + c_differentiable: bool = True, **kwargs, ): self.embedding_width = embedding_width + self.r_differentiable = r_differentiable + self.c_differentiable = c_differentiable super().__init__( var_name=var_name, ntypes=ntypes, @@ -103,6 +113,8 @@ def serialize(self) -> dict: data = super().serialize() data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl + data["r_differentiable"] = self.r_differentiable + data["c_differentiable"] = self.c_differentiable return data def output_def(self) -> FittingOutputDef: @@ -112,8 +124,8 @@ def output_def(self) -> FittingOutputDef: self.var_name, [3], reduciable=True, - r_differentiable=True, - c_differentiable=True, + r_differentiable=self.r_differentiable, + c_differentiable=self.c_differentiable, ), ] ) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index dc8d13ee84..c240567903 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -144,8 +144,8 @@ def output_def(self) -> FittingOutputDef: self.var_name, [3, 3], reduciable=True, - r_differentiable=True, - c_differentiable=True, + r_differentiable=False, + c_differentiable=False, ), ] ) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 5ecfb72481..fb04e49484 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools +import os import unittest import numpy as np @@ -9,9 +10,15 @@ ) from deepmd.dpmodel.fitting import DipoleFitting as DPDipoleFitting +from deepmd.infer.deep_dipole import ( + DeepDipole, +) from deepmd.pt.model.descriptor.se_a import ( DescrptSeA, ) +from deepmd.pt.model.model.dipole_model import ( + DipoleModel, +) from deepmd.pt.model.task.dipole import ( DipoleFittingNet, ) @@ -32,6 +39,20 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION +def finite_difference(f, x, a, delta=1e-6): + in_shape = x.shape + y0 = f(x, a) + out_shape = y0.shape + res = np.empty(out_shape + in_shape) + for idx in np.ndindex(*in_shape): + diff = np.zeros(in_shape) + diff[idx] += delta + y1p = f(x + diff, a) + y1n = f(x - diff, a) + res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) + return res + + class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) @@ -269,5 +290,61 @@ def test_trans(self): np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) +class TestDipoleModel(unittest.TestCase): + def setUp(self): + self.natoms = 5 + self.rcut = 4.0 + self.nt = 3 + self.rcut_smth = 0.5 + self.sel = [46, 92, 4] + self.nf = 1 + self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device="cpu") + cell = torch.rand([3, 3], dtype=dtype, device="cpu") + self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu") + self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + self.ft0 = DipoleFittingNet( + "dipole", + self.nt, + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=0, + numb_aparam=0, + mixed_types=True, + ).to(env.DEVICE) + self.type_mapping = ["O", "H", "B"] + self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) + self.file_path = "model_output.pth" + + def test_auto_diff(self): + places = 5 + delta = 1e-5 + atype = self.atype.view(self.nf, self.natoms) + + def ff(coord, atype): + return self.model(coord, atype)["global_dipole"].detach().cpu().numpy() + + fdf = -finite_difference(ff, self.coord, atype, delta=delta) + rff = self.model(self.coord, atype)["force"].detach().cpu().numpy() + + np.testing.assert_almost_equal(fdf, rff.transpose(0, 2, 1, 3), decimal=places) + + def test_deepdipole_infer(self): + atype = self.atype.view(self.nf, self.natoms) + coord = self.coord.reshape(1, 5, 3) + cell = self.cell.reshape(1, 9) + jit_md = torch.jit.script(self.model) + torch.jit.save(jit_md, self.file_path) + load_md = DeepDipole(self.file_path) + load_md.eval(coords=coord, atom_types=atype, cells=cell, atomic=True) + load_md.eval(coords=coord, atom_types=atype, cells=cell, atomic=False) + load_md.eval_full(coords=coord, atom_types=atype, cells=cell, atomic=True) + load_md.eval_full(coords=coord, atom_types=atype, cells=cell, atomic=False) + + def tearDown(self) -> None: + if os.path.exists(self.file_path): + os.remove(self.file_path) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 548219627b..3f154383b2 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools +import os import unittest import numpy as np @@ -9,9 +10,15 @@ ) from deepmd.dpmodel.fitting import PolarFitting as DPPolarFitting +from deepmd.infer.deep_polar import ( + DeepPolar, +) from deepmd.pt.model.descriptor.se_a import ( DescrptSeA, ) +from deepmd.pt.model.model.polar_model import ( + PolarModel, +) from deepmd.pt.model.task.polarizability import ( PolarFittingNet, ) @@ -308,5 +315,46 @@ def test_trans(self): np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) +class TestDipoleModel(unittest.TestCase): + def setUp(self): + self.natoms = 5 + self.rcut = 4.0 + self.nt = 3 + self.rcut_smth = 0.5 + self.sel = [46, 92, 4] + self.nf = 1 + self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device="cpu") + cell = torch.rand([3, 3], dtype=dtype, device="cpu") + self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu") + self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + self.ft0 = PolarFittingNet( + "polar", + self.nt, + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=0, + numb_aparam=0, + mixed_types=True, + ).to(env.DEVICE) + self.type_mapping = ["O", "H", "B"] + self.model = PolarModel(self.dd0, self.ft0, self.type_mapping) + self.file_path = "model_output.pth" + + def test_deepdipole_infer(self): + atype = self.atype.view(self.nf, self.natoms) + coord = self.coord.reshape(1, 5, 3) + cell = self.cell.reshape(1, 9) + jit_md = torch.jit.script(self.model) + torch.jit.save(jit_md, self.file_path) + load_md = DeepPolar(self.file_path) + load_md.eval(coords=coord, atom_types=atype, cells=cell, atomic=True) + load_md.eval(coords=coord, atom_types=atype, cells=cell, atomic=False) + + def tearDown(self) -> None: + if os.path.exists(self.file_path): + os.remove(self.file_path) + + if __name__ == "__main__": unittest.main()