Skip to content

Commit

Permalink
Merge pull request #10 from thangckt/add_gpaw
Browse files Browse the repository at this point in the history
fetch branch add_gpaw
  • Loading branch information
thangckt authored May 10, 2024
2 parents cb1d971 + d5f66f2 commit 4faa79e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 96 deletions.
15 changes: 1 addition & 14 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dargs import Argument, Variant

from dpgen.arginfo import general_mdata_arginfo
from dpgen.generator.lib.gpaw import fp_style_gpaw_args


def run_mdata_arginfo() -> Argument:
Expand Down Expand Up @@ -905,20 +906,6 @@ def fp_style_custom_args() -> list[Argument]:
]


# gpaw
def fp_style_gpaw_args() -> list[Argument]:
doc_fp_gpaw_runfile = "Input file to run GPAW."
return [
Argument(
"fp_gpaw_runfile",
str,
optional=True,
default="gpaw_singlepoint.py",
doc=doc_fp_gpaw_runfile,
)
]


def fp_style_variant_type_args() -> Variant:
doc_fp_style = "Software for First Principles."
doc_amber_diff = (
Expand Down
96 changes: 96 additions & 0 deletions dpgen/generator/lib/gpaw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Move all the GPAW related functions to here.
- functions for `arginfo.py`
- functions for `run.py`
"""


import os
import glob
from pathlib import Path

from dargs import Argument
from dpgen.generator.lib.utils import make_iter_name
from dpgen.util import set_directory
import dpdata
import numpy as np


### functinos for `arginfo.py`
def fp_style_gpaw_args() -> list[Argument]:
args = [Argument("fp_gpaw_runfile",
str,
optional=True,
default="gpaw_singlepoint.py",
doc="Input file to run GPAW.",
)
]
return args


### functions for `run.py`
def make_fp_gpaw(iter_index, jdata, fp_name):
"""Make input file for customized FP style.
Parameters
----------
iter_index : int
iter index
jdata : dict
Run parameters.
"""
## create symbolic link of the gpaw input file in the task directory
work_path = os.path.join(make_iter_name(iter_index), fp_name)
fp_tasks = glob.glob(os.path.join(work_path, "task.*"))
gpaw_runfile = jdata["fp_gpaw_runfile"]
gpaw_runfile_source = Path(gpaw_runfile).resolve()
assert os.path.exists(
gpaw_runfile_source
), f"Can not find gpaw runfile {gpaw_runfile_source}"
for ii in fp_tasks:
with set_directory(Path(ii)):
# create file `gpaw_runfile` in the current directory and symlink it to the source file
Path(gpaw_runfile).symlink_to(gpaw_runfile_source)


def post_fp_gpaw(iter_index, jdata, fp_name):
"""Post fp for custom fp. Collect data from user-defined `output_fn`.
Parameters
----------
iter_index : int
The index of the current iteration.
jdata : dict
The parameter data.
"""
model_devi_jobs = jdata["model_devi_jobs"]
assert iter_index < len(model_devi_jobs)

iter_name = make_iter_name(iter_index)
work_path = os.path.join(iter_name, fp_name)
fp_tasks = glob.glob(os.path.join(work_path, "task.*"))
fp_tasks.sort()
if len(fp_tasks) == 0:
return

system_index = []
for ii in fp_tasks:
system_index.append(os.path.basename(ii).split(".")[1])
system_index.sort()
set_tmp = set(system_index)
system_index = list(set_tmp)
system_index.sort()

output_fn = "conf_ase.traj"
output_fmt = "ase/traj"

for ss in system_index:
sys_output = glob.glob(os.path.join(work_path, f"task.{ss}.*"))
sys_output.sort()
all_sys = dpdata.MultiSystems(type_map=jdata["type_map"])
for oo in sys_output:
if os.path.exists(os.path.join(oo, output_fn)):
sys = dpdata.LabeledSystem(os.path.join(oo, output_fn), fmt=output_fmt)
all_sys.append(sys)
sys_data_path = os.path.join(work_path, f"data.{ss}")
all_sys.to_deepmd_raw(sys_data_path)
all_sys.to_deepmd_npy(sys_data_path, set_size=len(sys_output), prec=np.float64)
100 changes: 18 additions & 82 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
set_directory,
setup_ele_temp,
)
from dpgen.generator.lib.gpaw import (make_fp_gpaw, post_fp_gpaw)

from .arginfo import run_jdata_arginfo

Expand Down Expand Up @@ -749,10 +750,8 @@ def run_train(iter_index, jdata, mdata):
)

train_command = mdata.get("train_command", "dp").strip()
assert train_command == "dp", "The 'train_command' should be 'dp'"
if suffix == ".pb":
train_command += " --tf"
elif suffix == ".pth":
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
if suffix == ".pth":
train_command += " --pt"

train_resources = mdata["train_resources"]
Expand Down Expand Up @@ -816,11 +815,14 @@ def run_train(iter_index, jdata, mdata):
if "srtab_file_path" in jdata.keys():
forward_files.append(zbl_file)
if training_init_model:
forward_files += [
os.path.join("old", "model.ckpt.meta"),
os.path.join("old", "model.ckpt.index"),
os.path.join("old", "model.ckpt.data-00000-of-00001"),
]
if suffix == ".pb":
forward_files += [
os.path.join("old", "model.ckpt.meta"),
os.path.join("old", "model.ckpt.index"),
os.path.join("old", "model.ckpt.data-00000-of-00001"),
]
elif suffix == ".pth":
forward_files += [os.path.join("old", "model.ckpt.pt")]
elif training_init_frozen_model is not None or training_finetune_model is not None:
forward_files.append(os.path.join("old", f"init{suffix}"))

Expand Down Expand Up @@ -2043,7 +2045,9 @@ def run_md_model_devi(iter_index, jdata, mdata):
command += f'&& echo -e "{grp_name}\\n{grp_name}\\n" | {model_devi_exec} trjconv -s {ref_filename} -f {deffnm}.trr -o {traj_filename} -pbc mol -ur compact -center'
command += "&& if [ ! -d traj ]; then \n mkdir traj; fi\n"
command += f"python -c \"import dpdata;system = dpdata.System('{traj_filename}', fmt='gromacs/gro'); [system.to_gromacs_gro('traj/%d.gromacstrj' % (i * {trj_freq}), frame_idx=i) for i in range(system.get_nframes())]; system.to_deepmd_npy('traj_deepmd')\""
command += f"&& dp model-devi -m ../graph.000{suffix} ../graph.001{suffix} ../graph.002{suffix} ../graph.003{suffix} -s traj_deepmd -o model_devi.out -f {trj_freq}"
_rel_model_names = " ".join([str(os.path.join("..", ii)) for ii in model_names])
command += f"&& dp model-devi -m {_rel_model_names} -s traj_deepmd -o model_devi.out -f {trj_freq}"
del _rel_model_names
commands = [command]

forward_files = [
Expand Down Expand Up @@ -2218,7 +2222,7 @@ def _read_model_devi_file(
assert all(
model_devi_content.shape[0] == model_devi_contents[0].shape[0]
for model_devi_content in model_devi_contents
), "Not all beads generated the same number of lines in the model_devi$\{ibead\}.out file. Check your pimd task carefully."
), r"Not all beads generated the same number of lines in the model_devi${ibead}.out file. Check your pimd task carefully."
last_step = model_devi_contents[0][-1, 0]
for ibead in range(1, num_beads):
model_devi_contents[ibead][:, 0] = model_devi_contents[ibead][
Expand Down Expand Up @@ -3762,30 +3766,6 @@ def make_fp_custom(iter_index, jdata):
system.to(input_fmt, input_fn)


def make_fp_gpaw(iter_index, jdata):
"""Make input file for customized FP style.
Parameters
----------
iter_index : int
iter index
jdata : dict
Run parameters.
"""
## create symbolic link of the gpaw input file in the task directory
work_path = os.path.join(make_iter_name(iter_index), fp_name)
fp_tasks = glob.glob(os.path.join(work_path, "task.*"))
gpaw_runfile = jdata["fp_gpaw_runfile"]
gpaw_runfile_source = Path(gpaw_runfile).resolve()
assert os.path.exists(
gpaw_runfile_source
), f"Can not find gpaw runfile {gpaw_runfile_source}"
for ii in fp_tasks:
with set_directory(Path(ii)):
# create file `gpaw_runfile` in the current directory and symlink it to the source file
Path(gpaw_runfile).symlink_to(gpaw_runfile_source)


def make_fp(iter_index, jdata, mdata):
"""Select the candidate strutures and make the input file of FP calculation.
Expand Down Expand Up @@ -3836,7 +3816,7 @@ def make_fp_calculation(iter_index, jdata, mdata):
elif fp_style == "custom":
make_fp_custom(iter_index, jdata)
elif fp_style == "gpaw":
make_fp_gpaw(iter_index, jdata)
make_fp_gpaw(iter_index, jdata, fp_name)
else:
raise RuntimeError("unsupported fp style")
# Copy user defined forward_files
Expand Down Expand Up @@ -4161,7 +4141,7 @@ def run_fp(iter_index, jdata, mdata):
)
elif fp_style == "gpaw":
gpaw_runfile = jdata["fp_gpaw_runfile"]
forward_files = ["POSCAR"] + [gpaw_runfile]
forward_files = ["POSCAR", gpaw_runfile]
backward_files = ["conf_ase.traj", "calc.txt", "run.log"]
run_fp_inner(
iter_index,
Expand Down Expand Up @@ -4673,50 +4653,6 @@ def post_fp_custom(iter_index, jdata):
all_sys.to_deepmd_npy(sys_data_path, set_size=len(sys_output), prec=np.float64)


def post_fp_gpaw(iter_index, jdata):
"""Post fp for custom fp. Collect data from user-defined `output_fn`.
Parameters
----------
iter_index : int
The index of the current iteration.
jdata : dict
The parameter data.
"""
model_devi_jobs = jdata["model_devi_jobs"]
assert iter_index < len(model_devi_jobs)

iter_name = make_iter_name(iter_index)
work_path = os.path.join(iter_name, fp_name)
fp_tasks = glob.glob(os.path.join(work_path, "task.*"))
fp_tasks.sort()
if len(fp_tasks) == 0:
return

system_index = []
for ii in fp_tasks:
system_index.append(os.path.basename(ii).split(".")[1])
system_index.sort()
set_tmp = set(system_index)
system_index = list(set_tmp)
system_index.sort()

output_fn = "conf_ase.traj"
output_fmt = "ase/traj"

for ss in system_index:
sys_output = glob.glob(os.path.join(work_path, "task.%s.*" % ss))
sys_output.sort()
all_sys = dpdata.MultiSystems(type_map=jdata["type_map"])
for oo in sys_output:
if os.path.exists(os.path.join(oo, output_fn)):
sys = dpdata.LabeledSystem(os.path.join(oo, output_fn), fmt=output_fmt)
all_sys.append(sys)
sys_data_path = os.path.join(work_path, "data.%s" % ss)
all_sys.to_deepmd_raw(sys_data_path)
all_sys.to_deepmd_npy(sys_data_path, set_size=len(sys_output), prec=np.float64)


def post_fp(iter_index, jdata):
fp_style = jdata["fp_style"]
if fp_style == "vasp":
Expand All @@ -4738,7 +4674,7 @@ def post_fp(iter_index, jdata):
elif fp_style == "custom":
post_fp_custom(iter_index, jdata)
elif fp_style == "gpaw":
post_fp_gpaw(iter_index, jdata)
post_fp_gpaw(iter_index, jdata, fp_name)
else:
raise RuntimeError("unsupported fp style")
post_fp_check_fail(iter_index, jdata)
Expand Down

0 comments on commit 4faa79e

Please sign in to comment.