diff --git a/tardis/simulation/convergence.py b/tardis/simulation/convergence.py index a769256fd09..537eb3337a1 100644 --- a/tardis/simulation/convergence.py +++ b/tardis/simulation/convergence.py @@ -1,8 +1,25 @@ import numpy as np +from numba import njit + +from tardis.montecarlo.montecarlo_numba import njit_dict_no_parallel class ConvergenceSolver: def __init__(self, strategy): + """_summary_ + + Parameters + ---------- + strategy : _type_ + Convergence strategy for the physical property + + Raises + ------ + NotImplementedError + Custom convergence type specified + ValueError + Unknown convergence type specified + """ self.convergence_strategy = strategy self.damping_factor = self.convergence_strategy.damping_constant self.threshold = self.convergence_strategy.threshold @@ -21,10 +38,41 @@ def __init__(self, strategy): f"- input is {self.convergence_strategy.type}" ) + @njit(**njit_dict_no_parallel) def damped_converge(self, value, estimated_value): + """Damped convergence solver + + Parameters + ---------- + value : np.float64 + The current value of the physical property + estimated_value : np.float64 + The estimated value of the physical property + + Returns + ------- + np.float64 + The converged value + """ return value + self.damping_factor * (estimated_value - value) def get_convergence_status(self, value, estimated_value, no_of_cells): + """Get the status of convergence for the physical property + + Parameters + ---------- + value : np.float64, Quantity + The current value of the physical property + estimated_value : np.float64, Quantity + The estimated value of the physical property + no_of_cells : np.int64 + The number of cells to measure convergence over + + Returns + ------- + bool + True if convergence is reached + """ convergence = abs(value - estimated_value) / estimated_value fraction_converged = (