diff --git a/glotaran/analysis/optimization_group.py b/glotaran/analysis/optimization_group.py index c11d0e878..42bb4b38d 100644 --- a/glotaran/analysis/optimization_group.py +++ b/glotaran/analysis/optimization_group.py @@ -79,7 +79,8 @@ def __init__( self._model.validate(raise_exception=True) - self._prepare_data(scheme) + self._prepare_data(scheme, list(dataset_group.dataset_models.keys())) + self._dataset_labels = list(self.data.keys()) link_clp = dataset_group.model.link_clp if link_clp is None: @@ -204,7 +205,8 @@ def reset(self): """Resets all results and `DatasetModels`. Use after updating parameters.""" self._dataset_models = { label: dataset_model.fill(self._model, self._parameters).set_data(self.data[label]) - for label, dataset_model in self._model.dataset.items() + for label, dataset_model in self.model.dataset.items() + if label in self._dataset_labels } if self._overwrite_index_dependent: for d in self._dataset_models.values(): @@ -221,10 +223,12 @@ def _reset_results(self): self._additional_penalty = None self._full_penalty = None - def _prepare_data(self, scheme: Scheme): + def _prepare_data(self, scheme: Scheme, labels: list[str]): self._data = {} self._dataset_models = {} for label, dataset in scheme.data.items(): + if label not in labels: + continue if isinstance(dataset, xr.DataArray): dataset = dataset.to_dataset(name="data") diff --git a/glotaran/model/model.py b/glotaran/model/model.py index e1cf2b1ac..80095fef7 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -322,15 +322,16 @@ def need_index_dependent(self) -> bool: return any(i.interval is not None for i in self.clp_constraints + self.clp_relations) def is_groupable(self, parameters: ParameterGroup, data: dict[str, xr.DataArray]) -> bool: - if any(d.has_global_model() for d in self.dataset.values()): + dataset_models = {label: self.dataset[label] for label in data} + if any(d.has_global_model() for d in dataset_models.values()): return False global_dimensions = { d.fill(self, parameters).set_data(data[k]).get_global_dimension() - for k, d in self.dataset.items() + for k, d in dataset_models.items() } model_dimensions = { d.fill(self, parameters).set_data(data[k]).get_model_dimension() - for k, d in self.dataset.items() + for k, d in dataset_models.items() } return len(global_dimensions) == 1 and len(model_dimensions) == 1