Skip to content

Commit

Permalink
✨ Added iteration to parameter history
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Oct 2, 2022
1 parent 8015f02 commit 789583d
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 30 deletions.
46 changes: 26 additions & 20 deletions glotaran/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from glotaran.parameter import ParameterHistory
from glotaran.project import Result
from glotaran.project import Scheme
from glotaran.utils.tee import TeeContext
from glotaran.utils.tee import get_current_optimization_iteration

SUPPORTED_METHODS = {
"TrustRegionReflection": "trf",
Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(self, scheme: Scheme, verbose: bool = True, raise_exception: bool =
self._method = SUPPORTED_METHODS[scheme.optimization_method]

self._scheme = scheme
self._tee = TeeContext()
self._verbose = verbose
self._raise = raise_exception

Expand Down Expand Up @@ -131,25 +134,26 @@ def optimize(self):
lower_bounds,
upper_bounds,
) = self._scheme.parameters.get_label_value_and_bounds_arrays(exclude_non_vary=True)
try:
verbose = 2 if self._verbose else 0
self._optimization_result = least_squares(
self.objective_function,
initial_parameter,
bounds=(lower_bounds, upper_bounds),
method=self._method,
max_nfev=self._scheme.maximum_number_function_evaluations,
verbose=verbose,
ftol=self._scheme.ftol,
gtol=self._scheme.gtol,
xtol=self._scheme.xtol,
)
self._termination_reason = self._optimization_result.message
except Exception as e:
if self._raise:
raise e
warn(f"Optimization failed:\n\n{e}")
self._termination_reason = str(e)
with self._tee:
try:
verbose = 2 if self._verbose else 0
self._optimization_result = least_squares(
self.objective_function,
initial_parameter,
bounds=(lower_bounds, upper_bounds),
method=self._method,
max_nfev=self._scheme.maximum_number_function_evaluations,
verbose=verbose,
ftol=self._scheme.ftol,
gtol=self._scheme.gtol,
xtol=self._scheme.xtol,
)
self._termination_reason = self._optimization_result.message
except Exception as e:
if self._raise:
raise e
warn(f"Optimization failed:\n\n{e}")
self._termination_reason = str(e)

def objective_function(self, parameters: np.typing.ArrayLike) -> np.typing.ArrayLike:
"""Calculate the objective for the optimization.
Expand Down Expand Up @@ -177,7 +181,9 @@ def calculate_penalty(self) -> np.typing.ArrayLike:
"""
for group in self._optimization_groups:
group.calculate(self._parameters)
self._parameter_history.append(self._parameters)
self._parameter_history.append(
self._parameters, get_current_optimization_iteration(self._tee.read())
)

penalties = [group.get_full_penalty() for group in self._optimization_groups]

Expand Down
2 changes: 1 addition & 1 deletion glotaran/parameter/parameter_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def set_from_history(self, history: ParameterHistory, index: int):
The history index.
"""
self.set_from_label_and_value_arrays(
history.parameter_labels, history.get_parameters(index)
history.parameter_labels[1:], history.get_parameters(index)[1:]
)

def update_parameter_expression(self):
Expand Down
7 changes: 5 additions & 2 deletions glotaran/parameter/parameter_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,15 @@ def to_csv(self, file_name: str | PathLike[str], delimiter: str = ","):
self.source_path = Path(file_name).as_posix()
self.to_dataframe().to_csv(file_name, sep=delimiter, index=False)

def append(self, parameter_group: ParameterGroup):
def append(self, parameter_group: ParameterGroup, current_iteration: int = 0):
"""Append a :class:`ParameterGroup` to the history.
Parameters
----------
parameter_group : ParameterGroup
The group to append.
current_iteration: int
Current iteration of the optimizer.
Raises
------
Expand All @@ -143,14 +145,15 @@ def append(self, parameter_group: ParameterGroup):
_,
_,
) = parameter_group.get_label_value_and_bounds_arrays()
parameter_labels = ["iteration", *parameter_labels]
if len(self._parameter_labels) == 0:
self._parameter_labels = parameter_labels
if parameter_labels != self.parameter_labels:
raise ValueError(
"Cannot append parameter group. Parameter labels do not match existing."
)

self._parameters.append(parameter_values)
self._parameters.append(np.array([current_iteration, *parameter_values]))

