diff --git a/glotaran/builtin/io/yml/test/test_model_parser.py b/glotaran/builtin/io/yml/test/test_model_parser.py index b19a48dab..8fa942b55 100644 --- a/glotaran/builtin/io/yml/test/test_model_parser.py +++ b/glotaran/builtin/io/yml/test/test_model_parser.py @@ -8,7 +8,7 @@ from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex from glotaran.builtin.megacomplexes.decay.initial_concentration import InitialConcentration from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian -from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeSkewedGaussian +from glotaran.builtin.megacomplexes.spectral.shape import SpectralShapeGaussian from glotaran.io import load_model from glotaran.model import DatasetModel from glotaran.model import Model @@ -25,7 +25,7 @@ def model(): spec_path = join(THIS_DIR, "test_model_spec.yml") m = load_model(spec_path) - print(m.markdown()) + print(m.markdown()) # noqa return m @@ -48,9 +48,20 @@ def test_dataset(model): assert dataset.irf == "irf1" assert dataset.scale == 1 + assert "dataset2" in model.dataset + dataset = model.dataset["dataset2"] + assert isinstance(dataset, DatasetModel) + assert dataset.label == "dataset2" + assert dataset.megacomplex == ["cmplx2"] + assert dataset.initial_concentration == "inputD2" + assert dataset.irf == "irf2" + assert dataset.scale == 2 + assert dataset.spectral_axis_scale == 1e7 + assert dataset.spectral_axis_inverted + def test_constraints(model): - print(model.constraints) + print(model.constraints) # noqa assert len(model.constraints) == 2 zero = model.constraints[0] @@ -77,7 +88,7 @@ def test_penalties(model): def test_relations(model): - print(model.relations) + print(model.relations) # noqa assert len(model.relations) == 1 rel = model.relations[0] @@ -154,7 +165,7 @@ def test_shapes(model): assert "shape1" in model.shape shape = model.shape["shape1"] - assert isinstance(shape, SpectralShapeSkewedGaussian) + assert isinstance(shape, SpectralShapeGaussian) assert shape.amplitude.full_label == "shape.1" assert shape.location.full_label == "shape.2" assert shape.width.full_label == "shape.3" diff --git a/glotaran/builtin/io/yml/test/test_model_spec.yml b/glotaran/builtin/io/yml/test/test_model_spec.yml index 499e61b8f..cf5dfce21 100644 --- a/glotaran/builtin/io/yml/test/test_model_spec.yml +++ b/glotaran/builtin/io/yml/test/test_model_spec.yml @@ -12,6 +12,8 @@ dataset: initial_concentration: inputD2 irf: irf2 scale: 2 + spectral_axis_scale: 1e7 + spectral_axis_inverted: true irf: irf1: @@ -54,7 +56,7 @@ k_matrix: shape: shape1: - type: "skewed-gaussian" + type: "gaussian" amplitude: shape.1 location: shape.2 width: shape.3 diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index 63fe29884..f0b0a42df 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -7,8 +7,6 @@ import yaml -from glotaran.builtin.io.yml.sanatize import check_deprecations -from glotaran.builtin.io.yml.sanatize import sanitize_yaml from glotaran.io import ProjectIoInterface from glotaran.io import load_dataset from glotaran.io import load_model @@ -21,6 +19,8 @@ from glotaran.parameter import ParameterGroup from glotaran.project import SavingOptions from glotaran.project import Scheme +from glotaran.utils.sanitize import check_deprecations +from glotaran.utils.sanitize import sanitize_yaml if TYPE_CHECKING: from glotaran.project import Result diff --git a/glotaran/builtin/megacomplexes/spectral/shape.py b/glotaran/builtin/megacomplexes/spectral/shape.py index 7823590b0..f727b9ea6 100644 --- a/glotaran/builtin/megacomplexes/spectral/shape.py +++ b/glotaran/builtin/megacomplexes/spectral/shape.py @@ -9,52 +9,16 @@ @model_item( properties={ - "amplitude": Parameter, + "amplitude": {"type": Parameter, "allow_none": True}, "location": Parameter, "width": Parameter, - "skewness": {"type": Parameter, "allow_none": True}, }, has_type=True, ) -class SpectralShapeSkewedGaussian: - """A (skewed) Gaussian spectral shape""" +class SpectralShapeGaussian: + """A Gaussian spectral shape""" def calculate(self, axis: np.ndarray) -> np.ndarray: - r"""Calculate a (skewed) Gaussian shape for a given ``axis``. - - If a non-zero ``skewness`` parameter was added - :func:`calculate_skewed_gaussian` will be used. - Otherwise it will use :func:`calculate_gaussian`. - - Parameters - ---------- - axis: np.ndarray - The axis to calculate the shape for. - - Returns - ------- - shape: numpy.ndarray - A Gaussian shape. - - See Also - -------- - calculate_gaussian - calculate_skewed_gaussian - - Note - ---- - Internally ``axis`` is converted from :math:`\mbox{nm}` to - :math:`1/\mbox{cm}`, thus ``location`` and ``width`` also need to - be provided in :math:`1/\mbox{cm}` (``1e7/value_in_nm``). - - """ - return ( - self.calculate_skewed_gaussian(axis) - if self.skewness is not None and not np.allclose(self.skewness, 0) - else self.calculate_gaussian(axis) - ) - - def calculate_gaussian(self, axis: np.ndarray) -> np.ndarray: r"""Calculate a normal Gaussian shape for a given ``axis``. The following equation is used for the calculation: @@ -91,11 +55,22 @@ def calculate_gaussian(self, axis: np.ndarray) -> np.ndarray: np.ndarray An array representing a Gaussian shape. """ - return self.amplitude * np.exp( - -np.log(2) * np.square(2 * (axis - self.location) / self.width) - ) + shape = np.exp(-np.log(2) * np.square(2 * (axis - self.location) / self.width)) + if self.amplitude is not None: + shape *= self.amplitude + return shape - def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: + +@model_item( + properties={ + "skewness": Parameter, + }, + has_type=True, +) +class SpectralShapeSkewedGaussian(SpectralShapeGaussian): + """A skewed Gaussian spectral shape""" + + def calculate(self, axis: np.ndarray) -> np.ndarray: r"""Calculate the skewed Gaussian shape for ``axis``. The following equation is used for the calculation: @@ -134,7 +109,7 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: Note that in the limit of skewness parameter :math:`b` equal to zero :math:`f(x, x_0, A, \Delta, b)` simplifies to a normal gaussian (since :math:`\lim_{b \to 0} \frac{\ln(1+bx)}{b}=x`), - see the definition in :func:`calculate_gaussian`. + see the definition in :func:`SpectralShapeGaussian.calculate`. Parameters ---------- @@ -147,14 +122,17 @@ def calculate_skewed_gaussian(self, axis: np.ndarray) -> np.ndarray: np.ndarray An array representing a skewed Gaussian shape. """ + if np.allclose(self.skewness, 0): + return super().calculate(axis) log_args = 1 + (2 * self.skewness * (axis - self.location) / self.width) - result = np.zeros(log_args.shape) + shape = np.zeros(log_args.shape) valid_arg_mask = np.where(log_args > 0) - result[valid_arg_mask] = self.amplitude * np.exp( + shape[valid_arg_mask] = np.exp( -np.log(2) * np.square(np.log(log_args[valid_arg_mask]) / self.skewness) ) - - return result + if self.amplitude is not None: + shape *= self.amplitude + return shape @model_item(properties={}, has_type=True) @@ -201,6 +179,7 @@ def calculate(self, axis: np.ndarray) -> np.ndarray: @model_item_typed( types={ + "gaussian": SpectralShapeGaussian, "skewed-gaussian": SpectralShapeSkewedGaussian, "one": SpectralShapeOne, "zero": SpectralShapeZero, diff --git a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py index b25ec7151..ab2f70900 100644 --- a/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py +++ b/glotaran/builtin/megacomplexes/spectral/spectral_megacomplex.py @@ -14,7 +14,10 @@ @megacomplex( dimension="spectral", - properties={"energy_spectrum": {"type": bool, "default": False}}, + dataset_properties={ + "spectral_axis_inverted": {"type": bool, "default": False}, + "spectral_axis_scale": {"type": float, "default": 1}, + }, model_items={ "shape": Dict[str, SpectralShape], }, @@ -35,8 +38,10 @@ def calculate_matrix( compartments.append(compartment) model_axis = dataset_model.get_model_axis() - if self.energy_spectrum: - model_axis = 1e7 / model_axis + if dataset_model.spectral_axis_inverted: + model_axis = dataset_model.spectral_axis_scale / model_axis + elif dataset_model.spectral_axis_scale != 1: + model_axis = model_axis * dataset_model.spectral_axis_scale dim1 = model_axis.size dim2 = len(self.shape) diff --git a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py index 9e163292d..ce559a31d 100644 --- a/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py +++ b/glotaran/builtin/megacomplexes/spectral/test/test_spectral_model.py @@ -23,7 +23,7 @@ def from_dict(cls, model_dict): ) -class OneCompartmentModel: +class OneCompartmentModelInvertedAxis: decay_model = DecayModel.from_dict( { "initial_concentration": { @@ -59,7 +59,7 @@ class OneCompartmentModel: }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "1", "location": "2", "width": "3", @@ -68,12 +68,76 @@ class OneCompartmentModel: "dataset": { "dataset1": { "megacomplex": ["mc1"], + "spectral_axis_scale": 1e7, + "spectral_axis_inverted": True, + }, + }, + } + ) + + spectral_parameters = ParameterGroup.from_list([7, 1e7 / 10000, 800, -1]) + + time = np.arange(-10, 50, 1.5) + spectral = np.arange(5000, 15000, 20) + axis = {"time": time, "spectral": spectral} + + decay_dataset_model = decay_model.dataset["dataset1"].fill(decay_model, decay_parameters) + decay_dataset_model.overwrite_global_dimension("spectral") + decay_dataset_model.set_coordinates(axis) + matrix = calculate_matrix(decay_dataset_model, {}) + decay_compartments = matrix.clp_labels + clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) + + +class OneCompartmentModelNegativeSkew: + decay_model = DecayModel.from_dict( + { + "initial_concentration": { + "j1": {"compartments": ["s1"], "parameters": ["2"]}, + }, + "megacomplex": { + "mc1": {"k_matrix": ["k1"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s1", "s1"): "1", + } + } + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "megacomplex": ["mc1"], }, }, } ) - spectral_parameters = ParameterGroup.from_list([7, 20000, 800]) + decay_parameters = ParameterGroup.from_list( + [101e-4, [1, {"vary": False, "non-negative": False}]] + ) + + spectral_model = SpectralModel.from_dict( + { + "megacomplex": { + "mc1": {"shape": {"s1": "sh1"}}, + }, + "shape": { + "sh1": { + "type": "skewed-gaussian", + "location": "1", + "width": "2", + "skewness": "3", + } + }, + "dataset": { + "dataset1": {"megacomplex": ["mc1"], "spectral_axis_scale": 2}, + }, + } + ) + + spectral_parameters = ParameterGroup.from_list([1000, 80, -1]) time = np.arange(-10, 50, 1.5) spectral = np.arange(400, 600, 5) @@ -87,6 +151,14 @@ class OneCompartmentModel: clp = xr.DataArray(matrix.matrix, coords=[("time", time), ("clp_label", decay_compartments)]) +class OneCompartmentModelPositivSkew(OneCompartmentModelNegativeSkew): + spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 1]) + + +class OneCompartmentModelZeroSkew(OneCompartmentModelNegativeSkew): + spectral_parameters = ParameterGroup.from_list([7, 20000, 800, 0]) + + class ThreeCompartmentModel: decay_model = DecayModel.from_dict( { @@ -131,23 +203,22 @@ class ThreeCompartmentModel: }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "1", "location": "2", "width": "3", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "4", "location": "5", "width": "6", }, "sh3": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "7", "location": "8", "width": "9", - "skewness": "10", }, }, "dataset": { @@ -161,15 +232,14 @@ class ThreeCompartmentModel: spectral_parameters = ParameterGroup.from_list( [ 7, - 20000, - 800, + 450, + 80, 20, - 22000, - 500, + 550, + 50, + 10, + 580, 10, - 18000, - 650, - 0.1, ] ) @@ -188,26 +258,28 @@ class ThreeCompartmentModel: @pytest.mark.parametrize( "suite", [ - OneCompartmentModel, + OneCompartmentModelNegativeSkew, + OneCompartmentModelPositivSkew, + OneCompartmentModelZeroSkew, ThreeCompartmentModel, ], ) def test_spectral_model(suite): model = suite.spectral_model - print(model.validate()) + print(model.validate()) # noqa assert model.valid() wanted_parameters = suite.spectral_parameters - print(model.validate(wanted_parameters)) - print(wanted_parameters) + print(model.validate(wanted_parameters)) # noqa + print(wanted_parameters) # noqa assert model.valid(wanted_parameters) initial_parameters = suite.spectral_parameters - print(model.validate(initial_parameters)) + print(model.validate(initial_parameters)) # noqa assert model.valid(initial_parameters) - print(model.markdown(initial_parameters)) + print(model.markdown(initial_parameters)) # noqa dataset = simulate(model, "dataset1", wanted_parameters, suite.axis, suite.clp) @@ -222,7 +294,7 @@ def test_spectral_model(suite): maximum_number_function_evaluations=20, ) result = optimize(scheme) - print(result.optimized_parameters) + print(result.optimized_parameters) # noqa for label, param in result.optimized_parameters.all(): assert np.allclose(param.value, wanted_parameters.get(label).value, rtol=1e-1) diff --git a/glotaran/examples/sequential.py b/glotaran/examples/sequential.py index 328102d5b..d0f0f635e 100644 --- a/glotaran/examples/sequential.py +++ b/glotaran/examples/sequential.py @@ -39,19 +39,19 @@ }, "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", }, "sh3": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.3", "location": "shapes.locs.3", "width": "shapes.width.3", diff --git a/glotaran/examples/test/test_example.py b/glotaran/examples/test/test_example.py new file mode 100644 index 000000000..7bd6e74a4 --- /dev/null +++ b/glotaran/examples/test/test_example.py @@ -0,0 +1,7 @@ +import xarray as xr + +from glotaran.examples.sequential import dataset + + +def test_dataset(): + assert isinstance(dataset, xr.Dataset) diff --git a/glotaran/model/item.py b/glotaran/model/item.py index 82d1c9519..21dd6188f 100644 --- a/glotaran/model/item.py +++ b/glotaran/model/item.py @@ -109,9 +109,6 @@ def decorator(cls): from_dict = _create_from_dict_func(cls) setattr(cls, "from_dict", from_dict) - from_list = _create_from_list_func(cls) - setattr(cls, "from_list", from_list) - validate = _create_validation_func(cls) setattr(cls, "validate", validate) @@ -233,7 +230,13 @@ def from_dict(ncls, values: dict) -> cls: for name in ncls._glotaran_properties: if name in values: - setattr(item, name, values[name]) + value = values[name] + prop = getattr(item.__class__, name) + if prop.property_type == float: + value = float(value) + elif prop.property_type == int: + value = int(value) + setattr(item, name, value) elif not getattr(ncls, name).allow_none and getattr(item, name) is None: raise ValueError(f"Missing Property '{name}' For Item '{ncls.__name__}'") @@ -242,32 +245,6 @@ def from_dict(ncls, values: dict) -> cls: return from_dict -def _create_from_list_func(cls): - @classmethod - @wrap_func_as_method(cls) - def from_list(ncls, values: list) -> cls: - f"""Creates an instance of {cls.__name__} from a list of values. Intended only for internal use. - - Parameters - ---------- - values : - A list of values. - """ - item = ncls() - if len(values) != len(ncls._glotaran_properties): - raise ValueError( - f"To few or much parameters for '{ncls.__name__}'" - f"\nGot: {values}\nWant: {ncls._glotaran_properties}" - ) - - for i, name in enumerate(ncls._glotaran_properties): - setattr(item, name, values[i]) - - return item - - return from_list - - def _create_validation_func(cls): @wrap_func_as_method(cls) def validate(self, model: Model, parameters: ParameterGroup | None = None) -> list[str]: diff --git a/glotaran/model/property.py b/glotaran/model/property.py index 92518badc..41c6d4c0f 100644 --- a/glotaran/model/property.py +++ b/glotaran/model/property.py @@ -13,10 +13,10 @@ def __init__(self, cls, name, prop_type, doc, default, allow_none): self._allow_none = allow_none self._determine_if_parameter(prop_type) - set_type = prop_type if not self._is_parameter else typing.Union[str, prop_type] + self._type = prop_type if not self._is_parameter else typing.Union[str, prop_type] @wrap_func_as_method(cls, name=name) - def setter(that_self, value: set_type): + def setter(that_self, value: self._type): if value is None and not self._allow_none: raise Exception( f"Property '{name}' of '{cls.__name__}' is not allowed to set to None." @@ -44,6 +44,10 @@ def getter(that_self) -> prop_type: def allow_none(self) -> bool: return self._allow_none + @property + def property_type(self) -> typing.Type: + return self._type + def validate(self, value, model, parameters=None) -> typing.List[str]: if value is None and self.allow_none: diff --git a/glotaran/parameter/parameter.py b/glotaran/parameter/parameter.py index 3a0ea0f2b..7027da627 100644 --- a/glotaran/parameter/parameter.py +++ b/glotaran/parameter/parameter.py @@ -9,6 +9,8 @@ import numpy as np from numpy.typing._array_like import _SupportsArray +from glotaran.utils.sanitize import sanitize_parameter_list + if TYPE_CHECKING: from typing import Any @@ -113,7 +115,7 @@ def from_list_or_value( param.value = value else: - values = _sanatize_parameter_list(value) + values = sanitize_parameter_list(value) param.label = _retrieve_from_list_by_type(values, str, label) param.value = float(_retrieve_from_list_by_type(values, (int, float), 0)) options = _retrieve_from_list_by_type(values, dict, None) @@ -485,18 +487,6 @@ def _log_value(value: float): return np.log(value) -# A reexp for ONLY matching scientific -_match_scientific = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)") - - -def _sanatize_parameter_list(li: list) -> list: - for i, value in enumerate(li): - if isinstance(value, str) and _match_scientific.match(value): - li[i] = float(value) - - return li - - def _retrieve_from_list_by_type(li: list, t: type | tuple[type, ...], default: Any): tmp = list(filter(lambda x: isinstance(x, t), li)) if not tmp: diff --git a/glotaran/test/test_spectral_decay.py b/glotaran/test/test_spectral_decay.py index d0f86cecf..a23ce515d 100644 --- a/glotaran/test/test_spectral_decay.py +++ b/glotaran/test/test_spectral_decay.py @@ -25,17 +25,17 @@ s3: sh3 shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 @@ -75,17 +75,17 @@ width: [irf.width] shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 diff --git a/glotaran/test/test_spectral_decay_full_model.py b/glotaran/test/test_spectral_decay_full_model.py index 8f2e7e532..28d441db9 100644 --- a/glotaran/test/test_spectral_decay_full_model.py +++ b/glotaran/test/test_spectral_decay_full_model.py @@ -42,17 +42,17 @@ width: [irf.width] shape: sh1: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.1 location: shapes.locs.1 width: shapes.width.1 sh2: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.2 location: shapes.locs.2 width: shapes.width.2 sh3: - type: skewed-gaussian + type: gaussian amplitude: shapes.amps.3 location: shapes.locs.3 width: shapes.width.3 diff --git a/glotaran/test/test_spectral_penalties.py b/glotaran/test/test_spectral_penalties.py index a00c9b853..5b055aedc 100644 --- a/glotaran/test/test_spectral_penalties.py +++ b/glotaran/test/test_spectral_penalties.py @@ -127,13 +127,13 @@ def test_equal_area_penalties(debug=False): shape = { "shape": { "sh1": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.1", "location": "shapes.locs.1", "width": "shapes.width.1", }, "sh2": { - "type": "skewed-gaussian", + "type": "gaussian", "amplitude": "shapes.amps.2", "location": "shapes.locs.2", "width": "shapes.width.2", diff --git a/glotaran/utils/regex.py b/glotaran/utils/regex.py new file mode 100644 index 000000000..0f701bad9 --- /dev/null +++ b/glotaran/utils/regex.py @@ -0,0 +1,16 @@ +"""Glotaran module with regular expression patterns and functions.""" +import re + + +class RegexPattern: + """An 'Enum' of (compiled) regular expression patterns (rp).""" + + # tuple = re.compile(r"(\(.*?,.*?\))") + elements_in_string_of_list: re.Pattern = re.compile(r"(\(.+?\)|[-+.\d]+)") + group: re.Pattern = re.compile(r"(\(.+?\))") + list_with_tuples: re.Pattern = re.compile(r"(\[.+\(.+\).+\])") + word: re.Pattern = re.compile(r"[\w]+") + number_scientific: re.Pattern = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)") + number: re.Pattern = re.compile(r"[\d.+-]+") + tuple_word: re.Pattern = re.compile(r"(\([.\s\w\d]+?[,.\s\w\d]*?\))") + tuple_number: re.Pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))") diff --git a/glotaran/builtin/io/yml/sanatize.py b/glotaran/utils/sanitize.py similarity index 65% rename from glotaran/builtin/io/yml/sanatize.py rename to glotaran/utils/sanitize.py index fe3084656..d26a97d3e 100644 --- a/glotaran/builtin/io/yml/sanatize.py +++ b/glotaran/utils/sanitize.py @@ -1,22 +1,14 @@ -import re -from typing import List -from typing import Tuple -from typing import Union +"""Glotaran module with utilities for sanitation of parsed content.""" +from __future__ import annotations -from glotaran.deprecation import warn_deprecated +from typing import Any -# tuple_pattern = re.compile(r"(\(.*?,.*?\))") -tuple_number_pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))") -number_pattern = re.compile(r"[\d.+-]+") -tuple_name_pattern = re.compile(r"(\([.\s\w\d]+?[,.\s\w\d]*?\))") -name_pattern = re.compile(r"[\w]+") -group_pattern = re.compile(r"(\(.+?\))") -match_list_with_tuples = re.compile(r"(\[.+\(.+\).+\])") -match_elements_in_string_of_list = re.compile(r"(\(.+?\)|[-+.\d]+)") +from glotaran.deprecation import warn_deprecated +from glotaran.utils.regex import RegexPattern as rp -def sanitize_list_with_broken_tuples(mangled_list: List[Union[str, float]]) -> List[str]: - """Sanitize a list with 'broken' tuples +def sanitize_list_with_broken_tuples(mangled_list: list[str | float]) -> list[str]: + """Sanitize a list with 'broken' tuples. A list of broken tuples as returned by yaml when parsing tuples. e.g parsing the list of tuples [(3,100), (4,200)] results in @@ -34,13 +26,12 @@ def sanitize_list_with_broken_tuples(mangled_list: List[Union[str, float]]) -> L A list containing the restores tuples (in string form) which can be converted back to numbered tuples using `list_string_to_tuple` """ - sanitized_string = str(mangled_list).replace("'", "") - return list(match_elements_in_string_of_list.findall(sanitized_string)) + return list(rp.elements_in_string_of_list.findall(sanitized_string)) def sanitize_dict_keys(d: dict) -> dict: - """Sanitize the stringified tuple dict keys in a yaml parsed dict + """Sanitize the stringified tuple dict keys in a yaml parsed dict. Keys representing a tuple, e.g. '(s1, s2)' are converted to a tuple of strings e.g. ('s1', 's2') @@ -59,8 +50,8 @@ def sanitize_dict_keys(d: dict) -> dict: return {} d_new = {} for k, v in d.items() if isinstance(d, dict) else enumerate(d): - if isinstance(d, dict) and isinstance(k, str) and tuple_name_pattern.match(k): - k_new = tuple(map(str, name_pattern.findall(k))) + if isinstance(d, dict) and isinstance(k, str) and rp.tuple_word.match(k): + k_new = tuple(map(str, rp.word.findall(k))) d_new.update({k_new: v}) elif isinstance(d, (dict, list)): new_v = sanitize_dict_keys(v) @@ -69,18 +60,38 @@ def sanitize_dict_keys(d: dict) -> dict: return d_new -def sanitize_dict_values(d: dict): - """Sanitizes a dict with broken tuples inside modifying it in-place +def sanity_scientific_notation_conversion(d: dict[str, Any] | list[Any]): + """Convert scientific notation string values to floats. + + Parameters + ---------- + d : dict[str, Any] | list[Any] + Iterable which should be checked for scientific notation values. + """ + if not isinstance(d, (dict, list)): + return + for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[attr-defined] + if isinstance(v, (list, dict)): + sanity_scientific_notation_conversion(v) + if isinstance(v, str): + d[k] = convert_scientific_to_float(v) + + +def sanitize_dict_values(d: dict[str, Any] | list[Any]): + """Sanitizes a dict with broken tuples inside modifying it in-place. + Broken tuples are tuples that are turned into strings by the yaml parser. This functions calls `sanitize_list_with_broken_tuples` to glue the broken strings together and then calls list_to_tuple to turn the list with tuple strings back to number tuples. - Args: - d (dict): A (complex) dict containing (possibly nested) values of broken tuple strings + Parameters + ---------- + d : dict + A (complex) dict containing (possibly nested) values of broken tuple strings. """ if not isinstance(d, (dict, list)): return - for k, v in d.items() if isinstance(d, dict) else enumerate(d): + for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[attr-defined] if isinstance(v, list): leaf = all(isinstance(el, (str, tuple, float)) for el in v) if leaf: @@ -96,8 +107,8 @@ def sanitize_dict_values(d: dict): def string_to_tuple( tuple_str: str, from_list=False -) -> Union[Tuple[float], Tuple[str], float, str]: - """[summary] +) -> tuple[float, ...] | tuple[str, ...] | float | str: + """Convert a string to a tuple if it matches a tuple pattern. Parameters ---------- @@ -111,22 +122,23 @@ def string_to_tuple( Returns ------- - Union[Tuple[float], Tuple[str], float, str] + tuple[float], tuple[str], float, str Returns the tuple intended by the string """ - - if tuple_number_pattern.match(tuple_str): - return tuple(map(float, number_pattern.findall(tuple_str))) - elif tuple_name_pattern.match(tuple_str): - return tuple(map(str, name_pattern.findall(tuple_str))) - elif from_list and number_pattern.match(tuple_str): + if rp.tuple_number.match(tuple_str): + return tuple(map(float, rp.number.findall(tuple_str))) + elif rp.tuple_word.match(tuple_str): + return tuple(map(str, rp.word.findall(tuple_str))) + elif from_list and rp.number.match(tuple_str): return float(tuple_str) else: return tuple_str -def list_string_to_tuple(a_list: List[str]) -> List[Union[float, str]]: - """Converts a list of strings (representing tuples) to a list of tuples +def list_string_to_tuple( + a_list: list[str], +) -> list[tuple[float, ...] | tuple[str, ...] | float | str]: + """Convert a list of strings (representing tuples) to a list of tuples. Parameters ---------- @@ -138,18 +150,20 @@ def list_string_to_tuple(a_list: List[str]) -> List[Union[float, str]]: List[Union[float, str]] A list of the (numbered) tuples represted by the incoming a_list """ - for i, v in enumerate(a_list): - a_list[i] = string_to_tuple(v, from_list=True) - return a_list + return [string_to_tuple(v, from_list=True) for v in a_list] def sanitize_yaml(d: dict, do_keys: bool = True, do_values: bool = False) -> dict: - """Sanitize a yaml-returned dict for key or (list) values containing tuples + """Sanitize a yaml-returned dict for key or (list) values containing tuples. Parameters ---------- d : dict a dict resulting from parsing a pyglotaran model spec yml file + do_keys : bool + toggle sanitization of dict keys, by default True + do_values : bool + toggle sanitization of dict values, by default False Returns ------- @@ -161,10 +175,57 @@ def sanitize_yaml(d: dict, do_keys: bool = True, do_values: bool = False) -> dic if do_values: # this is only needed to allow for tuple parsing in specification sanitize_dict_values(d) + sanity_scientific_notation_conversion(d) return d +def convert_scientific_to_float(value: str) -> float | str: + """Convert value to float if it matches scientific notation string. + + Parameters + ---------- + value : str + value to convert from string to float if it matches scientific notation + + Returns + ------- + float | string + return float if value was scientific notation string, else turn original value + """ + if rp.number_scientific.match(value): + return float(value) + else: + return value + + +def sanitize_parameter_list(parameter_list: list[str | float]) -> list[str | float]: + """Replace in a list strings matching scientific notation with floats. + + Parameters + ---------- + parameter_list : list + A list of parameters where some elements may be strings like 1E7 + + Returns + ------- + list + A list where strings matching a scientific number have been converted to float + """ + for i, value in enumerate(parameter_list): + if isinstance(value, str): + parameter_list[i] = convert_scientific_to_float(value) + + return parameter_list + + def check_deprecations(spec: dict): + """Check deprecations in a `spec` dict. + + Parameters + ---------- + spec : dict + A specification dictionary + """ if "type" in spec: if spec["type"] == "kinetic-spectrum": warn_deprecated( diff --git a/glotaran/builtin/io/yml/test/test_util.py b/glotaran/utils/test/test_sanitize.py similarity index 93% rename from glotaran/builtin/io/yml/test/test_util.py rename to glotaran/utils/test/test_sanitize.py index 812b00b93..3c7e6b948 100644 --- a/glotaran/builtin/io/yml/test/test_util.py +++ b/glotaran/utils/test/test_sanitize.py @@ -5,12 +5,12 @@ import pytest -from glotaran.builtin.io.yml.sanatize import sanitize_list_with_broken_tuples +from glotaran.utils.sanitize import sanitize_list_with_broken_tuples class MangledListTestData(NamedTuple): input: list[Any] - input_sanitized: list[str] + input_sanitized: list[str] | str output: list[str]