Skip to content

Commit

Permalink
🩹 Fixed optimaization groups using datasets outside of their groups
Browse files Browse the repository at this point in the history
This bug led to result creation crashing, because of missing labels.

Co-authored-by: Jörn Weißenborn <[email protected]>
  • Loading branch information
s-weigand and joernweissenborn committed Oct 14, 2021
1 parent f96af6d commit cc4beb9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 7 additions & 3 deletions glotaran/analysis/optimization_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(

self._model.validate(raise_exception=True)

self._prepare_data(scheme)
self._prepare_data(scheme, list(dataset_group.dataset_models.keys()))
self._dataset_labels = list(self.data.keys())

link_clp = dataset_group.model.link_clp
if link_clp is None:
Expand Down Expand Up @@ -204,7 +205,8 @@ def reset(self):
"""Resets all results and `DatasetModels`. Use after updating parameters."""
self._dataset_models = {
label: dataset_model.fill(self._model, self._parameters).set_data(self.data[label])
for label, dataset_model in self._model.dataset.items()
for label, dataset_model in self.model.dataset.items()
if label in self._dataset_labels
}
if self._overwrite_index_dependent:
for d in self._dataset_models.values():
Expand All @@ -221,10 +223,12 @@ def _reset_results(self):
self._additional_penalty = None
self._full_penalty = None

def _prepare_data(self, scheme: Scheme):
def _prepare_data(self, scheme: Scheme, labels: list[str]):
self._data = {}
self._dataset_models = {}
for label, dataset in scheme.data.items():
if label not in labels:
continue
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset(name="data")

Expand Down
7 changes: 4 additions & 3 deletions glotaran/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,16 @@ def need_index_dependent(self) -> bool:
return any(i.interval is not None for i in self.clp_constraints + self.clp_relations)

def is_groupable(self, parameters: ParameterGroup, data: dict[str, xr.DataArray]) -> bool:
if any(d.has_global_model() for d in self.dataset.values()):
dataset_models = {label: self.dataset[label] for label in data}
if any(d.has_global_model() for d in dataset_models.values()):
return False
global_dimensions = {
d.fill(self, parameters).set_data(data[k]).get_global_dimension()
for k, d in self.dataset.items()
for k, d in dataset_models.items()
}
model_dimensions = {
d.fill(self, parameters).set_data(data[k]).get_model_dimension()
for k, d in self.dataset.items()
for k, d in dataset_models.items()
}
return len(global_dimensions) == 1 and len(model_dimensions) == 1

Expand Down

0 comments on commit cc4beb9

Please sign in to comment.