Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Use xarray internally and move relations/constraints/penaltys to glotaran.model #734

53 changes: 2 additions & 51 deletions glotaran/analysis/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def __init__(self, scheme: Scheme):

# all of the above are always not None

self._clp_labels = None
self._matrices = None
self._reduced_clp_labels = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
Expand Down Expand Up @@ -166,14 +164,6 @@ def groups(self) -> dict[str, list[str]]:
self.init_bag()
return self._groups

@property
def clp_labels(
self,
) -> dict[str, list[str] | list[list[str]]]:
if self._clp_labels is None:
self.calculate_matrices()
return self._clp_labels

@property
def matrices(
self,
Expand All @@ -182,14 +172,6 @@ def matrices(
self.calculate_matrices()
return self._matrices

@property
def reduced_clp_labels(
self,
) -> dict[str, list[str] | list[list[str]]]:
if self._reduced_clp_labels is None:
self.calculate_matrices()
return self._reduced_clp_labels

@property
def reduced_matrices(
self,
Expand Down Expand Up @@ -235,23 +217,12 @@ def additional_penalty(
self,
) -> dict[str, list[float]]:
if self._additional_penalty is None:
self.calculate_additional_penalty()
self.calculate_residual()
return self._additional_penalty

@property
def full_penalty(self) -> np.ndarray:
if self._full_penalty is None:
residuals = self.weighted_residuals
additional_penalty = self.additional_penalty
if not self.grouped:
residuals = [np.concatenate(residuals[label]) for label in residuals.keys()]

self._full_penalty = (
np.concatenate((np.concatenate(residuals), additional_penalty))
if additional_penalty is not None
else np.concatenate(residuals)
)
return self._full_penalty
raise NotImplementedError

@property
def cost(self) -> float:
Expand All @@ -272,9 +243,7 @@ def reset(self):
self._reset_results()

def _reset_results(self):
self._clp_labels = None
self._matrices = None
self._reduced_clp_labels = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
Expand Down Expand Up @@ -372,24 +341,6 @@ def calculate_matrices(self):
def calculate_residual(self):
raise NotImplementedError

def calculate_additional_penalty(self) -> np.ndarray | dict[str, np.ndarray]:
"""Calculates additional penalties by calling the model.additional_penalty function."""
if (
callable(self.model.has_additional_penalty_function)
and self.model.has_additional_penalty_function()
):
self._additional_penalty = self.model.additional_penalty_function(
self.parameters,
self.clp_labels,
self.clps,
self.matrices,
self.data,
self._scheme.group_tolerance,
)
else:
self._additional_penalty = None
return self._additional_penalty

def create_result_data(
self, copy: bool = True, history_index: int | None = None
) -> dict[str, xr.Dataset]:
Expand Down
Loading