Skip to content

Commit

Permalink
✨ Use xarray internally and move relations/constraints/penaltys to gl…
Browse files Browse the repository at this point in the history
…otaran.model (#734)

* Changed calculate_matrix function to return xarrays

* Added relation and constraint to basemodel

* Added functions to apply relations and constraints

* Removed LabelAndMatrix class

* Added relations and constraints tests to base model

* Added penalties to base model

* Adapted models to changes

* Address xarray deprecation warnings

Fixed DeprecationWarning: Using a DataArray object to construct a variable is ambiguous, please extract the data using the .data property. This will raise a TypeError in 0.19.0.

Co-authored-by: Jörn Weißenborn <[email protected]>
Co-authored-by: Joris Snellenburg <[email protected]>
  • Loading branch information
3 people authored Jul 4, 2021
1 parent b953319 commit eb2204e
Show file tree
Hide file tree
Showing 30 changed files with 977 additions and 787 deletions.
53 changes: 2 additions & 51 deletions glotaran/analysis/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def __init__(self, scheme: Scheme):

# all of the above are always not None

self._clp_labels = None
self._matrices = None
self._reduced_clp_labels = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
Expand Down Expand Up @@ -166,14 +164,6 @@ def groups(self) -> dict[str, list[str]]:
self.init_bag()
return self._groups

@property
def clp_labels(
self,
) -> dict[str, list[str] | list[list[str]]]:
if self._clp_labels is None:
self.calculate_matrices()
return self._clp_labels

@property
def matrices(
self,
Expand All @@ -182,14 +172,6 @@ def matrices(
self.calculate_matrices()
return self._matrices

@property
def reduced_clp_labels(
self,
) -> dict[str, list[str] | list[list[str]]]:
if self._reduced_clp_labels is None:
self.calculate_matrices()
return self._reduced_clp_labels

@property
def reduced_matrices(
self,
Expand Down Expand Up @@ -235,23 +217,12 @@ def additional_penalty(
self,
) -> dict[str, list[float]]:
if self._additional_penalty is None:
self.calculate_additional_penalty()
self.calculate_residual()
return self._additional_penalty

@property
def full_penalty(self) -> np.ndarray:
if self._full_penalty is None:
residuals = self.weighted_residuals
additional_penalty = self.additional_penalty
if not self.grouped:
residuals = [np.concatenate(residuals[label]) for label in residuals.keys()]

self._full_penalty = (
np.concatenate((np.concatenate(residuals), additional_penalty))
if additional_penalty is not None
else np.concatenate(residuals)
)
return self._full_penalty
raise NotImplementedError

@property
def cost(self) -> float:
Expand All @@ -272,9 +243,7 @@ def reset(self):
self._reset_results()

def _reset_results(self):
self._clp_labels = None
self._matrices = None
self._reduced_clp_labels = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
Expand Down Expand Up @@ -372,24 +341,6 @@ def calculate_matrices(self):
def calculate_residual(self):
raise NotImplementedError

def calculate_additional_penalty(self) -> np.ndarray | dict[str, np.ndarray]:
"""Calculates additional penalties by calling the model.additional_penalty function."""
if (
callable(self.model.has_additional_penalty_function)
and self.model.has_additional_penalty_function()
):
self._additional_penalty = self.model.additional_penalty_function(
self.parameters,
self.clp_labels,
self.clps,
self.matrices,
self.data,
self._scheme.group_tolerance,
)
else:
self._additional_penalty = None
return self._additional_penalty

def create_result_data(
self, copy: bool = True, history_index: int | None = None
) -> dict[str, xr.Dataset]:
Expand Down
Loading

0 comments on commit eb2204e

Please sign in to comment.