Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_crd2frag: use OB native PBC implementation #1006

Merged
merged 7 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions dpgen/generator/lib/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import dpdata
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.spatial import cKDTree
try:
# expect openbabel >= 3.1.0
from openbabel import openbabel
except ImportError:
try:
import openbabel
except ImportError:
pass
pass
try:
from ase import Atoms, Atom
from ase.data import atomic_numbers
Expand All @@ -26,21 +23,24 @@
def _crd2frag(symbols, crds, pbc=False, cell=None, return_bonds=False):
atomnumber = len(symbols)
all_atoms = Atoms(symbols = symbols, positions = crds, pbc=pbc, cell=cell)
if pbc:
repeated_atoms = all_atoms.repeat(2)[atomnumber:]
tree = cKDTree(crds)
d = tree.query(repeated_atoms.get_positions(), k=1)[0]
nearest = d < 5
ghost_atoms = repeated_atoms[nearest]
realnumber = np.where(nearest)[0] % atomnumber
all_atoms += ghost_atoms
# Use openbabel to connect atoms
mol = openbabel.OBMol()
mol.BeginModify()
for idx, (num, position) in enumerate(zip(all_atoms.get_atomic_numbers(), all_atoms.positions)):
atom = mol.NewAtom(idx)
atom.SetAtomicNum(int(num))
atom.SetVector(*position)
# Apply period boundry conditions
# openbabel#1853, supported in v3.1.0
if pbc:
uc = openbabel.OBUnitCell()
uc.SetData(
openbabel.vector3(cell[0][0], cell[0][1], cell[0][2]),
openbabel.vector3(cell[1][0], cell[1][1], cell[1][2]),
openbabel.vector3(cell[2][0], cell[2][1], cell[2][2]),
)
mol.CloneData(uc)
mol.SetPeriodicMol()
mol.ConnectTheDots()
mol.PerceiveBondOrders()
mol.EndModify()
Expand All @@ -50,13 +50,6 @@ def _crd2frag(symbols, crds, pbc=False, cell=None, return_bonds=False):
a = bond.GetBeginAtom().GetId()
b = bond.GetEndAtom().GetId()
bo = bond.GetBondOrder()
if a >= atomnumber and b >= atomnumber:
# duplicated
continue
elif a >= atomnumber:
a = realnumber[a-atomnumber]
elif b >= atomnumber:
b = realnumber[b-atomnumber]
bonds.extend([[a, b, bo], [b, a, bo]])
bonds = np.array(bonds, ndmin=2).reshape((-1, 3))
graph = csr_matrix(
Expand All @@ -77,11 +70,7 @@ def _crd2mul(symbols, crds):
mol = openbabel.OBMol()
conv.ReadString(mol, xyzstring)
gjfstring = conv.WriteString(mol)
try:
mul = int(gjfstring.split('\n')[4].split()[1])
except IndexError:
# openbabel 3.0
mul = int(gjfstring.split('\n')[5].split()[1])
mul = int(gjfstring.split('\n')[5].split()[1])
return mul


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
'dargs>=0.2.9',
'h5py',
'pymatgen-analysis-defects',
'openbabel-wheel',
]
requires-python = ">=3.8"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/generator/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from dpgen.generator.run import *
from dpgen.generator.lib.gaussian import detect_multiplicity
from dpgen.generator.lib.gaussian import detect_multiplicity, _crd2frag
from dpgen.generator.lib.ele_temp import NBandsEsti
from dpgen.generator.lib.lammps import get_dumped_forces
from dpgen.generator.lib.lammps import get_all_dumped_forces
Expand Down
18 changes: 17 additions & 1 deletion tests/generator/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
__package__ = 'generator'
from .context import take_cluster
from .context import take_cluster, _crd2frag
from .context import setUpModule
from .comp_sys import CompSys

Expand Down Expand Up @@ -36,5 +36,21 @@ def setUp (self) :
self.system_2.data['cells'] = self.system_1['cells']
self.places=0


class TestCrd2Frag(unittest.TestCase):
def test_crd2frag_pbc(self):
crds = np.array([[0., 0., 0.], [19., 19., 19.]])
symbols = ["O", "O"]
cell = np.diag([20., 20., 20.])
frag_numb, _ = _crd2frag(symbols, crds, pbc=True, cell=cell)
self.assertEqual(frag_numb, 1)

def test_crd2frag_nopbc(self):
crds = np.array([[0., 0., 0.], [19., 19., 19.]])
symbols = ["O", "O"]
frag_numb, _ = _crd2frag(symbols, crds, pbc=False)
self.assertEqual(frag_numb, 2)


if __name__ == '__main__':
unittest.main()
11 changes: 8 additions & 3 deletions tests/generator/test_gromacs_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def _copy_outputs(self, path_1, path_2):
shutil.copy(os.path.join(path_1, "model_devi.out"), os.path.join(path_2, "model_devi.out"))
shutil.copytree(os.path.join(path_1, "traj"), os.path.join(path_2, "traj"))

def test_make_model_devi_gromacs(self):

@unittest.skipIf(importlib.util.find_spec("openbabel") != None, "when openbabel is found, this test will be skipped. ")
def test_make_model_devi_gromacs_without_openbabel(self):
flag = make_model_devi(iter_index=0,
jdata=self.jdata,
mdata={"deepmd_version": "2.0"})
Expand All @@ -98,10 +100,13 @@ def test_make_model_devi_gromacs(self):
self._check_dir(self.model_devi_task_path, post=True)

@unittest.skipIf(importlib.util.find_spec("openbabel") is None, "requires openbabel")
def test_make_fp_gaussian(self):
def test_make_model_devi_gromacs_with_openbabel(self):
flag = make_model_devi(iter_index=0,
jdata=self.jdata,
mdata={"deepmd_version": "2.0"})
self._copy_outputs(os.path.join(self.dirname, "outputs"), self.model_devi_task_path)
make_fp_gaussian(iter_index=0, jdata=self.jdata)
candi = np.loadtxt(os.path.join(self.fp_path, "candidate.shuffled.000.out"), dtype=np.str)
candi_ref = np.loadtxt(os.path.join(self.dirname, "outputs", "candidate.shuffled.000.out"), dtype=np.str)
self.assertEqual(sorted([int(i) for i in candi[:,1]]), [0,10,20,30,50])


Expand Down