def get_parameters(self, index: int) -> np.ndarray:
"""Get parameters for a history index.
Expand Down
12 changes: 6 additions & 6 deletions glotaran/parameter/test/test_parameter_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def test_parameter_history():

history.append(group0)

assert history.parameter_labels == ["1", "2"]
assert history.parameter_labels == ["iteration", "1", "2"]

assert history.number_of_records == 1
assert all(history.get_parameters(0) == [1, 4])
assert all(history.get_parameters(0) == [0, 1, 4])

history.append(group1)
history.append(group1, current_iteration=1)

assert history.number_of_records == 2
assert all(history.get_parameters(1) == [2, 5])
assert all(history.get_parameters(1) == [1, 2, 5])

history.append(group2)
history.append(group2, current_iteration=2)

assert history.number_of_records == 3
assert all(history.get_parameters(2) == [3, 6])
assert all(history.get_parameters(2) == [2, 3, 6])

df = history.to_dataframe()

Expand Down
7 changes: 7 additions & 0 deletions glotaran/utils/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ class RegexPattern:
number: re.Pattern = re.compile(r"[\d.+-]+")
tuple_word: re.Pattern = re.compile(r"(\([.\s\w\d]+?[,.\s\w\d]*?\))")
tuple_number: re.Pattern = re.compile(r"(\([\s\d.+-]+?[,\s\d.+-]*?\))")
optimization_stdout: re.Pattern = re.compile(
r"^\s+(?P<iteration>\d+)\s+(?P<nfev>\d+)"
r"\s+(?P<cost>\d\.\d+e[+-]\d+)"
r"(\s+(?P<cost_reduction>\d\.\d+e[+-]\d+)\s+(?P<step_norm>\d\.\d+e[+-]\d+)|\s+)"
r"\s+(?P<optimality>\d\.\d+e[+-]\d+)\s*?$",
re.MULTILINE,
)
20 changes: 19 additions & 1 deletion glotaran/utils/tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from io import StringIO
from types import TracebackType

from glotaran.utils.regex import RegexPattern


class TeeContext:
"""Context manager that allows to work with string written to stdout."""
Expand Down Expand Up @@ -33,7 +35,6 @@ def __exit__(
) -> bool | None:
"""Restore ``sys.stdout`` on exiting the context."""
sys.stdout = self.stdout
self.buffer.close()
return None

def write(self, data: str) -> None:
Expand All @@ -60,3 +61,20 @@ def read(self) -> str:
def flush(self) -> None:
"""Flush values in the buffer."""
self.buffer.flush()


def get_current_optimization_iteration(optimize_stdout: str) -> int:
"""Extract current iteration from ``optimize_stdout``.
Parameters
----------
optimize_stdout: str
Scipy optimization stdout string, read out via ``TeeContext.read()``.
Returns
-------
int
Current iteration (``0`` if pattern did not match).
"""
matches = RegexPattern.optimization_stdout.findall(optimize_stdout)
return 0 if len(matches) == 0 else int(matches[-1][0])
46 changes: 46 additions & 0 deletions glotaran/utils/test/test_tee.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Test for ``glotaran.utils.tee``."""
from __future__ import annotations

from textwrap import dedent
from typing import TYPE_CHECKING

import pytest

from glotaran.utils.tee import TeeContext
from glotaran.utils.tee import get_current_optimization_iteration

if TYPE_CHECKING:
from _pytest.capture import CaptureFixture
Expand All @@ -21,3 +25,45 @@ def test_tee_context(capsys: CaptureFixture):

assert stdout == expected
assert result == expected


@pytest.mark.parametrize(
"optimize_stdout, expected",
(
("random string", 0),
(
dedent(
"""\
Iteration Total nfev Cost Cost reduction Step norm Optimality
""" # noqa: E501
),
0,
),
(
dedent(
"""\
Iteration Total nfev Cost Cost reduction Step norm Optimality
0 1 7.5834e+00 3.84e+01
1 2 7.5833e+00 1.37e-04 4.55e-05 1.26e-01
""" # noqa: E501
),
1,
),
(
dedent(
"""\
Iteration Total nfev Cost Cost reduction Step norm Optimality
0 1 7.5834e+00 3.84e+01
1 2 7.5833e+00 1.37e-04 4.55e-05 1.26e-01
2 3 7.5833e+00 6.02e-11 6.44e-09 1.64e-05
Both `ftol` and `xtol` termination conditions are satisfied.
Function evaluations 3, initial cost 7.5834e+00, final cost 7.5833e+00, first-order optimality 1.64e-05.
""" # noqa: E501
),
2,
),
),
)
def test_get_current_optimization_iteration(optimize_stdout: str, expected: int):
"""Test that the correct iteration is returned."""
assert get_current_optimization_iteration(optimize_stdout) == expected

0 comments on commit 789583d

Please sign in to comment.