Skip to content

Commit

Permalink
First working implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
yakutovicha committed Mar 6, 2024
1 parent 21eab70 commit f9e3fa7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
15 changes: 12 additions & 3 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Sequence

import numpy as np
from aiida import common, orm, plugins
from aiida import common, engine, orm, plugins


def _unpack(adict):
Expand Down Expand Up @@ -417,13 +417,16 @@ def write_pseudos(inp, pseudos, folder):
_write_gdt(inp, pseudos, folder, "POTENTIAL_FILE_NAME", "POTENTIAL")


@engine.calcfunction
def merge_trajectory_data(*trajectories):
if len(trajectories) < 0:
return None

final_trajectory = orm.TrajectoryData()
final_trajectory_dict = {}

array_names = trajectories[0].get_arraynames()
symbols = trajectories[0].symbols

for array_name in array_names:
if any(array_name not in traj.get_arraynames() for traj in trajectories):
raise ValueError(
Expand All @@ -432,6 +435,12 @@ def merge_trajectory_data(*trajectories):
merged_array = np.concatenate(
[traj.get_array(array_name) for traj in trajectories], axis=0
)
final_trajectory.set_array(array_name, merged_array)
final_trajectory_dict[array_name] = merged_array

final_trajectory.set_trajectory(
symbols=symbols, positions=final_trajectory_dict.pop("positions")
)
for array_name, array in final_trajectory_dict.items():
final_trajectory.set_array(array_name, array)

return final_trajectory
16 changes: 16 additions & 0 deletions aiida_cp2k/workchains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,27 @@ def setup(self):
super().setup()
self.ctx.inputs = common.AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k'))

def _collect_all_trajetories(self):
"""Collect all trajectories from the children calculations."""
trajectories = []
for called in self.ctx.children:
if isinstance(called, orm.CalcJobNode):
try:
trajectories.append(called.outputs.output_trajectory)
except AttributeError:
pass
return trajectories

def results(self):
super().results()
if self.inputs.cp2k.parameters != self.ctx.inputs.parameters:
self.out('final_input_parameters', self.ctx.inputs.parameters)

trajectories = self._collect_all_trajetories()
if trajectories:
self.report("Work chain completed successfully, collecting all trajectories")
self.out("output_trajectory", utils.merge_trajectory_data(*trajectories))

def overwrite_input_structure(self):
if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs:
self.ctx.inputs.structure = self.ctx.children[self.ctx.iteration-1].outputs.output_structure
Expand Down
29 changes: 29 additions & 0 deletions examples/workchains/example_base_md_reftraj_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def example_base(cp2k_code):
"ERROR, EXT_RESTART section is NOT present in the final_input_parameters."
)
sys.exit(1)

# Check stepids extracted from each individual calculation.
stepids = np.concatenate(
[
called.outputs.output_trajectory.get_stepids()
Expand All @@ -188,6 +190,33 @@ def example_base(cp2k_code):
)
sys.exit(1)

# Check the final trajectory.
final_trajectory = outputs["output_trajectory"]

if np.all(final_trajectory.get_stepids() == np.arange(1, steps + 1)):
print("OK, final trajectory stepids are correct.")
else:
print(
f"ERROR, final trajectory stepids are NOT correct. Expected: {np.arange(1, steps + 1)} but got: {final_trajectory.get_stepids()}"
)
sys.exit(1)

if final_trajectory.get_positions().shape == (steps, len(structure.sites), 3):
print("OK, the shape of the positions array is correct.")
else:
print(
f"ERROR, the shape of the positions array is NOT correct. Expected: {(steps, len(structure.sites), 3)} but got: {final_trajectory.get_positions().shape}"
)
sys.exit(1)

if final_trajectory.get_cells().shape == (steps, 3, 3):
print("OK, the shape of the cells array is correct.")
else:
print(
f"ERROR, the shape of the cells array is NOT correct. Expected: {(steps, 3, 3)} but got: {final_trajectory.get_cells().shape}"
)
sys.exit(1)


@click.command("cli")
@click.argument("codelabel")
Expand Down

0 comments on commit f9e3fa7

Please sign in to comment.