Skip to content

Commit

Permalink
🩹 Fix coherent artifact crash for index dependent models (#808)
Browse files Browse the repository at this point in the history
* 🩹 Fixed dimension mismatching of CoherentArtifactMegacomplex

If the dataset_model was index_dependent, CoherentArtifactMegacomplex.finalize_data did crash because it tried cast data of dimensions (global_dimension, model_dimension, "coherent_artifact_order") to the dimensions (model_dimension, "coherent_artifact_order")

* ♻️ Refactored test_coherent_artifact to be more readable

* 🧪 Parametrized test_coherent_artifact to cover index dependent case

* ♻️ Renamed 'DatasetModel.index_dependent' ->'is_index_dependent'  (see #755)

* 🔧 Let mypy run normally in pre-commit

* 👌 Renamed 'coherent_artifact_concentration' to 'coherent_artifact_response'

Ref: #808 (comment)
  • Loading branch information
s-weigand authored Sep 11, 2021
1 parent 9892b6e commit ae63eb4
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 45 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ repos:
- id: mypy
files: "^glotaran/(plugin_system|utils|deprecation)"
exclude: "docs"
args: [glotaran]
pass_filenames: false
additional_dependencies: [types-all]

- repo: https://github.com/econchick/interrogate
Expand Down
2 changes: 1 addition & 1 deletion glotaran/analysis/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def create_result_dataset(self, label: str, copy: bool = True) -> xr.Dataset:
model_dimension = dataset_model.get_model_dimension()
if copy:
dataset = dataset.copy()
if dataset_model.index_dependent():
if dataset_model.is_index_dependent():
dataset = self.create_index_dependent_result_dataset(label, dataset)
else:
dataset = self.create_index_independent_result_dataset(label, dataset)
Expand Down
2 changes: 1 addition & 1 deletion glotaran/analysis/problem_grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, scheme: Scheme):
raise ValueError(
f"Cannot group datasets. Model dimension '{model_dimensions}' do not match."
)
self._index_dependent = any(d.index_dependent() for d in self.dataset_models.values())
self._index_dependent = any(d.is_index_dependent() for d in self.dataset_models.values())
self._global_dimension = global_dimensions.pop()
self._model_dimension = model_dimensions.pop()
self._group_clp_labels = None
Expand Down
10 changes: 5 additions & 5 deletions glotaran/analysis/problem_ungrouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def calculate_matrices(

for label, dataset_model in self.dataset_models.items():

if dataset_model.index_dependent():
if dataset_model.is_index_dependent():
self._calculate_index_dependent_matrix(label, dataset_model)
else:
self._calculate_index_independent_matrix(label, dataset_model)
Expand Down Expand Up @@ -132,10 +132,10 @@ def _calculate_residual(self, label: str, dataset_model: DatasetModel):
for i, index in enumerate(global_axis):
reduced_clp_labels, reduced_matrix = (
self.reduced_matrices[label][i]
if dataset_model.index_dependent()
if dataset_model.is_index_dependent()
else self.reduced_matrices[label]
)
if not dataset_model.index_dependent():
if not dataset_model.is_index_dependent():
reduced_matrix = reduced_matrix.copy()

if dataset_model.scale is not None:
Expand Down Expand Up @@ -183,7 +183,7 @@ def _calculate_full_model_residual(self, label: str, dataset_model: DatasetModel
model_matrix = self.matrices[label]
global_matrix = self.global_matrices[label].matrix

if dataset_model.index_dependent():
if dataset_model.is_index_dependent():
matrix = np.concatenate(
[
np.kron(global_matrix[i, :], model_matrix[i].matrix)
Expand All @@ -205,7 +205,7 @@ def _calculate_full_model_residual(self, label: str, dataset_model: DatasetModel
def _get_clp_labels(self, label: str, index: int = 0):
return (
self.matrices[label][index].clp_labels
if self.dataset_models[label].index_dependent()
if self.dataset_models[label].is_index_dependent()
else self.matrices[label].clp_labels
)

Expand Down
4 changes: 2 additions & 2 deletions glotaran/analysis/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def simulate_clp(
)
for index, _ in enumerate(global_axis)
]
if dataset_model.index_dependent()
if dataset_model.is_index_dependent()
else calculate_matrix(dataset_model, {})
)

Expand All @@ -108,7 +108,7 @@ def simulate_clp(
)
result = result.to_dataset(name="data")
for i in range(global_axis.size):
index_matrix = matrices[i] if dataset_model.index_dependent() else matrices
index_matrix = matrices[i] if dataset_model.is_index_dependent() else matrices
result.data[:, i] = np.dot(
index_matrix.matrix,
clp.isel({global_dimension: i}).sel({"clp_label": index_matrix.clp_labels}),
Expand Down
14 changes: 7 additions & 7 deletions glotaran/analysis/test/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from glotaran.project import Scheme


@pytest.mark.parametrize("index_dependent", [True, False])
@pytest.mark.parametrize("is_index_dependent", [True, False])
@pytest.mark.parametrize("grouped", [True, False])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize(
Expand All @@ -27,16 +27,16 @@
"suite",
[OneCompartmentDecay, TwoCompartmentDecay, ThreeDatasetDecay, MultichannelMulticomponentDecay],
)
def test_optimization(suite, index_dependent, grouped, weight, method):
def test_optimization(suite, is_index_dependent, grouped, weight, method):
model = suite.model

model.megacomplex["m1"].is_index_dependent = index_dependent
model.megacomplex["m1"].is_index_dependent = is_index_dependent

print("Grouped:", grouped)
print("Index dependent:", index_dependent)
print("Index dependent:", is_index_dependent)

sim_model = suite.sim_model
sim_model.megacomplex["m1"].is_index_dependent = index_dependent
sim_model.megacomplex["m1"].is_index_dependent = is_index_dependent

print(model.validate())
assert model.valid()
Expand All @@ -54,8 +54,8 @@ def test_optimization(suite, index_dependent, grouped, weight, method):
print(model.validate(initial_parameters))
assert model.valid(initial_parameters)
assert (
model.dataset["dataset1"].fill(model, initial_parameters).index_dependent()
== index_dependent
model.dataset["dataset1"].fill(model, initial_parameters).is_index_dependent()
== is_index_dependent
)

nr_datasets = 3 if issubclass(suite, ThreeDatasetDecay) else 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ def finalize_data(
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
dataset.coords["coherent_artifact_order"] = np.arange(1, self.order + 1)
dataset["coherent_artifact_concentration"] = (
(model_dimension, "coherent_artifact_order"),
response_dimensions = (model_dimension, "coherent_artifact_order")
if dataset_model.is_index_dependent() is True:
response_dimensions = (global_dimension, *response_dimensions)
dataset["coherent_artifact_response"] = (
response_dimensions,
dataset.matrix.sel(clp_label=self.compartments()).values,
)
dataset["coherent_artifact_associated_spectra"] = (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import xarray as xr

from glotaran.analysis.optimize import optimize
Expand All @@ -11,10 +12,14 @@
from glotaran.project import Scheme


def test_coherent_artifact():
@pytest.mark.parametrize(
"is_index_dependent",
(True, False),
)
def test_coherent_artifact(is_index_dependent: bool):
model_dict = {
"initial_concentration": {
"j1": {"compartments": ["s1"], "parameters": ["2"]},
"j1": {"compartments": ["s1"], "parameters": ["irf_center"]},
},
"megacomplex": {
"mc1": {"type": "decay", "k_matrix": ["k1"]},
Expand All @@ -23,15 +28,15 @@ def test_coherent_artifact():
"k_matrix": {
"k1": {
"matrix": {
("s1", "s1"): "1",
("s1", "s1"): "rate",
}
}
},
"irf": {
"irf1": {
"type": "spectral-gaussian",
"center": "2",
"width": "3",
"type": "spectral-multi-gaussian",
"center": ["irf_center"],
"width": ["irf_width"],
},
},
"dataset": {
Expand All @@ -42,6 +47,24 @@ def test_coherent_artifact():
},
},
}

parameter_list = [
["rate", 101e-4],
["irf_center", 10, {"vary": False, "non-negative": False}],
["irf_width", 20, {"vary": False, "non-negative": False}],
]

if is_index_dependent is True:
irf_spec = model_dict["irf"]["irf1"]
irf_spec["dispersion_center"] = "irf_dispc"
irf_spec["center_dispersion"] = ["irf_disp1", "irf_disp2"]

parameter_list += [
["irf_dispc", 300, {"vary": False, "non-negative": False}],
["irf_disp1", 0.01, {"vary": False, "non-negative": False}],
["irf_disp2", 0.001, {"vary": False, "non-negative": False}],
]

model = Model.from_dict(
model_dict.copy(),
megacomplex_types={
Expand All @@ -50,23 +73,16 @@ def test_coherent_artifact():
},
)

parameters = ParameterGroup.from_list(
[
101e-4,
[10, {"vary": False, "non-negative": False}],
[20, {"vary": False, "non-negative": False}],
[30, {"vary": False, "non-negative": False}],
]
)
parameters = ParameterGroup.from_list(parameter_list)

time = np.arange(0, 50, 1.5)
spectral = np.asarray([0])
spectral = np.asarray([200, 300, 400])
coords = {"time": time, "spectral": spectral}

dataset_model = model.dataset["dataset1"].fill(model, parameters)
dataset_model.overwrite_global_dimension("spectral")
dataset_model.set_coordinates(coords)
matrix = calculate_matrix(dataset_model, {})
matrix = calculate_matrix(dataset_model, {"spectral": [0, 1, 2]})
compartments = matrix.clp_labels

print(compartments)
Expand All @@ -77,9 +93,9 @@ def test_coherent_artifact():
assert matrix.matrix.shape == (time.size, 4)

clp = xr.DataArray(
[[1, 1, 1, 1]],
np.ones((3, 4)),
coords=[
("spectral", [0]),
("spectral", spectral),
(
"clp_label",
[
Expand All @@ -102,17 +118,20 @@ def test_coherent_artifact():
print(result.optimized_parameters)

for label, param in result.optimized_parameters.all():
assert np.allclose(param.value, parameters.get(label).value, rtol=1e-1)
assert np.allclose(param.value, parameters.get(label).value, rtol=1e-8)

resultdata = result.data["dataset1"]
assert np.array_equal(data.time, resultdata.time)
assert np.array_equal(data.spectral, resultdata.spectral)
assert data.data.shape == resultdata.data.shape
assert data.data.shape == resultdata.fitted_data.shape
assert np.allclose(data.data, resultdata.fitted_data, rtol=1e-2)
assert np.allclose(data.data, resultdata.fitted_data)

assert "coherent_artifact_concentration" in resultdata
assert resultdata["coherent_artifact_concentration"].shape == (time.size, 3)
assert "coherent_artifact_response" in resultdata
if is_index_dependent:
assert resultdata["coherent_artifact_response"].shape == (spectral.size, time.size, 3)
else:
assert resultdata["coherent_artifact_response"].shape == (time.size, 3)

assert "coherent_artifact_associated_spectra" in resultdata
assert resultdata["coherent_artifact_associated_spectra"].shape == (1, 3)
assert resultdata["coherent_artifact_associated_spectra"].shape == (3, 3)
2 changes: 1 addition & 1 deletion glotaran/model/dataset_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_weight(self) -> np.ndarray | None:
"""Gets the dataset model's weight."""
return self._weight

def index_dependent(self) -> bool:
def is_index_dependent(self) -> bool:
"""Indicates if the dataset model is index dependent."""
if hasattr(self, "_index_dependent"):
return self._index_dependent
Expand Down
2 changes: 1 addition & 1 deletion glotaran/model/dataset_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DatasetModel:
def set_data(self, dataset: xr.Dataset) -> DatasetModel: ...
def get_data(self) -> np.ndarray: ...
def get_weight(self) -> np.ndarray | None: ...
def index_dependent(self) -> bool: ...
def is_index_dependent(self) -> bool: ...
def overwrite_index_dependent(self, index_dependent: bool): ...
def has_global_model(self) -> bool: ...
def set_coordinates(self, coords: dict[str, np.ndarray]): ...
Expand Down

0 comments on commit ae63eb4

Please sign in to comment.