Skip to content

Commit

Permalink
Merge pull request #1176 from MatthiasJE/issue1174_improved_callbacks
Browse files Browse the repository at this point in the history
Issue1174 improved callbacks
  • Loading branch information
Holger Kohr authored Oct 6, 2017
2 parents 3fddf3e + c0c4dbc commit 3f5934c
Showing 1 changed file with 64 additions and 24 deletions.
88 changes: 64 additions & 24 deletions odl/solvers/util/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from odl.util import signature_string

__all__ = ('Callback', 'CallbackStore', 'CallbackApply',
'CallbackPrintTiming', 'CallbackPrintIteration',
'CallbackPrint', 'CallbackPrintNorm', 'CallbackShow',
'CallbackSaveToDisk', 'CallbackSleep', 'CallbackShowConvergence',
'CallbackPrintHardwareUsage', 'CallbackProgressBar')
__all__ = ('Callback', 'CallbackStore', 'CallbackApply', 'CallbackPrintTiming',
'CallbackPrintIteration', 'CallbackPrint', 'CallbackPrintNorm',
'CallbackShow', 'CallbackSaveToDisk', 'CallbackSleep',
'CallbackShowConvergence', 'CallbackPrintHardwareUsage',
'CallbackProgressBar')


class Callback(object):
Expand Down Expand Up @@ -335,7 +335,7 @@ class CallbackPrintIteration(Callback):

"""Callback for printing the iteration count."""

def __init__(self, fmt='iter = {}', step=1):
def __init__(self, fmt='iter = {}', step=1, **kwargs):
"""Initialize a new instance.
Parameters
Expand All @@ -349,6 +349,11 @@ def __init__(self, fmt='iter = {}', step=1):
step : positive int, optional
Number of iterations between output.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Create simple callback that prints iteration count:
Expand All @@ -373,11 +378,12 @@ def __init__(self, fmt='iter = {}', step=1):
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0
self.kwargs = kwargs

def __call__(self, _):
"""Print the current iteration."""
if self.iter % self.step == 0:
print(self.fmt.format(self.iter))
print(self.fmt.format(self.iter), **self.kwargs)

self.iter += 1

Expand All @@ -403,7 +409,8 @@ class CallbackPrintTiming(Callback):

"""Callback for printing the time elapsed since the previous iteration."""

def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1):
def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1,
cumulative=False, **kwargs):
"""Initialize a new instance.
Parameters
Expand All @@ -416,31 +423,44 @@ def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1):
where ``runtime`` is the runtime since the last iterate.
step : positive int, optional
Number of iterations between prints.
cumulative : boolean, optional
Print the time since the initialization instead of the last call.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
"""
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0

self.time = time.time()
self.cumulative = cumulative
self.start_time = time.time()
self.kwargs = kwargs

def __call__(self, _):
"""Print time elapsed from the previous iteration."""
if self.iter % self.step == 0:
t = time.time()
print(self.fmt.format(t - self.time))
self.time = t
current_time = time.time()

print(self.fmt.format(current_time - self.start_time),
**self.kwargs)

if not self.cumulative:
self.start_time = current_time

self.iter += 1

def reset(self):
"""Set `time` to the current time."""
self.time = time.time()
self.start_time = time.time()
self.iter = 0

def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('fmt', self.fmt, 'Time elapsed = {:<5.03f} s'),
('step', self.step, 1)]
('step', self.step, 1),
('cumulative', self.cumulative, False)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)

Expand All @@ -449,7 +469,7 @@ class CallbackPrint(Callback):

"""Callback for printing the current value."""

def __init__(self, func=None, fmt='{!r}', step=1):
def __init__(self, func=None, fmt='{!r}', step=1, **kwargs):
"""Initialize a new instance.
Parameters
Expand All @@ -467,6 +487,11 @@ def __init__(self, func=None, fmt='{!r}', step=1):
step : positive int, optional
Number of iterations between prints.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Callback for simply printing the current iterate:
Expand Down Expand Up @@ -499,14 +524,15 @@ def __init__(self, func=None, fmt='{!r}', step=1):
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0
self.kwargs = kwargs

def __call__(self, result):
"""Print the current value."""
if self.iter % self.step == 0:
if self.func is not None:
result = self.func(result)

print(self.fmt.format(result))
print(self.fmt.format(result), **self.kwargs)

self.iter += 1

Expand Down Expand Up @@ -796,8 +822,8 @@ class CallbackShowConvergence(Callback):

"""Displays a convergence plot."""

def __init__(self, functional, title='convergence',
logx=False, logy=False, **kwargs):
def __init__(self, functional, title='convergence', logx=False, logy=False,
**kwargs):
"""Initialize a new instance.
Parameters
Expand All @@ -811,6 +837,9 @@ def __init__(self, functional, title='convergence',
If true, the x axis is logarithmic.
logx : bool, optional
If true, the y axis is logarithmic.
Other Parameters
----------------
kwargs :
Additional parameters passed to the scatter-plotting function.
"""
Expand Down Expand Up @@ -863,7 +892,7 @@ class CallbackPrintHardwareUsage(Callback):
"""

def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',
fmt_mem='RAM usage: {}', fmt_swap='SWAP usage: {}'):
fmt_mem='RAM usage: {}', fmt_swap='SWAP usage: {}', **kwargs):
"""Initialize a new instance.
Parameters
Expand Down Expand Up @@ -893,6 +922,11 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',
where ``swap`` is the current SWAP memory usaged. An empty format
string disables printing of SWAP memory usage.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Print memory and CPU usage
Expand All @@ -914,7 +948,6 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',
self.fmt_cpu = str(fmt_cpu)
self.fmt_mem = str(fmt_mem)
self.fmt_swap = str(fmt_swap)

self.iter = 0

def __call__(self, _):
Expand All @@ -924,11 +957,14 @@ def __call__(self, _):

if self.iter % self.step == 0:
if self.fmt_cpu:
print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True)))
print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True)),
**self.kwargs)
if self.fmt_mem:
print(self.fmt_mem.format(psutil.virtual_memory()))
print(self.fmt_mem.format(psutil.virtual_memory()),
**self.kwargs)
if self.fmt_swap:
print(self.fmt_swap.format(psutil.swap_memory()))
print(self.fmt_swap.format(psutil.swap_memory()),
**self.kwargs)

self.iter += 1

Expand Down Expand Up @@ -962,6 +998,9 @@ def __init__(self, niter, step=1, **kwargs):
Total number of iterations.
step : positive int, optional
Number of iterations between output.
Other Parameters
----------------
kwargs :
Further parameters passed to ``tqdm.tqdm``.
"""
Expand Down Expand Up @@ -995,6 +1034,7 @@ def __repr__(self):
return '{}({})'.format(self.__class__.__name__,
inner_str)


if __name__ == '__main__':
# pylint: disable=wrong-import-position
from odl.util.testutils import run_doctests
Expand Down

0 comments on commit 3f5934c

Please sign in to comment.