diff --git a/odl/solvers/util/callback.py b/odl/solvers/util/callback.py index 6309f1bcf33..84a897c917f 100644 --- a/odl/solvers/util/callback.py +++ b/odl/solvers/util/callback.py @@ -19,11 +19,11 @@ from odl.util import signature_string -__all__ = ('Callback', 'CallbackStore', 'CallbackApply', - 'CallbackPrintTiming', 'CallbackPrintIteration', - 'CallbackPrint', 'CallbackPrintNorm', 'CallbackShow', - 'CallbackSaveToDisk', 'CallbackSleep', 'CallbackShowConvergence', - 'CallbackPrintHardwareUsage', 'CallbackProgressBar') +__all__ = ('Callback', 'CallbackStore', 'CallbackApply', 'CallbackPrintTiming', + 'CallbackPrintIteration', 'CallbackPrint', 'CallbackPrintNorm', + 'CallbackShow', 'CallbackSaveToDisk', 'CallbackSleep', + 'CallbackShowConvergence', 'CallbackPrintHardwareUsage', + 'CallbackProgressBar') class Callback(object): @@ -335,7 +335,7 @@ class CallbackPrintIteration(Callback): """Callback for printing the iteration count.""" - def __init__(self, fmt='iter = {}', step=1): + def __init__(self, fmt='iter = {}', step=1, **kwargs): """Initialize a new instance. Parameters @@ -349,6 +349,11 @@ def __init__(self, fmt='iter = {}', step=1): step : positive int, optional Number of iterations between output. + Other Parameters + ---------------- + kwargs : + Key word arguments passed to the print function. + Examples -------- Create simple callback that prints iteration count: @@ -373,11 +378,12 @@ def __init__(self, fmt='iter = {}', step=1): self.fmt = str(fmt) self.step = int(step) self.iter = 0 + self.kwargs = kwargs def __call__(self, _): """Print the current iteration.""" if self.iter % self.step == 0: - print(self.fmt.format(self.iter)) + print(self.fmt.format(self.iter), **self.kwargs) self.iter += 1 @@ -403,7 +409,8 @@ class CallbackPrintTiming(Callback): """Callback for printing the time elapsed since the previous iteration.""" - def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1): + def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1, + cumulative=False, **kwargs): """Initialize a new instance. Parameters @@ -416,31 +423,44 @@ def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1): where ``runtime`` is the runtime since the last iterate. step : positive int, optional Number of iterations between prints. + cumulative : boolean, optional + Print the time since the initialization instead of the last call. + + Other Parameters + ---------------- + kwargs : + Key word arguments passed to the print function. """ self.fmt = str(fmt) self.step = int(step) self.iter = 0 - - self.time = time.time() + self.cumulative = cumulative + self.start_time = time.time() + self.kwargs = kwargs def __call__(self, _): """Print time elapsed from the previous iteration.""" if self.iter % self.step == 0: - t = time.time() - print(self.fmt.format(t - self.time)) - self.time = t + current_time = time.time() + + print(self.fmt.format(current_time - self.start_time), + **self.kwargs) + + if not self.cumulative: + self.start_time = current_time self.iter += 1 def reset(self): """Set `time` to the current time.""" - self.time = time.time() + self.start_time = time.time() self.iter = 0 def __repr__(self): """Return ``repr(self)``.""" optargs = [('fmt', self.fmt, 'Time elapsed = {:<5.03f} s'), - ('step', self.step, 1)] + ('step', self.step, 1), + ('cumulative', self.cumulative, False)] inner_str = signature_string([], optargs) return '{}({})'.format(self.__class__.__name__, inner_str) @@ -449,7 +469,7 @@ class CallbackPrint(Callback): """Callback for printing the current value.""" - def __init__(self, func=None, fmt='{!r}', step=1): + def __init__(self, func=None, fmt='{!r}', step=1, **kwargs): """Initialize a new instance. Parameters @@ -467,6 +487,11 @@ def __init__(self, func=None, fmt='{!r}', step=1): step : positive int, optional Number of iterations between prints. + Other Parameters + ---------------- + kwargs : + Key word arguments passed to the print function. + Examples -------- Callback for simply printing the current iterate: @@ -499,6 +524,7 @@ def __init__(self, func=None, fmt='{!r}', step=1): self.fmt = str(fmt) self.step = int(step) self.iter = 0 + self.kwargs = kwargs def __call__(self, result): """Print the current value.""" @@ -506,7 +532,7 @@ def __call__(self, result): if self.func is not None: result = self.func(result) - print(self.fmt.format(result)) + print(self.fmt.format(result), **self.kwargs) self.iter += 1 @@ -796,8 +822,8 @@ class CallbackShowConvergence(Callback): """Displays a convergence plot.""" - def __init__(self, functional, title='convergence', - logx=False, logy=False, **kwargs): + def __init__(self, functional, title='convergence', logx=False, logy=False, + **kwargs): """Initialize a new instance. Parameters @@ -811,6 +837,9 @@ def __init__(self, functional, title='convergence', If true, the x axis is logarithmic. logx : bool, optional If true, the y axis is logarithmic. + + Other Parameters + ---------------- kwargs : Additional parameters passed to the scatter-plotting function. """ @@ -863,7 +892,7 @@ class CallbackPrintHardwareUsage(Callback): """ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}', - fmt_mem='RAM usage: {}', fmt_swap='SWAP usage: {}'): + fmt_mem='RAM usage: {}', fmt_swap='SWAP usage: {}', **kwargs): """Initialize a new instance. Parameters @@ -893,6 +922,11 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}', where ``swap`` is the current SWAP memory usaged. An empty format string disables printing of SWAP memory usage. + Other Parameters + ---------------- + kwargs : + Key word arguments passed to the print function. + Examples -------- Print memory and CPU usage @@ -914,7 +948,6 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}', self.fmt_cpu = str(fmt_cpu) self.fmt_mem = str(fmt_mem) self.fmt_swap = str(fmt_swap) - self.iter = 0 def __call__(self, _): @@ -924,11 +957,14 @@ def __call__(self, _): if self.iter % self.step == 0: if self.fmt_cpu: - print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True))) + print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True)), + **self.kwargs) if self.fmt_mem: - print(self.fmt_mem.format(psutil.virtual_memory())) + print(self.fmt_mem.format(psutil.virtual_memory()), + **self.kwargs) if self.fmt_swap: - print(self.fmt_swap.format(psutil.swap_memory())) + print(self.fmt_swap.format(psutil.swap_memory()), + **self.kwargs) self.iter += 1 @@ -962,6 +998,9 @@ def __init__(self, niter, step=1, **kwargs): Total number of iterations. step : positive int, optional Number of iterations between output. + + Other Parameters + ---------------- kwargs : Further parameters passed to ``tqdm.tqdm``. """ @@ -995,6 +1034,7 @@ def __repr__(self): return '{}({})'.format(self.__class__.__name__, inner_str) + if __name__ == '__main__': # pylint: disable=wrong-import-position from odl.util.testutils import run_doctests