Skip to content

Commit

Permalink
Finish refactoring of activations
Browse files Browse the repository at this point in the history
  • Loading branch information
jsnel committed Oct 13, 2024
1 parent 3138c49 commit 5da2676
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 131 deletions.
2 changes: 1 addition & 1 deletion glotaran/builtin/elements/coherent_artifact/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def calculate_matrix( # type:ignore[override]

activations = {
key: a
for key,a in model.activations.items()
for key, a in model.activations.items()
if isinstance(a, MultiGaussianActivation) and self.label in a.compartments
}

Expand Down
9 changes: 6 additions & 3 deletions glotaran/builtin/elements/kinetic/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def calculate_matrix( # type:ignore[override]
) -> tuple[list[str], ArrayLike]:
compartments = self.compartments
matrices = []
for _,activation in model.activations.items():
for _, activation in model.activations.items():
initial_concentrations = np.array(
[float(activation.compartments.get(label, 0)) for label in compartments]
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def create_result(
initial_concentrations = []
a_matrices = []
kinetic_amplitudes = []
for _,activation in model.activations.items():
for _, activation in model.activations.items():
initial_concentration = np.array(
[float(activation.compartments.get(label, 0)) for label in self.compartments]
)
Expand All @@ -196,7 +196,10 @@ def create_result(

initial_concentration = xr.DataArray(
initial_concentrations,
coords={"activation": range(len(initial_concentrations)), "compartment": self.compartments},
coords={
"activation": range(len(initial_concentrations)),
"compartment": self.compartments,
},
)
a_matrix = xr.DataArray(
a_matrices,
Expand Down
97 changes: 30 additions & 67 deletions glotaran/builtin/items/activation/data_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import asdict
from typing import TYPE_CHECKING
from typing import cast

import attr
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -30,7 +32,7 @@ def to_string(self) -> str:


def validate_activations(
value: dict[str,Activation],
value: dict[str, Activation],
activation: Activation,
parameters: Parameters | None,
) -> list[ItemIssue]:
Expand All @@ -41,7 +43,7 @@ def validate_activations(


class ActivationDataModel(DataModel):
activations: dict[str,Activation.get_annotated_type()] = Attribute( # type:ignore[valid-type]
activations: dict[str, Activation.get_annotated_type()] = Attribute( # type:ignore[valid-type]
validator=validate_activations,
description="The activation(s) of the dataset.",
)
Expand All @@ -53,81 +55,42 @@ def create_result(
model_dimension: str,
amplitudes: xr.DataArray,
concentrations: xr.DataArray,
) -> dict[str, xr.DataArray]:
) -> dict[str, xr.Dataset]:
gaussian_activations = {
key:a for key, a in model.activations.items() if isinstance(a, MultiGaussianActivation)
key: a
for key, a in model.activations.items()
if isinstance(a, MultiGaussianActivation)
}
if not len(gaussian_activations):
return {}

global_axis = amplitudes.coords[global_dimension]
model_axis = concentrations.coords[model_dimension]

activations = []
activation_parameters: list[list[GaussianActivationParameters]] = []
activation_shifts = []
activation_dispersions = []
result: dict[str, xr.Dataset] = {}

has_shifts = any(a.shift is not None for a in gaussian_activations.values())
has_dispersions = any(a.dispersion_center is not None for a in gaussian_activations.values())

for _, activation in gaussian_activations.items():
activations.append(activation.calculate_function(model_axis))
activation_parameters.append(
cast(list[GaussianActivationParameters], activation.parameters())
)
if has_shifts:
activation_shifts.append(
activation.shift if activation.shift is not None else [0] * global_axis.size
)
if has_dispersions:
activation_dispersions.append(
activation.calculate_dispersion(global_axis)
if activation.dispersion_center is not None
else activation.center * global_axis.size
)

result = {}

activation_coords = {"gaussian_activation": np.arange(1, len(gaussian_activations) + 1)}
result["gaussian_activation_function"] = xr.DataArray(
activations,
coords=activation_coords | {model_dimension: model_axis},
dims=("gaussian_activation", model_dimension),
)

if has_shifts:
result["activation_shift"] = xr.DataArray(
activation_shifts,
coords=activation_coords | {global_dimension: global_axis},
dims=("gaussian_activation", global_dimension),
for key, activation in gaussian_activations.items():
trace = activation.calculate_function(model_axis)
shift = activation.shift if activation.shift is not None else [0] * global_axis.size
center = (
np.sum(activation.calculate_dispersion(global_axis), axis=0)
if activation.dispersion_center is not None
else activation.center * global_axis.size
)

activation_coords = activation_coords | {
"gaussian_activation_part": np.arange(max([len(ps) for ps in activation_parameters]))
}

result["activation_center"] = xr.DataArray(
[[p.center for p in ps] for ps in activation_parameters],
coords=activation_coords,
dims=("gaussian_activation", "gaussian_activation_part"),
)
result["activation_width"] = xr.DataArray(
[[p.width for p in ps] for ps in activation_parameters],
coords=activation_coords,
dims=("gaussian_activation", "gaussian_activation_part"),
)
result["activation_scale"] = xr.DataArray(
[[p.scale for p in ps] for ps in activation_parameters],
coords=activation_coords,
dims=("gaussian_activation", "gaussian_activation_part"),
)

if has_dispersions:
result["activation_dispersion"] = xr.DataArray(
activation_dispersions,
coords=activation_coords | {global_dimension: global_axis},
dims=("gaussian_activation", "gaussian_activation_part", global_dimension),
props = [asdict(p) for p in activation.parameters()]
result[key] = xr.Dataset(
{
"trace": xr.DataArray(
trace, coords={model_dimension: model_axis}, dims=(model_dimension,)
),
"shift": xr.DataArray(
shift, coords={global_dimension: global_axis}, dims=(global_dimension,)
),
"center": xr.DataArray(
center, coords={global_dimension: global_axis}, dims=(global_dimension,)
),
},
attrs={"activation": props},
)

return result
8 changes: 5 additions & 3 deletions glotaran/optimization/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class OptimizationResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

elements: dict[str, xr.Dataset] = Field(default_factory=dict)
activations: xr.Dataset = Field(default_factory=dict)
activations: dict[str, xr.Dataset] = Field(default_factory=dict)
input_data: xr.DataArray | xr.Dataset | None = None
residuals: xr.DataArray | xr.Dataset | None = None

Expand Down Expand Up @@ -279,7 +279,9 @@ def create_single_dataset_result(self) -> OptimizationObjectiveResult:
activations=activations,
)
return OptimizationObjectiveResult(
optimization_results={label: result}, clp_size=clp_size, additional_penalty=additional_penalty
optimization_results={label: result},
clp_size=clp_size,
additional_penalty=additional_penalty,
)

def create_multi_dataset_result(self) -> OptimizationObjectiveResult:
Expand Down Expand Up @@ -470,7 +472,7 @@ def create_data_model_results(
amplitudes,
concentrations,
)
return xr.Dataset(result)
return result

def get_result(self) -> OptimizationObjectiveResult:
return (
Expand Down
4 changes: 2 additions & 2 deletions glotaran/testing/simulated_data/parallel_spectral_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
KineticSpectrumDataModel(
elements=["parallel"],
global_elements=["spectral"],
activations={"irf":GaussianActivation.model_validate(ACTIVATION)}, # type:ignore[call-arg]
activations={"irf": GaussianActivation.model_validate(ACTIVATION)}, # type:ignore[call-arg]
),
ModelLibrary.from_dict(LIBRARY),
SIMULATION_PARAMETERS,
Expand All @@ -35,7 +35,7 @@
"datasets": {
"parallel-decay": {
"elements": ["parallel"],
"activations": {"irf":ACTIVATION},
"activations": {"irf": ACTIVATION},
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions glotaran/testing/simulated_data/sequential_spectral_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
KineticSpectrumDataModel(
elements=["sequential"],
global_elements=["spectral"],
activations={"irf":GaussianActivation.model_validate(ACTIVATION)}, # type:ignore[call-arg]
activations={"irf": GaussianActivation.model_validate(ACTIVATION)}, # type:ignore[call-arg]
),
ModelLibrary.from_dict(LIBRARY),
SIMULATION_PARAMETERS,
Expand All @@ -35,7 +35,7 @@
"datasets": {
"sequential-decay": {
"elements": ["sequential"],
"activations": {"irf":ACTIVATION},
"activations": {"irf": ACTIVATION},
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
),
)
def test_coherent_artifact(activation: Activation):
data_model = ActivationDataModel(elements=["coherent-artifact"], activations={"irf":activation})
data_model = ActivationDataModel(
elements=["coherent-artifact"], activations={"irf": activation}
)
data_model.data = simulate(
data_model, test_library, test_parameters_simulation, test_axies, clp=test_clp
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
test_parameters_simulation = Parameters.from_dict(
{
"osc": [["frequency", 3], ["rate", 1]],
"gaussian": [["center", 0], ["width", 10]],
"irf": [["center", 0], ["width", 10]],
}
)
test_parameters = Parameters.from_dict(
{
"osc": [["frequency", 3], ["rate", 1, {"min": 0}]],
"gaussian": [["center", 0], ["width", 10]],
"irf": [["center", 0], ["width", 10]],
}
)

Expand Down Expand Up @@ -92,7 +92,7 @@
),
)
def test_coherent_artifact(activation: Activation):
data_model = ActivationDataModel(elements=["does"], activations={"irf": activation})
data_model = ActivationDataModel(elements=["doas"], activations={"irf": activation})
data_model.data = simulate(
data_model, test_library, test_parameters_simulation, test_axies, clp=test_clp
)
Expand Down Expand Up @@ -124,9 +124,12 @@ def test_coherent_artifact(activation: Activation):
in optimization_results["damped_oscillation"]
)
assert (
"damped_oscillation_frequency_damped-oscillation" in optimization_results["damped_oscillation"]
"damped_oscillation_frequency_damped-oscillation"
in optimization_results["damped_oscillation"]
)
assert (
"damped_oscillation_rate_damped-oscillation" in optimization_results["damped_oscillation"]
)
assert "damped_oscillation_rate_damped-oscillation" in optimization_results["damped_oscillation"]
assert (
"damped_oscillation_phase_associated_amplitudes_damped-oscillation"
in optimization_results["damped_oscillation"]
Expand All @@ -148,27 +151,36 @@ def test_coherent_artifact(activation: Activation):
in optimization_results["damped_oscillation"]
)


if __name__ == "__main__":
test_coherent_artifact(InstantActivation(
type="instant",
compartments={"osc": 1},
))
test_coherent_artifact(GaussianActivation(
type="gaussian",
compartments={"osc": 1},
center="gaussian.center",
width="gaussian.width",
))
test_coherent_artifact(GaussianActivation(
type="gaussian",
compartments={"osc": 1},
center="gaussian.center",
width="gaussian.width",
shift=[0],
))
test_coherent_artifact(MultiGaussianActivation(
type="multi-gaussian",
compartments={},
center=["gaussian.center"],
width=["gaussian.width", "gaussian.width"],
))
test_coherent_artifact(
InstantActivation(
type="instant",
compartments={"osc": 1},
)
)
test_coherent_artifact(
GaussianActivation(
type="gaussian",
compartments={"osc": 1},
center="irf.center",
width="irf.width",
)
)
test_coherent_artifact(
GaussianActivation(
type="gaussian",
compartments={"osc": 1},
center="irf.center",
width="irf.width",
shift=[0],
)
)
test_coherent_artifact(
MultiGaussianActivation(
type="multi-gaussian",
compartments={},
center=["irf.center"],
width=["irf.width", "irf.width"],
)
)
Loading

0 comments on commit 5da2676

Please sign in to comment.