Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SurroBIER #182

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions profit/run/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import time
import subprocess
import f90nml

from .worker import Component, Worker, Interface

Expand Down Expand Up @@ -200,7 +201,7 @@ def __init__(
def template_path(self):
return os.path.join(os.environ.get("PROFIT_BASE_DIR", "."), self.path)

def prepare(self, data: Mapping):
def prepare(self, data: np.ndarray):
# No call to super()! overrides the default directory creation
if os.path.exists(self.run_dir):
self.logger.error(f"run directory '{self.run_dir}' already exists")
Expand Down Expand Up @@ -228,7 +229,7 @@ def fill_run_dir_single(
os.path.exists(run_dir_single) and not ignore_path_exists
): # ToDo: make ignore_path_exists default
if overwrite:
rmtree(run_dir_single)
shutil.rmtree(run_dir_single)
else:
raise RuntimeError("Run directory not empty: {}".format(run_dir_single))
self.copy_template(template_dir, run_dir_single)
Expand Down Expand Up @@ -321,6 +322,71 @@ def replace_template(content, params):
return content.format_map(SafeDict.from_params(params, pre=pre, post=post))


# --- Fortran Namelist Preprocessor --- #


class NamelistPreprocessor(TemplatePreprocessor, label="namelist"):
"""Preprocessor for Fortran Namelists

This Preprocessor copies a given template directory (specified by `template`)
for each run and modifies the namelist file (specified by `file`) by setting the
input parameters in the namelist specified by `namelist`.

Parameters:
run_dir (str/path): [Internal] path to the run directory to be filled
clean (bool): flag whether to remove the run directory upon completion
template (str/path): path to the template directory to copy for each run
(relative to the base directory)
file (str/path): path to / name of the namelist file into which the parameters
are inserted (relative to the template directory)
namelist (str): identifier for the namelist within the file
"""

def __init__(
self,
run_dir: str,
*,
clean=True,
template="template",
file="input.nml",
namelist="indata",
logger_parent=None,
):
Preprocessor.__init__(
self, run_dir=run_dir, clean=clean, logger_parent=logger_parent
)
self.template = template
self.file = file
self.namelist = namelist

@property
def path(self):
"""required by TemplatePreprocessor"""
return self.template

@property
def param_files(self):
"""required by TemplatePreprocessor"""
return [self.file]

def prepare(self, data: np.ndarray):
super().prepare(data)
if self.param_files is None:
return
if len(self.param_files) > 1:
raise ValueError(
"Namelist Preprocessor can only handle a single file for now!"
)
with open(self.file) as f:
nml = f90nml.read(f)
for key in data.dtype.names:
if data[key].size > 1:
nml[self.namelist][key] = list(data[key])
else:
nml[self.namelist][key] = data[key]
f90nml.write(nml, self.file, force=True)


# === Postprocessor Component === #


Expand Down Expand Up @@ -434,3 +500,21 @@ def HDF5Postprocessor(self, data):
for key in f.keys():
if key in data.dtype.names:
data[key] = f[key][:]


# --- netCDF Postprocessor --- #


@Postprocessor.wrap("netcdf", config=dict(path="stdout"))
def NetCDFPostprocessor(self, data):
"""Postprocessor to read output from a HDF5 file

- variables are assumed to be stored with the correct key and able to be converted immediately
- not extensively tested
"""
import xarray as xr

ds = xr.open_dataset(self.path)
for key in ds.keys():
if key in data.dtype.names:
data[key] = ds[key].values.flatten()
8 changes: 7 additions & 1 deletion profit/run/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def __init__(
self.logger.parent = logger_parent

self.input_vars = [
(variable, spec["dtype"].__name__)
(
variable,
spec["dtype"].__name__,
()
if "size" not in spec or spec["size"] == (1, 1)
else (spec["size"][-1],),
)
for variable, spec in input_config.items()
]
self.output_vars = [
Expand Down
16 changes: 6 additions & 10 deletions profit/run/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def __init__(

def __repr__(self):
return (
f"<{self.__class__.__name__} (" + f", {self.cpus} cpus" + ", OpenMP"
if self.openmp
else "" + ", debug"
if self.debug
else "" + ", custom script"
if self.custom
else "" + ")>"
f"<{self.__class__.__name__} ("
+ f", {self.cpus} cpus"
+ (", OpenMP" if self.openmp else "")
+ (", debug" if self.debug else "")
+ (", custom script" if self.custom else "")
+ ")>"
)

@property
Expand Down Expand Up @@ -222,9 +221,6 @@ def generate_script(self):
if value is not None:
text += f"\n#SBATCH --{key}={value}"

text += """
#SBATCH --ntasks=1
"""
if self.cpus == "all" or self.cpus == 0:
text += """
#SBATCH --nodes=1
Expand Down
3 changes: 2 additions & 1 deletion profit/run/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def request(self, request):
if request == "READY":
input_descr, input_data, output_descr = response
input_descr = [
tuple(column) for column in json.loads(input_descr.decode())
tuple(column[:2] + [tuple(column[2])])
for column in json.loads(input_descr.decode())
]
output_descr = [
tuple(column[:2] + [tuple(column[2])])
Expand Down
5 changes: 4 additions & 1 deletion profit/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def params2map(params: Union[None, MutableMapping, np.ndarray, np.void]):
if isinstance(params, MutableMapping):
return params
try:
return {key: params[key] for key in params.dtype.names}
return {
key: params[key].item() if params[key].size == 1 else params[key]
for key in params.dtype.names
}
except AttributeError:
pass
raise TypeError("params are not a Mapping")
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ install_requires =
dash>=1.20.0
tqdm
zmq
f90nml
xarray
importlib_metadata; python_version<'3.8'
importlib_metadata; python_version>='3.10' # required by chaospy
tests_require = pytest
packages = find:
include_package_data = True
Expand Down
File renamed without changes.
File renamed without changes.
Binary file added tests/unit_tests/run/post.nc
Binary file not shown.
File renamed without changes.
4 changes: 4 additions & 0 deletions tests/unit_tests/run/template/input.nml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
&indata
u = 1
v = 2
/
39 changes: 30 additions & 9 deletions tests/unit_tests/run/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import json
import os
import f90nml


@pytest.fixture(autouse=True)
Expand All @@ -24,16 +25,17 @@ def chdir_pytest():
# === initialization === #


POSTPROCESSORS = ["numpytxt", "json", "hdf5"]
POSTPROCESSORS = ["numpytxt", "json", "hdf5", "netcdf"]
INPUT_DTYPE = [("u", float), ("v", float)]
OUTPUTS = {"f": [1.4, 1.3, 1.2], "g": 10}
OUTPUT_DTYPE = [("f", float, (3,)), ("g", float)]
OPTIONS = {"numpytxt": {"names": ["f", "g"]}}
POST_EXTENSION = {"netcdf": "nc"}


@pytest.fixture
def inputs():
return {key: np.random.random() for key, dtype in INPUT_DTYPE}
return np.random.random(1).astype(INPUT_DTYPE)[0]


@pytest.fixture(params=POSTPROCESSORS)
Expand All @@ -42,7 +44,9 @@ def postprocessor(request, logger):
from profit.run.command import Postprocessor

return Postprocessor[label](
path=f"{label}.post", **OPTIONS.get(label, {}), logger_parent=logger
path=f"post.{POST_EXTENSION.get(label, label)}",
**OPTIONS.get(label, {}),
logger_parent=logger,
)


Expand All @@ -58,7 +62,7 @@ def __init__(self, *args, **kwargs):

def retrieve(self):
self.input = np.zeros(1, dtype=INPUT_DTYPE)[0]
for key in inputs:
for key in inputs.dtype.names:
self.input[key] = inputs[key]
self.output = np.zeros(1, dtype=OUTPUT_DTYPE)[0]
self.retrieved = True
Expand Down Expand Up @@ -86,7 +90,7 @@ def __init__(self, *args, **kwargs):
self.posted = False

def prepare(self, data):
for key in inputs:
for key in inputs.dtype.names:
assert np.all(data[key] == inputs[key])
self.prepared = True

Expand Down Expand Up @@ -131,7 +135,7 @@ def test_register():
assert CommandWorker.label in Worker.labels
assert Worker[CommandWorker.label] is CommandWorker
# all Preprocessors should be tested
assert Preprocessor.labels == {"template"}
assert Preprocessor.labels == {"template", "namelist"}
# all Postprocessors should be tested
assert Postprocessor.labels == set(POSTPROCESSORS)

Expand Down Expand Up @@ -161,9 +165,26 @@ def test_template(inputs, logger):
data_csv = np.loadtxt("template.csv", delimiter=",", dtype=INPUT_DTYPE)
with open("template.json") as f:
data_json = json.load(f)
for key, value in inputs.items():
assert np.all(value == data_csv[key])
assert np.all(value == data_json[key])
for key in inputs.dtype.names:
assert np.all(inputs[key] == data_csv[key])
assert np.all(inputs[key] == data_json[key])
finally:
preprocessor.post()


def test_namelist(inputs, logger):
from profit.run.command import Preprocessor

preprocessor = Preprocessor["namelist"](
"run_test", clean=True, logger_parent=logger
)
try:
preprocessor.prepare(inputs)
assert os.path.basename(os.getcwd()) == "run_test"
with open("input.nml") as f:
data = f90nml.read(f)["indata"]
for key in inputs.dtype.names:
assert np.all(inputs[key] == data[key])
finally:
preprocessor.post()

Expand Down
6 changes: 4 additions & 2 deletions tests/unit_tests/run/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def chdir_pytest():

LABELS = ["memmap", "zeromq"]
TIMEOUT = 2 # s
LOOPS = 20
LOOP_SLEEP = 0.1


@pytest.fixture(params=LABELS)
Expand Down Expand Up @@ -123,11 +125,11 @@ def test_interface(
"""send & receive with default values"""
# send & receive
def run():
for i in range(5):
for i in range(LOOPS):
runner_interface.poll()
if runner_interface.internal["DONE"][runid]:
break
sleep(0.1)
sleep(LOOP_SLEEP)
else:
raise RuntimeError("timeout")

Expand Down