diff --git a/glotaran/io/preprocessor/__init__.py b/glotaran/io/preprocessor/__init__.py new file mode 100644 index 000000000..f419c7669 --- /dev/null +++ b/glotaran/io/preprocessor/__init__.py @@ -0,0 +1,2 @@ +"""Tools for data pre-processing.""" +from glotaran.io.preprocessor.pipeline import PreProcessingPipeline diff --git a/glotaran/io/preprocessor/pipeline.py b/glotaran/io/preprocessor/pipeline.py new file mode 100644 index 000000000..c331b8008 --- /dev/null +++ b/glotaran/io/preprocessor/pipeline.py @@ -0,0 +1,88 @@ +"""A pre-processor pipeline for data.""" +from __future__ import annotations + +from typing import Annotated + +import xarray as xr +from pydantic import BaseModel +from pydantic import Field + +from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage +from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue + +PipelineAction = Annotated[ + CorrectBaselineValue | CorrectBaselineAverage, + Field(discriminator="action"), +] + + +class PreProcessingPipeline(BaseModel): + """A pipeline for pre-processors.""" + + actions: list[PipelineAction] = Field(default_factory=list) + + def apply(self, original: xr.DataArray) -> xr.DataArray: + """Apply all pre-processors on data. + + Parameters + ---------- + original: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + result = original.copy() + + for action in self.actions: + result = action.apply(result) + return result + + def _push_action(self, action: PipelineAction): + """Push an action. + + Parameters + ---------- + action: PipelineAction + The action to push. + """ + self.actions.append(action) + + def correct_baseline_value(self, value: float) -> PreProcessingPipeline: + """Correct a dataset by subtracting baseline value. + + Parameters + ---------- + value: float + The value to subtract. + + Returns + ------- + PreProcessingPipeline + """ + self._push_action(CorrectBaselineValue(value=value)) + return self + + def correct_baseline_average( + self, + select: dict[str, slice | list[int] | int] | None = None, + exclude: dict[str, slice | list[int] | int] | None = None, + ) -> PreProcessingPipeline: + """Correct a dataset by subtracting the average over a part of the data. + + Parameters + ---------- + select: dict[str, slice | list[int] | int] | None + The selection to average as dictionary of dimension and indexer. + The indexer can be a slice, a list or an integer value. + exclude: dict[str, slice | list[int] | int] | None + Excluded regions from the average as dictionary of dimension and indexer. + The indexer can be a slice, a list or an integer value. + + Returns + ------- + PreProcessingPipeline + """ + self._push_action(CorrectBaselineAverage(exclude=exclude, select=select)) + return self diff --git a/glotaran/io/preprocessor/preprocessor.py b/glotaran/io/preprocessor/preprocessor.py new file mode 100644 index 000000000..6d918a96e --- /dev/null +++ b/glotaran/io/preprocessor/preprocessor.py @@ -0,0 +1,76 @@ +"""A pre-processor pipeline for data.""" +from __future__ import annotations + +import abc +from typing import Literal + +import xarray as xr +from pydantic import BaseModel + + +class PreProcessor(BaseModel, abc.ABC): + """A base class for pre=processors.""" + + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + + @abc.abstractmethod + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + + .. # noqa: DAR202 + """ + + +class CorrectBaselineValue(PreProcessor): + """Corrects a dataset by subtracting baseline value.""" + + action: Literal["baseline-value"] = "baseline-value" + value: float + + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + return data - self.value + + +class CorrectBaselineAverage(PreProcessor): + """Corrects a dataset by subtracting the average over a part of the data.""" + + action: Literal["baseline-average"] = "baseline-average" + select: dict[str, slice | list[int] | int] | None = None + exclude: dict[str, slice | list[int] | int] | None = None + + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + return data - data.sel(self.select or {}).drop_sel(self.exclude or {}).mean() diff --git a/glotaran/io/preprocessor/test/test_preprocessor.py b/glotaran/io/preprocessor/test/test_preprocessor.py new file mode 100644 index 000000000..f6a6e283e --- /dev/null +++ b/glotaran/io/preprocessor/test/test_preprocessor.py @@ -0,0 +1,44 @@ +import pytest +import xarray as xr + +from glotaran.io.preprocessor import PreProcessingPipeline + + +def test_correct_baseline_value(): + pl = PreProcessingPipeline() + pl.correct_baseline_value(1) + data = xr.DataArray([[1]]) + result = pl.apply(data) + assert result == data - 1 + + +@pytest.mark.parametrize("indexer", (slice(0, 2), [0, 1])) +def test_correct_baseline_average(indexer: slice | list[int]): + pl = PreProcessingPipeline() + pl.correct_baseline_average(select={"dim_0": 0, "dim_1": indexer}) + data = xr.DataArray([[1.1, 0.9]]) + result = pl.apply(data) + assert (result == data - 1).all() + + +def test_correct_baseline_average_exclude(): + pl = PreProcessingPipeline() + pl.correct_baseline_average(select={"dim_0": 0}, exclude={"dim_1": 1}) + data = xr.DataArray([[1.1, 0.9]]) + result = pl.apply(data) + print(result) + assert (result == data - 1.1).all() + + +def test_to_from_dict(): + pl = PreProcessingPipeline() + pl.correct_baseline_value(1) + pl.correct_baseline_average({"dim_1": slice(0, 2)}) + pl_dict = pl.dict() + assert pl_dict == { + "actions": [ + {"action": "baseline-value", "value": 1.0}, + {"action": "baseline-average", "select": {"dim_1": slice(0, 2)}, "exclude": None}, + ] + } + assert PreProcessingPipeline.parse_obj(pl_dict) == pl diff --git a/requirements_dev.txt b/requirements_dev.txt index 5982a1092..6502bdc94 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -12,6 +12,7 @@ numpy==1.23.5 odfpy==1.4.1 openpyxl==3.1.1 pandas==1.5.3 +pydantic==1.10.2 rich==13.3.1 ruamel.yaml==0.17.21 scipy==1.10.1 diff --git a/setup.cfg b/setup.cfg index b08492677..3080b3f3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = odfpy>=1.4.1 openpyxl>=3.0.10 pandas>=1.3.4 + pydantic>=1.10.2 rich>=10.9.0 ruamel.yaml>=0.17.17 scipy>=1.7.2