Skip to content

Commit

Permalink
Removed ElementResults in favor of xr.Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jsnel committed Oct 13, 2024
1 parent 85ee779 commit 4d274dc
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 79 deletions.
11 changes: 3 additions & 8 deletions glotaran/builtin/elements/baseline/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from typing import Literal

import numpy as np
import xarray as xr

from glotaran.model.element import Element
from glotaran.model.element import ElementResult

if TYPE_CHECKING:
import xarray as xr

from glotaran.model.data_model import DataModel
from glotaran.typing.types import ArrayLike

Expand Down Expand Up @@ -42,8 +40,5 @@ def create_result(
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
return ElementResult(
amplitudes={"baseline": amplitudes.sel(amplitude_label=self.clp_label())},
concentrations={},
)
) -> xr.Dataset:
return xr.Dataset({"amplitudes": amplitudes.sel(amplitude_label=self.clp_label())})
11 changes: 11 additions & 0 deletions glotaran/builtin/elements/clp_guide/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal

import numpy as np
import xarray as xr

from glotaran.model.element import Element

Expand All @@ -27,3 +28,13 @@ def calculate_matrix(
**kwargs,
) -> tuple[list[str], ArrayLike]:
return [self.target], np.ones((1, 1), dtype=np.float64)

def create_result(
self,
model: DataModel,
global_dimension: str,
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> xr.Dataset:
return xr.Dataset() # TODO: return correct data
11 changes: 3 additions & 8 deletions glotaran/builtin/elements/coherent_artifact/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@

import numba as nb
import numpy as np
import xarray as xr

from glotaran.builtin.items.activation import ActivationDataModel
from glotaran.builtin.items.activation import MultiGaussianActivation
from glotaran.model.element import Element
from glotaran.model.element import ElementResult
from glotaran.model.errors import GlotaranModelError
from glotaran.model.item import ParameterType # noqa: TCH001

if TYPE_CHECKING:
import xarray as xr

from glotaran.model.data_model import DataModel
from glotaran.typing.types import ArrayLike

Expand Down Expand Up @@ -96,7 +94,7 @@ def create_result(
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
) -> xr.Dataset:
amplitude = (
amplitudes.sel(amplitude_label=self.compartments)
.rename(amplitude_label="coherent_artifact_order")
Expand All @@ -107,10 +105,7 @@ def create_result(
.rename(amplitude_label="coherent_artifact_order")
.assign_coords({"coherent_artifact_order": range(1, self.order + 1)})
)
return ElementResult(
amplitudes={"coherent_artifact": amplitude},
concentrations={"coherent_artifact": concentration},
)
return xr.Dataset({"amplitudes": amplitude, "concentrations": concentration})


@nb.jit(nopython=True, parallel=False)
Expand Down
25 changes: 11 additions & 14 deletions glotaran/builtin/elements/damped_oscillation/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from glotaran.builtin.items.activation import MultiGaussianActivation
from glotaran.model.data_model import DataModel # noqa: TCH001
from glotaran.model.element import Element
from glotaran.model.element import ElementResult
from glotaran.model.item import Item
from glotaran.model.item import ParameterType

Expand Down Expand Up @@ -140,7 +139,7 @@ def create_result(
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
) -> xr.Dataset:
oscillations = list(self.oscillations)
frequencies = [self.oscillations[label].frequency for label in oscillations]
rates = [self.oscillations[label].rate for label in oscillations]
Expand Down Expand Up @@ -186,17 +185,15 @@ def create_result(
coords=doas_concentrations.coords,
)

return ElementResult(
amplitudes={
"damped_oscillation": doas_amplitudes,
"damped_oscillation_phase": phase_amplitudes,
"damped_oscillation_sin": sin_amplitudes,
"damped_oscillation_cos": cos_amplitudes,
},
concentrations={
"damped_oscillation": doas_concentrations,
"damped_oscillation_phase": phase_concentrations,
"damped_oscillation_sin": sin_concentrations,
"damped_oscillation_cos": cos_concentrations,
return xr.Dataset(
{
"damped_oscillation_amplitudes": doas_amplitudes,
"damped_oscillation_phase_amplitudes": phase_amplitudes,
"damped_oscillation_sin_amplitudes": sin_amplitudes,
"damped_oscillation_cos_amplitudes": cos_amplitudes,
"damped_oscillation_concentrations": doas_concentrations,
"damped_oscillation_phase_concentrations": phase_concentrations,
"damped_oscillation_sin_concentrations": sin_concentrations,
"damped_oscillation_cos_concentrations": cos_concentrations,
},
)
17 changes: 7 additions & 10 deletions glotaran/builtin/elements/kinetic/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from glotaran.builtin.elements.kinetic.matrix import calculate_matrix_gaussian_activation_on_index
from glotaran.builtin.items.activation import ActivationDataModel
from glotaran.builtin.items.activation import MultiGaussianActivation
from glotaran.model.element import ElementResult
from glotaran.model.element import ExtendableElement

if TYPE_CHECKING:
Expand Down Expand Up @@ -159,7 +158,7 @@ def create_result(
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
) -> xr.Dataset:
species_amplitude = amplitudes.sel(amplitude_label=self.species).rename(
amplitude_label="species"
)
Expand Down Expand Up @@ -219,16 +218,14 @@ def create_result(
dims=("activation", global_dimension, "kinetic"),
)

return ElementResult(
amplitudes={
"species": species_amplitude,
"kinetic": kinetic_amplitude,
},
concentrations={"species": species_concentration},
extra={
return xr.Dataset(
{
"species_amplitude": species_amplitude,
"kinetic_amplitude": kinetic_amplitude,
"species_concentration": species_concentration,
"k_matrix": k_matrix,
"reduced_k_matrix": reduced_k_matrix,
"initial_concentration": initial_concentration,
"a_matrix": a_matrix,
},
}
)
14 changes: 7 additions & 7 deletions glotaran/builtin/elements/spectral/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from typing import Literal

import numpy as np
import xarray as xr

from glotaran.builtin.elements.spectral.shape import SpectralShape # noqa: TCH001
from glotaran.model.data_model import DataModel
from glotaran.model.element import Element
from glotaran.model.element import ElementResult

if TYPE_CHECKING:
import xarray as xr

from glotaran.typing.types import ArrayLike


Expand Down Expand Up @@ -87,15 +85,17 @@ def create_result(
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
) -> xr.Dataset:
shapes = list(self.shapes.keys())

spectra_amplitude = amplitudes.sel(amplitude_label=shapes).rename(amplitude_label="shape")
spectra_concentration = concentrations.sel(amplitude_label=shapes).rename(
amplitude_label="shape"
)

return ElementResult(
amplitudes={"spectrum": spectra_amplitude},
concentrations={"spectrum": spectra_concentration},
return xr.Dataset(
{
"amplitudes": spectra_amplitude,
"concentrations": spectra_concentration,
}
)
14 changes: 3 additions & 11 deletions glotaran/model/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any
from typing import ClassVar

import xarray as xr
from pydantic import ConfigDict
from pydantic import Field

Expand All @@ -17,8 +18,6 @@
from glotaran.plugin_system.element_registration import register_element

if TYPE_CHECKING:
import xarray as xr

from glotaran.model.data_model import DataModel
from glotaran.typing.types import ArrayLike

Expand All @@ -35,13 +34,6 @@ def _sanitize_json_schema(json_schema: dict[str, Any]) -> None:
json_schema["required"].remove("label")


@dataclass
class ElementResult:
amplitudes: dict[str, xr.DataArray]
concentrations: dict[str, xr.DataArray]
extra: dict[str, xr.DataArray] = field(default_factory=dict)


class Element(TypedItem, abc.ABC):
"""Subclasses must overwrite :method:`glotaran.model.Element.calculate_matrix`."""

Expand Down Expand Up @@ -93,14 +85,15 @@ def calculate_matrix(
.. # noqa: DAR202
"""

@abc.abstractmethod
def create_result(
self,
model: DataModel,
global_dimension: str,
model_dimension: str,
amplitudes: xr.Dataset,
concentrations: xr.Dataset,
) -> ElementResult:
) -> xr.Dataset:
"""
Parameters
Expand All @@ -114,7 +107,6 @@ def create_result(
as_global: bool
Whether model is calculated as global model.
"""
return ElementResult({}, {})


class ExtendableElement(Element):
Expand Down
22 changes: 12 additions & 10 deletions glotaran/optimization/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from glotaran.model.data_model import DataModel
from glotaran.model.data_model import iterate_data_model_elements
from glotaran.model.element import Element
from glotaran.model.element import ElementResult
from glotaran.optimization.data import LinkedOptimizationData
from glotaran.optimization.data import OptimizationData
from glotaran.optimization.estimation import OptimizationEstimation
Expand Down Expand Up @@ -46,7 +45,7 @@ def add_svd_to_result_dataset(dataset: xr.Dataset, global_dim: str, model_dim: s
class DatasetResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

elements: dict[str, ElementResult] = Field(default_factory=dict)
elements: dict[str, xr.Dataset] = Field(default_factory=dict)
activations: xr.Dataset = Field(default_factory=dict)
input_data: xr.DataArray | xr.Dataset | None = None
residuals: xr.DataArray | xr.Dataset | None = None
Expand Down Expand Up @@ -194,17 +193,20 @@ def create_global_result(self) -> OptimizationObjectiveResult:
)
clp_size = len(matrix.clp_axis) + len(global_matrix.clp_axis)
self._data.unweight_result_dataset(result_dataset)
result_dataset["fit"] = result_dataset.data - result_dataset.residual

add_svd_to_result_dataset(result_dataset, global_dim, model_dim)
result = DatasetResult(
result_dataset,
{
label: ElementResult(
amplitudes={"clp": clp},
concentrations={"global": global_matrix, "model": matrix},
input_data=result_dataset.data,
residuals=result_dataset.residual,
elements={
label: xr.Dataset(
{
"amplitudes": clp,
"global_concentrations": global_matrix,
"model_concentrations": matrix,
}
)
},
{},
)
return OptimizationObjectiveResult(
data={label: result}, clp_size=clp_size, additional_penalty=0
Expand Down Expand Up @@ -385,7 +387,7 @@ def create_element_results(
model_dim: str,
amplitudes: xr.DataArray,
concentrations: xr.DataArray,
) -> dict[str, ElementResult]:
) -> dict[str, xr.Dataset]:
assert any(isinstance(element, str) for element in model.elements) is False
return {
element.label: element.create_result(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,31 @@ def test_coherent_artifact(activation: Activation):
"coherent_artifact_associated_amplitudes_coherent-artifact"
in optimized_data["coherent_artifact"]
)


if __name__ == "__main__":
test_coherent_artifact(
GaussianActivation(
type="gaussian",
compartments={"coherent-artifact": 1},
center="gaussian.center",
width="gaussian.width",
)
)
test_coherent_artifact(
GaussianActivation(
type="gaussian",
compartments={"coherent-artifact": 1},
center="gaussian.center",
width="gaussian.width",
shift=[0],
)
)
test_coherent_artifact(
MultiGaussianActivation(
type="multi-gaussian",
compartments={"coherent-artifact": 1},
center=["gaussian.center"],
width=["gaussian.width", "gaussian.width"],
)
)
Loading

0 comments on commit 4d274dc

Please sign in to comment.