Skip to content

Commit

Permalink
Updated docstring on history module (#1230)
Browse files Browse the repository at this point in the history
* Updated docstring on history module

* Apply suggestions from code review

Co-authored-by: Daniel Weindl <[email protected]>

* Updated references

---------

Co-authored-by: Daniel Weindl <[email protected]>
  • Loading branch information
PaulJonasJost and dweindl authored Dec 1, 2023
1 parent fd64c8f commit 00506f6
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 94 deletions.
24 changes: 11 additions & 13 deletions pypesto/history/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numbers
import time
from abc import ABC, abstractmethod
from typing import Dict, Sequence, Tuple, Union
from typing import Sequence, Union

import numpy as np

Expand Down Expand Up @@ -38,7 +38,7 @@ class HistoryBase(ABC):
# all possible history entries
ALL_KEYS = (X, *RESULT_KEYS, TIME)

def __init__(self, options: HistoryOptions = None):
def __init__(self, options: Union[HistoryOptions, None] = None):
if options is None:
options = HistoryOptions()
options = HistoryOptions.assert_instance(options)
Expand All @@ -48,7 +48,7 @@ def __init__(self, options: HistoryOptions = None):
def update(
self,
x: np.ndarray,
sensi_orders: Tuple[int, ...],
sensi_orders: tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
Expand All @@ -70,8 +70,8 @@ def update(

def finalize(
self,
message: str = None,
exitflag: str = None,
message: Union[str, None] = None,
exitflag: Union[str, None] = None,
) -> None:
"""
Finalize history. Called after a run. Default: Do nothing.
Expand Down Expand Up @@ -281,7 +281,7 @@ class NoHistory(HistoryBase):
def update( # noqa: D102
self,
x: np.ndarray,
sensi_orders: Tuple[int, ...],
sensi_orders: tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
Expand Down Expand Up @@ -364,7 +364,7 @@ class CountHistoryBase(HistoryBase):
Needs a separate implementation of trace.
"""

def __init__(self, options: Union[HistoryOptions, Dict] = None):
def __init__(self, options: Union[HistoryOptions, dict] = None):
super().__init__(options)
self._n_fval: int = 0
self._n_grad: int = 0
Expand All @@ -378,15 +378,15 @@ def __init__(self, options: Union[HistoryOptions, Dict] = None):
def update( # noqa: D102
self,
x: np.ndarray,
sensi_orders: Tuple[int, ...],
sensi_orders: tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
self._update_counts(sensi_orders, mode)

def _update_counts(
self,
sensi_orders: Tuple[int, ...],
sensi_orders: tuple[int, ...],
mode: ModeType,
):
"""Update the counters."""
Expand Down Expand Up @@ -499,8 +499,7 @@ def add_fun_from_res(result: ResultDict) -> ResultDict:
Returns
-------
full_result:
Result dicionary, adding whatever is possible to calculate.
Result dictionary, adding whatever is possible to calculate.
"""
result = result.copy()

Expand Down Expand Up @@ -529,8 +528,7 @@ def reduce_result_via_options(
Returns
-------
result:
Result reduced to what is intended to be stored in history.
Result reduced to what is intended to be stored in history.
"""
result = result.copy()

Expand Down
34 changes: 17 additions & 17 deletions pypesto/history/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import os
import time
from typing import Dict, List, Sequence, Tuple, Union
from typing import Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -41,14 +41,14 @@ class CsvHistory(CountHistoryBase):
options:
History options.
load_from_file:
If True, history will be initialized from data in the specified file
If True, history will be initialized from data in the specified file.
"""

def __init__(
self,
file: str,
x_names: Sequence[str] = None,
options: Union[HistoryOptions, Dict] = None,
options: Union[HistoryOptions, dict] = None,
load_from_file: bool = False,
):
super().__init__(options=options)
Expand Down Expand Up @@ -87,16 +87,16 @@ def _update_counts_from_trace(self) -> None:
def update(
self,
x: np.ndarray,
sensi_orders: Tuple[int, ...],
sensi_orders: tuple[int, ...],
mode: ModeType,
result: ResultDict,
) -> None:
"""See `History` docstring."""
"""See :meth:`HistoryBase.update`."""
super().update(x, sensi_orders, mode, result)
self._update_trace(x, mode, result)

def finalize(self, message: str = None, exitflag: str = None):
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.finalize`."""
super().finalize(message=message, exitflag=exitflag)
self._save_trace(finalize=True)

Expand Down Expand Up @@ -167,7 +167,7 @@ def _init_trace(self, x: np.ndarray):
if self.x_names is None:
self.x_names = [f'x{i}' for i, _ in enumerate(x)]

columns: List[Tuple] = [
columns: list[tuple] = [
(c, np.nan)
for c in [
TIME,
Expand Down Expand Up @@ -213,7 +213,7 @@ def _init_trace(self, x: np.ndarray):

def _save_trace(self, finalize: bool = False):
"""
Save to file via pd.DataFrame.to_csv().
Save to file via :meth:`pandas.DataFrame.to_csv`.
Only done, if `self.storage_file` is not None and other conditions.
apply.
Expand Down Expand Up @@ -243,49 +243,49 @@ def get_x_trace(
ix: Union[int, Sequence[int], None] = None,
trim: bool = False,
) -> Union[Sequence[np.ndarray], np.ndarray]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_x_trace`."""
return list(self._trace[X].values[ix])

@trace_wrap
def get_fval_trace(
self, ix: Union[int, Sequence[int], None], trim: bool = False
) -> Union[Sequence[float], float]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_fval_trace`."""
return list(self._trace[(FVAL, np.nan)].values[ix])

@trace_wrap
def get_grad_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_grad_trace`."""
return list(self._trace[GRAD].values[ix])

@trace_wrap
def get_hess_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_hess_trace`."""
return list(self._trace[(HESS, np.nan)].values[ix])

@trace_wrap
def get_res_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_res_trace`."""
return list(self._trace[(RES, np.nan)].values[ix])

@trace_wrap
def get_sres_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[MaybeArray], MaybeArray]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_sres_trace`."""
return list(self._trace[(SRES, np.nan)].values[ix])

@trace_wrap
def get_time_trace(
self, ix: Union[int, Sequence[int], None] = None, trim: bool = False
) -> Union[Sequence[float], float]:
"""See `HistoryBase` docstring."""
"""See :meth:`HistoryBase.get_time_trace`."""
return list(self._trace[(TIME, np.nan)].values[ix])


Expand All @@ -301,7 +301,7 @@ def ndarray2string_full(x: Union[np.ndarray, None]) -> Union[str, None]:
Returns
-------
x: array as string.
Array as string.
"""
if not isinstance(x, np.ndarray):
return x
Expand All @@ -320,7 +320,7 @@ def string2ndarray(x: Union[str, float]) -> Union[np.ndarray, float]:
Returns
-------
x: array as np.ndarray.
Array as :class:`numpy.ndarray`.
"""
if not isinstance(x, str):
return x
Expand Down
3 changes: 1 addition & 2 deletions pypesto/history/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def create_history(
Returns
-------
history:
A history object corresponding to the inputs.
A history object corresponding to the inputs.
"""
# create different history types based on the inputs
if options.storage_file is None:
Expand Down
Loading

0 comments on commit 00506f6

Please sign in to comment.