diff --git a/profit/run/command.py b/profit/run/command.py index b8784df..7ea867c 100644 --- a/profit/run/command.py +++ b/profit/run/command.py @@ -22,6 +22,7 @@ import numpy as np import time import subprocess +import f90nml from .worker import Component, Worker, Interface @@ -200,7 +201,7 @@ def __init__( def template_path(self): return os.path.join(os.environ.get("PROFIT_BASE_DIR", "."), self.path) - def prepare(self, data: Mapping): + def prepare(self, data: np.ndarray): # No call to super()! overrides the default directory creation if os.path.exists(self.run_dir): self.logger.error(f"run directory '{self.run_dir}' already exists") @@ -228,7 +229,7 @@ def fill_run_dir_single( os.path.exists(run_dir_single) and not ignore_path_exists ): # ToDo: make ignore_path_exists default if overwrite: - rmtree(run_dir_single) + shutil.rmtree(run_dir_single) else: raise RuntimeError("Run directory not empty: {}".format(run_dir_single)) self.copy_template(template_dir, run_dir_single) @@ -321,6 +322,71 @@ def replace_template(content, params): return content.format_map(SafeDict.from_params(params, pre=pre, post=post)) +# --- Fortran Namelist Preprocessor --- # + + +class NamelistPreprocessor(TemplatePreprocessor, label="namelist"): + """Preprocessor for Fortran Namelists + + This Preprocessor copies a given template directory (specified by `template`) + for each run and modifies the namelist file (specified by `file`) by setting the + input parameters in the namelist specified by `namelist`. + + Parameters: + run_dir (str/path): [Internal] path to the run directory to be filled + clean (bool): flag whether to remove the run directory upon completion + template (str/path): path to the template directory to copy for each run + (relative to the base directory) + file (str/path): path to / name of the namelist file into which the parameters + are inserted (relative to the template directory) + namelist (str): identifier for the namelist within the file + """ + + def __init__( + self, + run_dir: str, + *, + clean=True, + template="template", + file="input.nml", + namelist="indata", + logger_parent=None, + ): + Preprocessor.__init__( + self, run_dir=run_dir, clean=clean, logger_parent=logger_parent + ) + self.template = template + self.file = file + self.namelist = namelist + + @property + def path(self): + """required by TemplatePreprocessor""" + return self.template + + @property + def param_files(self): + """required by TemplatePreprocessor""" + return [self.file] + + def prepare(self, data: np.ndarray): + super().prepare(data) + if self.param_files is None: + return + if len(self.param_files) > 1: + raise ValueError( + "Namelist Preprocessor can only handle a single file for now!" + ) + with open(self.file) as f: + nml = f90nml.read(f) + for key in data.dtype.names: + if data[key].size > 1: + nml[self.namelist][key] = list(data[key]) + else: + nml[self.namelist][key] = data[key] + f90nml.write(nml, self.file, force=True) + + # === Postprocessor Component === # @@ -434,3 +500,21 @@ def HDF5Postprocessor(self, data): for key in f.keys(): if key in data.dtype.names: data[key] = f[key][:] + + +# --- netCDF Postprocessor --- # + + +@Postprocessor.wrap("netcdf", config=dict(path="stdout")) +def NetCDFPostprocessor(self, data): + """Postprocessor to read output from a HDF5 file + + - variables are assumed to be stored with the correct key and able to be converted immediately + - not extensively tested + """ + import xarray as xr + + ds = xr.open_dataset(self.path) + for key in ds.keys(): + if key in data.dtype.names: + data[key] = ds[key].values.flatten() diff --git a/profit/run/interface.py b/profit/run/interface.py index 1c592c0..616cb62 100644 --- a/profit/run/interface.py +++ b/profit/run/interface.py @@ -33,7 +33,13 @@ def __init__( self.logger.parent = logger_parent self.input_vars = [ - (variable, spec["dtype"].__name__) + ( + variable, + spec["dtype"].__name__, + () + if "size" not in spec or spec["size"] == (1, 1) + else (spec["size"][-1],), + ) for variable, spec in input_config.items() ] self.output_vars = [ diff --git a/profit/run/slurm.py b/profit/run/slurm.py index ce45102..1454875 100644 --- a/profit/run/slurm.py +++ b/profit/run/slurm.py @@ -66,13 +66,12 @@ def __init__( def __repr__(self): return ( - f"<{self.__class__.__name__} (" + f", {self.cpus} cpus" + ", OpenMP" - if self.openmp - else "" + ", debug" - if self.debug - else "" + ", custom script" - if self.custom - else "" + ")>" + f"<{self.__class__.__name__} (" + + f", {self.cpus} cpus" + + (", OpenMP" if self.openmp else "") + + (", debug" if self.debug else "") + + (", custom script" if self.custom else "") + + ")>" ) @property @@ -222,9 +221,6 @@ def generate_script(self): if value is not None: text += f"\n#SBATCH --{key}={value}" - text += """ -#SBATCH --ntasks=1 -""" if self.cpus == "all" or self.cpus == 0: text += """ #SBATCH --nodes=1 diff --git a/profit/run/zeromq.py b/profit/run/zeromq.py index b45f9e1..99fb96b 100644 --- a/profit/run/zeromq.py +++ b/profit/run/zeromq.py @@ -269,7 +269,8 @@ def request(self, request): if request == "READY": input_descr, input_data, output_descr = response input_descr = [ - tuple(column) for column in json.loads(input_descr.decode()) + tuple(column[:2] + [tuple(column[2])]) + for column in json.loads(input_descr.decode()) ] output_descr = [ tuple(column[:2] + [tuple(column[2])]) diff --git a/profit/util/util.py b/profit/util/util.py index 9e33e71..2d827d3 100644 --- a/profit/util/util.py +++ b/profit/util/util.py @@ -54,7 +54,10 @@ def params2map(params: Union[None, MutableMapping, np.ndarray, np.void]): if isinstance(params, MutableMapping): return params try: - return {key: params[key] for key in params.dtype.names} + return { + key: params[key].item() if params[key].size == 1 else params[key] + for key in params.dtype.names + } except AttributeError: pass raise TypeError("params are not a Mapping") diff --git a/setup.cfg b/setup.cfg index efca387..7ac264f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,10 @@ install_requires = dash>=1.20.0 tqdm zmq + f90nml + xarray importlib_metadata; python_version<'3.8' + importlib_metadata; python_version>='3.10' # required by chaospy tests_require = pytest packages = find: include_package_data = True diff --git a/tests/unit_tests/run/hdf5.post b/tests/unit_tests/run/post.hdf5 similarity index 100% rename from tests/unit_tests/run/hdf5.post rename to tests/unit_tests/run/post.hdf5 diff --git a/tests/unit_tests/run/json.post b/tests/unit_tests/run/post.json similarity index 100% rename from tests/unit_tests/run/json.post rename to tests/unit_tests/run/post.json diff --git a/tests/unit_tests/run/post.nc b/tests/unit_tests/run/post.nc new file mode 100644 index 0000000..5ddd6c0 Binary files /dev/null and b/tests/unit_tests/run/post.nc differ diff --git a/tests/unit_tests/run/numpytxt.post b/tests/unit_tests/run/post.numpytxt similarity index 100% rename from tests/unit_tests/run/numpytxt.post rename to tests/unit_tests/run/post.numpytxt diff --git a/tests/unit_tests/run/template/input.nml b/tests/unit_tests/run/template/input.nml new file mode 100644 index 0000000..ee40b79 --- /dev/null +++ b/tests/unit_tests/run/template/input.nml @@ -0,0 +1,4 @@ +&indata +u = 1 +v = 2 +/ diff --git a/tests/unit_tests/run/test_command.py b/tests/unit_tests/run/test_command.py index 623d734..809c36f 100644 --- a/tests/unit_tests/run/test_command.py +++ b/tests/unit_tests/run/test_command.py @@ -9,6 +9,7 @@ import logging import json import os +import f90nml @pytest.fixture(autouse=True) @@ -24,16 +25,17 @@ def chdir_pytest(): # === initialization === # -POSTPROCESSORS = ["numpytxt", "json", "hdf5"] +POSTPROCESSORS = ["numpytxt", "json", "hdf5", "netcdf"] INPUT_DTYPE = [("u", float), ("v", float)] OUTPUTS = {"f": [1.4, 1.3, 1.2], "g": 10} OUTPUT_DTYPE = [("f", float, (3,)), ("g", float)] OPTIONS = {"numpytxt": {"names": ["f", "g"]}} +POST_EXTENSION = {"netcdf": "nc"} @pytest.fixture def inputs(): - return {key: np.random.random() for key, dtype in INPUT_DTYPE} + return np.random.random(1).astype(INPUT_DTYPE)[0] @pytest.fixture(params=POSTPROCESSORS) @@ -42,7 +44,9 @@ def postprocessor(request, logger): from profit.run.command import Postprocessor return Postprocessor[label]( - path=f"{label}.post", **OPTIONS.get(label, {}), logger_parent=logger + path=f"post.{POST_EXTENSION.get(label, label)}", + **OPTIONS.get(label, {}), + logger_parent=logger, ) @@ -58,7 +62,7 @@ def __init__(self, *args, **kwargs): def retrieve(self): self.input = np.zeros(1, dtype=INPUT_DTYPE)[0] - for key in inputs: + for key in inputs.dtype.names: self.input[key] = inputs[key] self.output = np.zeros(1, dtype=OUTPUT_DTYPE)[0] self.retrieved = True @@ -86,7 +90,7 @@ def __init__(self, *args, **kwargs): self.posted = False def prepare(self, data): - for key in inputs: + for key in inputs.dtype.names: assert np.all(data[key] == inputs[key]) self.prepared = True @@ -131,7 +135,7 @@ def test_register(): assert CommandWorker.label in Worker.labels assert Worker[CommandWorker.label] is CommandWorker # all Preprocessors should be tested - assert Preprocessor.labels == {"template"} + assert Preprocessor.labels == {"template", "namelist"} # all Postprocessors should be tested assert Postprocessor.labels == set(POSTPROCESSORS) @@ -161,9 +165,26 @@ def test_template(inputs, logger): data_csv = np.loadtxt("template.csv", delimiter=",", dtype=INPUT_DTYPE) with open("template.json") as f: data_json = json.load(f) - for key, value in inputs.items(): - assert np.all(value == data_csv[key]) - assert np.all(value == data_json[key]) + for key in inputs.dtype.names: + assert np.all(inputs[key] == data_csv[key]) + assert np.all(inputs[key] == data_json[key]) + finally: + preprocessor.post() + + +def test_namelist(inputs, logger): + from profit.run.command import Preprocessor + + preprocessor = Preprocessor["namelist"]( + "run_test", clean=True, logger_parent=logger + ) + try: + preprocessor.prepare(inputs) + assert os.path.basename(os.getcwd()) == "run_test" + with open("input.nml") as f: + data = f90nml.read(f)["indata"] + for key in inputs.dtype.names: + assert np.all(inputs[key] == data[key]) finally: preprocessor.post() diff --git a/tests/unit_tests/run/test_interface.py b/tests/unit_tests/run/test_interface.py index 051c2cd..f0dbb90 100644 --- a/tests/unit_tests/run/test_interface.py +++ b/tests/unit_tests/run/test_interface.py @@ -27,6 +27,8 @@ def chdir_pytest(): LABELS = ["memmap", "zeromq"] TIMEOUT = 2 # s +LOOPS = 20 +LOOP_SLEEP = 0.1 @pytest.fixture(params=LABELS) @@ -123,11 +125,11 @@ def test_interface( """send & receive with default values""" # send & receive def run(): - for i in range(5): + for i in range(LOOPS): runner_interface.poll() if runner_interface.internal["DONE"][runid]: break - sleep(0.1) + sleep(LOOP_SLEEP) else: raise RuntimeError("timeout")