From 3d548d5ab60e2d9efde7cad74dce739fd60a5154 Mon Sep 17 00:00:00 2001 From: Aliaksandr Yakutovich Date: Wed, 6 Mar 2024 16:17:03 +0000 Subject: [PATCH] Create add_first_snapshot_in_reftraj_section calcfunction. --- aiida_cp2k/workchains/base.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/aiida_cp2k/workchains/base.py b/aiida_cp2k/workchains/base.py index 97dba32..1ca625f 100644 --- a/aiida_cp2k/workchains/base.py +++ b/aiida_cp2k/workchains/base.py @@ -1,21 +1,13 @@ """Base work chain to run a CP2K calculation.""" -from aiida.common import AttributeDict -from aiida.engine import ( - BaseRestartWorkChain, - ProcessHandlerReport, - process_handler, - while_, -) -from aiida.orm import Bool, Dict -from aiida.plugins import CalculationFactory +from aiida import common, engine, orm, plugins from .. import utils -Cp2kCalculation = CalculationFactory('cp2k') +Cp2kCalculation = plugins.CalculationFactory('cp2k') -class Cp2kBaseWorkChain(BaseRestartWorkChain): +class Cp2kBaseWorkChain(engine.BaseRestartWorkChain): """Workchain to run a CP2K calculation with automated error handling and restarts.""" _process_class = Cp2kCalculation @@ -28,7 +20,7 @@ def define(cls, spec): spec.outline( cls.setup, - while_(cls.should_run_process)( + engine.while_(cls.should_run_process)( cls.run_process, cls.inspect_process, cls.overwrite_input_structure, @@ -37,7 +29,7 @@ def define(cls, spec): ) spec.expose_outputs(Cp2kCalculation) - spec.output('final_input_parameters', valid_type=Dict, required=False, + spec.output('final_input_parameters', valid_type=orm.Dict, required=False, help='The input parameters used for the final calculation.') spec.exit_code(400, 'NO_RESTART_DATA', message="The calculation didn't produce any data to restart from.") spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE', @@ -52,7 +44,7 @@ def setup(self): internal loop. """ super().setup() - self.ctx.inputs = AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k')) + self.ctx.inputs = common.AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k')) def results(self): super().results() @@ -63,7 +55,7 @@ def overwrite_input_structure(self): if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs: self.ctx.inputs.structure = self.ctx.children[self.ctx.iteration-1].outputs.output_structure - @process_handler(priority=401, exit_codes=[ + @engine.process_handler(priority=401, exit_codes=[ Cp2kCalculation.exit_codes.ERROR_OUT_OF_WALLTIME, Cp2kCalculation.exit_codes.ERROR_OUTPUT_INCOMPLETE, ], enabled=False) @@ -81,12 +73,12 @@ def restart_incomplete_calculation(self, calc): "Sending a signal to stop the Base work chain.") # Signaling to the base work chain that the problem could not be recovered. - return ProcessHandlerReport(True, self.exit_codes.NO_RESTART_DATA) + return engine.ProcessHandlerReport(True, self.exit_codes.NO_RESTART_DATA) self.ctx.inputs.parent_calc_folder = calc.outputs.remote_folder params = self.ctx.inputs.parameters - params = utils.add_wfn_restart_section(params, Bool('kpoints' in self.ctx.inputs)) + params = utils.add_wfn_restart_section(params, orm.Bool('kpoints' in self.ctx.inputs)) if restart_geometry_transformation: # Check if we need to fix restart snapshot in REFTRAJ MD @@ -103,4 +95,4 @@ def restart_incomplete_calculation(self, calc): self.report( "The CP2K calculation wasn't completed. The restart of the calculation might be able to " "fix the problem.") - return ProcessHandlerReport(False) + return engine.ProcessHandlerReport(False)