Skip to content

Commit

Permalink
Interface fix (#3)
Browse files Browse the repository at this point in the history
* Changed interface

* Removed archive from constructor

* Removed use of old variable

* Fixed test calls.
  • Loading branch information
lauri-codes authored May 3, 2024
1 parent dd94e69 commit 78eb29c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions simulationworkflownormalizer/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ class SimulationWorkflowNormalizer(Normalizer):
This normalizer produces information specific to a simulation workflow.
"""

def __init__(self, entry_archive: EntryArchive):
super().__init__(entry_archive)
def __init__(self):
super().__init__()
self._elastic_programs = ['elastic']
self._phonon_programs = ['phonopy']
self._molecular_dynamics_programs = ['lammps']

def _resolve_workflow(self):
if not self.entry_archive.run:
def _resolve_workflow(self, archive: EntryArchive):
if not archive.run:
return

# resolve it from parser
workflow = None
try:
program_name = self.entry_archive.run[-1].program.name
program_name = archive.run[-1].program.name
except Exception:
program_name = None

Expand All @@ -66,20 +66,20 @@ def _resolve_workflow(self):
if workflow is None:
# workflow references always to the last run
# TODO decide if workflow should map to each run
if len(self.entry_archive.run[-1].calculation) == 1:
if len(archive.run[-1].calculation) == 1:
workflow = SinglePoint()
else:
workflow = GeometryOptimization()

return workflow

def normalize(self, logger=None) -> None:
def normalize(self, archive: EntryArchive, logger=None) -> None:
logger = logger if logger is not None else get_logger(__name__)
super().normalize(logger)
super().normalize(archive, logger)

# Do nothing if section_run is not present
if not self.entry_archive.run:
# Do nothing if run section is not present
if not archive.run:
return

if not self.entry_archive.workflow2:
self.entry_archive.workflow2 = self._resolve_workflow()
if not archive.workflow2:
archive.workflow2 = self._resolve_workflow(archive)
4 changes: 2 additions & 2 deletions tests/test_simulationworkflownormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_resolve_workflow_from_program_name(
):
run = Run(program=Program(name=program_name))
entry_archive.run.append(run)
SimulationWorkflowNormalizer(entry_archive).normalize(get_logger(__name__))
SimulationWorkflowNormalizer().normalize(entry_archive, get_logger(__name__))
assert isinstance(entry_archive.workflow2, workflow_class)


Expand All @@ -58,5 +58,5 @@ def test_resolve_workflow_from_calculation(
):
run = Run(calculation=[Calculation() for _ in range(n_calculations)])
entry_archive.run.append(run)
SimulationWorkflowNormalizer(entry_archive).normalize(get_logger(__name__))
SimulationWorkflowNormalizer().normalize(entry_archive, get_logger(__name__))
assert isinstance(entry_archive.workflow2, workflow_class)

0 comments on commit 78eb29c

Please sign in to comment.