Skip to content

Commit

Permalink
Adds tests to convergence solver
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed May 29, 2024
1 parent e5e37ad commit bf5a062
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions tardis/simulation/tests/test_convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pathlib import Path

import numpy as np
import pytest

from tardis.io.configuration.config_reader import Configuration
from tardis.simulation.convergence import ConvergenceSolver


@pytest.fixture(scope="function")
def config(example_configuration_dir: Path):
return Configuration.from_yaml(
example_configuration_dir / "tardis_configv1_verysimple.yml"
)


@pytest.fixture(scope="function")
def strategy(config):
return config.montecarlo.convergence_strategy.t_rad

def test_convergence_solver_init_damped(strategy):
solver = ConvergenceSolver(strategy)
assert solver.damping_factor == 0.5
assert solver.threshold == 0.05
assert solver.converge == solver.damped_converge

def test_convergence_solver_init_custom(strategy):
strategy.type = 'custom'
with pytest.raises(NotImplementedError):
ConvergenceSolver(strategy)

def test_convergence_solver_init_invalid(strategy):
strategy.type = 'invalid'
with pytest.raises(ValueError):
ConvergenceSolver(strategy)

def test_damped_converge(strategy):
solver = ConvergenceSolver(strategy)
value = np.float64(10.0)
estimated_value = np.float64(20.0)
converged_value = solver.damped_converge(value, estimated_value)
assert converged_value == 15.0

def test_get_convergence_status(strategy):
solver = ConvergenceSolver(strategy)
value = np.array([1.0, 2.0, 3.0], dtype=np.float64)
estimated_value = np.array([1.01, 2.02, 3.03], dtype=np.float64)
no_of_cells = np.int64(3)
status = solver.get_convergence_status(value, estimated_value, no_of_cells)
assert status

value = np.array([1.0, 2.0, 3.0], dtype=np.float64)
estimated_value = np.array([2.0, 3.0, 4.0], dtype=np.float64)
status = solver.get_convergence_status(value, estimated_value, no_of_cells)
assert not status

0 comments on commit bf5a062

Please sign in to comment.