diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index d6411348c5..18e62c5dfa 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -145,8 +145,8 @@ Dependency `pre-commit `__ \- dev For managing and maintaining multi-language pre-commit hooks. `ruff `__ \- dev For code formatting. `nox `__ \- dev For running testing sessions in multiple environments. +`pytest-subtests `__ \- dev For subtests pytest fixture. `pytest-cov `__ \- dev For calculating test coverage. -`parameterized `__ \- dev For test parameterization. `pytest `__ 6.0.0 dev For running the test suites. `pytest-doctestplus `__ \- dev For running doctests. `pytest-xdist `__ \- dev For running tests in parallel across distributed workers. diff --git a/pyproject.toml b/pyproject.toml index 7fb1a5ce95..d2c487ede4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,12 +105,11 @@ dev = [ "pytest-cov", # For doctest "pytest-doctestplus", - # For test parameterization - "parameterized>=0.9", # pytest and its plugins "pytest>=6", "pytest-xdist", "pytest-mock", + "pytest-subtests", # For testing Jupyter notebooks "nbmake", # To access the metadata for python packages @@ -230,6 +229,7 @@ minversion = "8" required_plugins = [ "pytest-xdist", "pytest-mock", + "pytest-subtests", ] norecursedirs = 'pybind11*' addopts = [ diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 4f981ba04c..394d64b257 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -1,11 +1,11 @@ # # Test setting up a simulation with an experiment # +import pytest import casadi import pybamm import numpy as np import os -import unittest from datetime import datetime @@ -15,7 +15,7 @@ def default_duration(self, value): return 1 -class TestSimulationExperiment(unittest.TestCase): +class TestSimulationExperiment: def test_set_up(self): experiment = pybamm.Experiment( [ @@ -29,24 +29,22 @@ def test_set_up(self): sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertEqual(sim.experiment.args, experiment.args) + assert sim.experiment.args == experiment.args steps = sim.experiment.steps model_I = sim.experiment_unique_steps_to_model[ steps[1].basic_repr() ] # CC charge model_V = sim.experiment_unique_steps_to_model[steps[2].basic_repr()] # CV hold - self.assertIn( - "Current cut-off [A] [experiment]", - [event.name for event in model_V.events], - ) - self.assertIn( - "Charge voltage cut-off [V] [experiment]", - [event.name for event in model_I.events], - ) + assert "Current cut-off [A] [experiment]" in [ + event.name for event in model_V.events + ] + assert "Charge voltage cut-off [V] [experiment]" in [ + event.name for event in model_I.events + ] # fails if trying to set up with something that isn't an experiment - with self.assertRaisesRegex(TypeError, "experiment must be"): + with pytest.raises(TypeError, match="experiment must be"): pybamm.Simulation(model, experiment=0) def test_setup_experiment_string_or_list(self): @@ -54,17 +52,14 @@ def test_setup_experiment_string_or_list(self): sim = pybamm.Simulation(model, experiment="Discharge at C/20 for 1 hour") sim.build_for_experiment() - self.assertEqual(len(sim.experiment.steps), 1) - self.assertEqual( - sim.experiment.steps[0].description, - "Discharge at C/20 for 1 hour", - ) + assert len(sim.experiment.steps) == 1 + assert sim.experiment.steps[0].description == "Discharge at C/20 for 1 hour" sim = pybamm.Simulation( model, experiment=["Discharge at C/20 for 1 hour", pybamm.step.rest(60)], ) sim.build_for_experiment() - self.assertEqual(len(sim.experiment.steps), 2) + assert len(sim.experiment.steps) == 2 def test_run_experiment(self): s = pybamm.step.string @@ -84,8 +79,8 @@ def test_run_experiment(self): sim = pybamm.Simulation(model, experiment=experiment) # test the callback here sol = sim.solve(callbacks=pybamm.callbacks.Callback()) - self.assertEqual(sol.termination, "final time") - self.assertEqual(len(sol.cycles), 1) + assert sol.termination == "final time" + assert len(sol.cycles) == 1 # Test outputs np.testing.assert_array_equal(sol.cycles[0].steps[0]["C-rate"].data, 1 / 20) @@ -128,23 +123,23 @@ def test_run_experiment(self): # Solve again starting from solution sol2 = sim.solve(starting_solution=sol) - self.assertEqual(sol2.termination, "final time") - self.assertGreater(sol2.t[-1], sol.t[-1]) - self.assertEqual(sol2.cycles[0], sol.cycles[0]) - self.assertEqual(len(sol2.cycles), 2) + assert sol2.termination == "final time" + assert sol2.t[-1] > sol.t[-1] + assert sol2.cycles[0] == sol.cycles[0] + assert len(sol2.cycles) == 2 # Solve again starting from solution but only inputting the cycle sol2 = sim.solve(starting_solution=sol.cycles[-1]) - self.assertEqual(sol2.termination, "final time") - self.assertGreater(sol2.t[-1], sol.t[-1]) - self.assertEqual(len(sol2.cycles), 2) + assert sol2.termination == "final time" + assert sol2.t[-1] > sol.t[-1] + assert len(sol2.cycles) == 2 # Check starting solution is unchanged - self.assertEqual(len(sol.cycles), 1) + assert len(sol.cycles) == 1 # save sol2.save("test_experiment.sav") sol3 = pybamm.load("test_experiment.sav") - self.assertEqual(len(sol3.cycles), 2) + assert len(sol3.cycles) == 2 os.remove("test_experiment.sav") def test_run_experiment_multiple_times(self): @@ -168,7 +163,9 @@ def test_run_experiment_multiple_times(self): sol1["Voltage [V]"].data, sol2["Voltage [V]"].data ) - @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") + @pytest.mark.skipif( + not pybamm.has_idaklu(), reason="idaklu solver is not installed" + ) def test_run_experiment_cccv_solvers(self): experiment_2step = pybamm.Experiment( [ @@ -199,9 +196,11 @@ def test_run_experiment_cccv_solvers(self): solutions[1]["Current [A]"](solutions[0].t), decimal=0, ) - self.assertEqual(solutions[1].termination, "final time") + assert solutions[1].termination == "final time" - @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") + @pytest.mark.skipif( + not pybamm.has_idaklu(), reason="idaklu solver is not installed" + ) def test_solve_with_sensitivities_and_experiment(self): experiment_2step = pybamm.Experiment( [ @@ -276,7 +275,7 @@ def test_solve_with_sensitivities_and_experiment(self): ) / len(sens_casadi) ) - self.assertLess(error, 1.0) + assert error < 1.0 def test_run_experiment_drive_cycle(self): drive_cycle = np.array([np.arange(10), np.arange(10)]).T @@ -292,9 +291,8 @@ def test_run_experiment_drive_cycle(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertEqual( - sorted([step.basic_repr() for step in experiment.steps]), - sorted(list(sim.experiment_unique_steps_to_model.keys())), + assert sorted([step.basic_repr() for step in experiment.steps]) == sorted( + list(sim.experiment_unique_steps_to_model.keys()) ) def test_run_experiment_breaks_early_infeasible(self): @@ -308,7 +306,7 @@ def test_run_experiment_breaks_early_infeasible(self): t_eval, solver=pybamm.CasadiSolver(), callbacks=pybamm.callbacks.Callback() ) pybamm.set_logging_level("WARNING") - self.assertEqual(sim._solution.termination, "event: Minimum voltage [V]") + assert sim._solution.termination == "event: Minimum voltage [V]" def test_run_experiment_breaks_early_error(self): s = pybamm.step.string @@ -331,8 +329,8 @@ def test_run_experiment_breaks_early_error(self): solver=solver, ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 # Different experiment setup style experiment = pybamm.Experiment( @@ -348,8 +346,8 @@ def test_run_experiment_breaks_early_error(self): solver=solver, ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 # Different callback - this is for coverage on the `Callback` class sol = sim.solve(callbacks=pybamm.callbacks.Callback()) @@ -364,8 +362,8 @@ def test_run_experiment_infeasible_time(self): model, parameter_values=parameter_values, experiment=experiment ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 def test_run_experiment_termination_capacity(self): # with percent @@ -442,7 +440,7 @@ def test_run_experiment_termination_voltage(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(4, np.min(sol.cycles[0]["Voltage [V]"].data)) np.testing.assert_array_less(np.min(sol.cycles[1]["Voltage [V]"].data), 4) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_min(self): experiment = pybamm.Experiment( @@ -460,7 +458,7 @@ def test_run_experiment_termination_time_min(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1500) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1500) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_s(self): experiment = pybamm.Experiment( @@ -478,7 +476,7 @@ def test_run_experiment_termination_time_s(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1500) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1500) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_h(self): experiment = pybamm.Experiment( @@ -496,7 +494,7 @@ def test_run_experiment_termination_time_h(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1800) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1800) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_save_at_cycles(self): experiment = pybamm.Experiment( @@ -516,22 +514,22 @@ def test_save_at_cycles(self): ) # Solution saves "None" for the cycles that are not saved for cycle_num in [2, 4, 6, 8]: - self.assertIsNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is None for cycle_num in [0, 1, 3, 5, 7, 9]: - self.assertIsNotNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is not None # Summary variables are not None - self.assertIsNotNone(sol.summary_variables["Capacity [A.h]"]) + assert sol.summary_variables["Capacity [A.h]"] is not None sol = sim.solve( solver=pybamm.CasadiSolver("fast with events"), save_at_cycles=[3, 4, 5, 9] ) # Note offset by 1 (0th cycle is cycle 1) for cycle_num in [1, 5, 6, 7]: - self.assertIsNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is None for cycle_num in [0, 2, 3, 4, 8, 9]: # first & last cycle always saved - self.assertIsNotNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is not None # Summary variables are not None - self.assertIsNotNone(sol.summary_variables["Capacity [A.h]"]) + assert sol.summary_variables["Capacity [A.h]"] is not None def test_cycle_summary_variables(self): # Test cycle_summary_variables works for different combinations of data and @@ -633,8 +631,8 @@ def test_run_experiment_skip_steps(self): model, parameter_values=parameter_values, experiment=experiment ) sol = sim.solve() - self.assertIsInstance(sol.cycles[0].steps[0], pybamm.EmptySolution) - self.assertIsInstance(sol.cycles[0].steps[3], pybamm.EmptySolution) + assert isinstance(sol.cycles[0].steps[0], pybamm.EmptySolution) + assert isinstance(sol.cycles[0].steps[3], pybamm.EmptySolution) # Should get the same result if we run without the charge steps # since they are skipped @@ -689,9 +687,9 @@ def test_all_empty_solution_errors(self): sim = pybamm.Simulation( model, parameter_values=parameter_values, experiment=experiment ) - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - "Step 'Charge at 1C until 4.2V' is infeasible due to exceeded bounds", + match="Step 'Charge at 1C until 4.2V' is infeasible due to exceeded bounds", ): sim.solve() @@ -702,7 +700,7 @@ def test_all_empty_solution_errors(self): sim = pybamm.Simulation( model, parameter_values=parameter_values, experiment=experiment ) - with self.assertRaisesRegex(pybamm.SolverError, "All steps in the cycle"): + with pytest.raises(pybamm.SolverError, match="All steps in the cycle"): sim.solve() def test_solver_error(self): @@ -719,7 +717,7 @@ def test_solver_error(self): solver=pybamm.CasadiSolver(mode="fast"), ) - with self.assertRaisesRegex(pybamm.SolverError, "IDA_CONV_FAIL"): + with pytest.raises(pybamm.SolverError, match="IDA_CONV_FAIL"): sim.solve() def test_run_experiment_half_cell(self): @@ -749,9 +747,7 @@ def test_padding_rest_model(self): experiment = pybamm.Experiment(["Rest for 1 hour"]) sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertNotIn( - "Rest for padding", sim.experiment_unique_steps_to_model.keys() - ) + assert "Rest for padding" not in sim.experiment_unique_steps_to_model.keys() # Test padding rest model exists if there are start_times experiment = pybamm.step.string( @@ -759,13 +755,13 @@ def test_padding_rest_model(self): ) sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertIn("Rest for padding", sim.experiment_unique_steps_to_model.keys()) + assert "Rest for padding" in sim.experiment_unique_steps_to_model.keys() # Check at least there is an input parameter (temperature) - self.assertGreater( - len(sim.experiment_unique_steps_to_model["Rest for padding"].parameters), 0 + assert ( + len(sim.experiment_unique_steps_to_model["Rest for padding"].parameters) > 0 ) # Check the model is the same - self.assertIsInstance( + assert isinstance( sim.experiment_unique_steps_to_model["Rest for padding"], pybamm.lithium_ion.SPM, ) @@ -787,7 +783,7 @@ def test_run_start_time_experiment(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual(sol["Time [s]"].entries[-1], 5400) + assert sol["Time [s]"].entries[-1] == 5400 # Test padding rest is added if time stamp is late experiment = pybamm.Experiment( @@ -803,7 +799,7 @@ def test_run_start_time_experiment(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual(sol["Time [s]"].entries[-1], 10800) + assert sol["Time [s]"].entries[-1] == 10800 def test_starting_solution(self): model = pybamm.lithium_ion.SPM() @@ -820,7 +816,7 @@ def test_starting_solution(self): solution = sim.solve(save_at_cycles=[1]) # test that the last state is correct (i.e. final cycle is saved) - self.assertEqual(solution.last_state.t[-1], 1200) + assert solution.last_state.t[-1] == 1200 experiment = pybamm.Experiment( [ @@ -833,7 +829,7 @@ def test_starting_solution(self): new_solution = sim.solve(calc_esoh=False, starting_solution=solution) # test that the final time is correct (i.e. starting solution correctly set) - self.assertEqual(new_solution["Time [s]"].entries[-1], 3600) + assert new_solution["Time [s]"].entries[-1] == 3600 def test_experiment_start_time_starting_solution(self): model = pybamm.lithium_ion.SPM() @@ -855,7 +851,7 @@ def test_experiment_start_time_starting_solution(self): ) sim = pybamm.Simulation(model, experiment=experiment) - with self.assertRaisesRegex(ValueError, "experiments with `start_time`"): + with pytest.raises(ValueError, match="experiments with `start_time`"): sim.solve(starting_solution=solution) # Test starting_solution works well with start_time @@ -892,7 +888,7 @@ def test_experiment_start_time_starting_solution(self): new_solution = sim.solve(starting_solution=solution) # test that the final time is correct (i.e. starting solution correctly set) - self.assertEqual(new_solution["Time [s]"].entries[-1], 5400) + assert new_solution["Time [s]"].entries[-1] == 5400 def test_experiment_start_time_identical_steps(self): # Test that if we have the same step twice, with different start times, @@ -918,15 +914,15 @@ def test_experiment_start_time_identical_steps(self): sim.solve(calc_esoh=False) # Check that there are 4 steps - self.assertEqual(len(experiment.steps), 4) + assert len(experiment.steps) == 4 # Check that there are only 2 unique steps - self.assertEqual(len(sim.experiment.unique_steps), 2) + assert len(sim.experiment.unique_steps) == 2 # Check that there are only 3 built models (unique steps + padding rest) - self.assertEqual(len(sim.steps_to_built_models), 3) + assert len(sim.steps_to_built_models) == 3 - def test_experiment_custom_steps(self): + def test_experiment_custom_steps(self, subtests): model = pybamm.lithium_ion.SPM() # Explicit control @@ -947,7 +943,7 @@ def custom_step_voltage(variables): return 100 * (variables["Voltage [V]"] - 4.2) for control in ["differential"]: - with self.subTest(control=control): + with subtests.test(control=control): custom_step_alg = pybamm.step.CustomStepImplicit( custom_step_voltage, control=control, duration=100, period=10 ) @@ -974,20 +970,10 @@ def neg_stoich_cutoff(variables): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual( - sol.cycles[0].steps[0].termination, - "event: Negative stoichiometry cut-off [experiment]", + assert ( + sol.cycles[0].steps[0].termination + == "event: Negative stoichiometry cut-off [experiment]" ) neg_stoich = sol["Negative electrode stoichiometry"].data - self.assertAlmostEqual(neg_stoich[-1], 0.5, places=4) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert neg_stoich[-1] == pytest.approx(0.5, abs=0.0001) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index e1a14206b4..206aa0799c 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -2,8 +2,8 @@ # Tests for the Binary Operator classes # -import unittest -import unittest.mock as mock +import pytest + import numpy as np from scipy.sparse import coo_matrix @@ -19,19 +19,19 @@ } -class TestBinaryOperators(unittest.TestCase): +class TestBinaryOperators: def test_binary_operator(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") bin = pybamm.BinaryOperator("binary test", a, b) - self.assertEqual(bin.children[0].name, a.name) - self.assertEqual(bin.children[1].name, b.name) + assert bin.children[0].name == a.name + assert bin.children[1].name == b.name c = pybamm.Scalar(1) d = pybamm.Scalar(2) bin2 = pybamm.BinaryOperator("binary test", c, d) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): bin2.evaluate() - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): bin2._binary_jac(a, b) def test_binary_operator_domains(self): @@ -39,28 +39,28 @@ def test_binary_operator_domains(self): a = pybamm.Symbol("a", domain=["negative electrode"]) b = pybamm.Symbol("b", domain=["negative electrode"]) bin1 = pybamm.BinaryOperator("binary test", a, b) - self.assertEqual(bin1.domain, ["negative electrode"]) + assert bin1.domain == ["negative electrode"] # one empty domain c = pybamm.Symbol("c", domain=[]) bin2 = pybamm.BinaryOperator("binary test", a, c) - self.assertEqual(bin2.domain, ["negative electrode"]) + assert bin2.domain == ["negative electrode"] bin3 = pybamm.BinaryOperator("binary test", c, b) - self.assertEqual(bin3.domain, ["negative electrode"]) + assert bin3.domain == ["negative electrode"] # mismatched domains d = pybamm.Symbol("d", domain=["positive electrode"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.BinaryOperator("binary test", a, d) def test_addition(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") summ = pybamm.Addition(a, b) - self.assertEqual(summ.children[0].name, a.name) - self.assertEqual(summ.children[1].name, b.name) + assert summ.children[0].name == a.name + assert summ.children[1].name == b.name # test simplifying summ2 = pybamm.Scalar(1) + pybamm.Scalar(3) - self.assertEqual(summ2, pybamm.Scalar(4)) + assert summ2 == pybamm.Scalar(4) def test_addition_numpy_array(self): a = pybamm.Symbol("a") @@ -68,32 +68,32 @@ def test_addition_numpy_array(self): # converts numpy array to vector array = np.array([1, 2, 3]) summ3 = pybamm.Addition(a, array) - self.assertIsInstance(summ3, pybamm.Addition) - self.assertIsInstance(summ3.children[0], pybamm.Symbol) - self.assertIsInstance(summ3.children[1], pybamm.Vector) + assert isinstance(summ3, pybamm.Addition) + assert isinstance(summ3.children[0], pybamm.Symbol) + assert isinstance(summ3.children[1], pybamm.Vector) summ4 = array + a - self.assertIsInstance(summ4.children[0], pybamm.Vector) + assert isinstance(summ4.children[0], pybamm.Vector) # should error if numpy array is not 1D array = np.array([[1, 2, 3], [4, 5, 6]]) - with self.assertRaisesRegex(ValueError, "left must be a 1D array"): + with pytest.raises(ValueError, match="left must be a 1D array"): pybamm.Addition(array, a) - with self.assertRaisesRegex(ValueError, "right must be a 1D array"): + with pytest.raises(ValueError, match="right must be a 1D array"): pybamm.Addition(a, array) def test_power(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") pow1 = pybamm.Power(a, b) - self.assertEqual(pow1.name, "**") - self.assertEqual(pow1.children[0].name, a.name) - self.assertEqual(pow1.children[1].name, b.name) + assert pow1.name == "**" + assert pow1.children[0].name == a.name + assert pow1.children[1].name == b.name a = pybamm.Scalar(4) b = pybamm.Scalar(2) pow2 = pybamm.Power(a, b) - self.assertEqual(pow2.evaluate(), 16) + assert pow2.evaluate() == 16 def test_diff(self): a = pybamm.StateVector(slice(0, 1)) @@ -101,51 +101,51 @@ def test_diff(self): y = np.array([5, 3]) # power - self.assertEqual((a**b).diff(b).evaluate(y=y), 5**3 * np.log(5)) - self.assertEqual((a**b).diff(a).evaluate(y=y), 3 * 5**2) - self.assertEqual((a**b).diff(a**b).evaluate(), 1) - self.assertEqual((a**a).diff(a).evaluate(y=y), 5**5 * np.log(5) + 5 * 5**4) - self.assertEqual((a**a).diff(b).evaluate(y=y), 0) + assert (a**b).diff(b).evaluate(y=y) == 5**3 * np.log(5) + assert (a**b).diff(a).evaluate(y=y) == 3 * 5**2 + assert (a**b).diff(a**b).evaluate() == 1 + assert (a**a).diff(a).evaluate(y=y) == 5**5 * np.log(5) + 5 * 5**4 + assert (a**a).diff(b).evaluate(y=y) == 0 # addition - self.assertEqual((a + b).diff(a).evaluate(), 1) - self.assertEqual((a + b).diff(b).evaluate(), 1) - self.assertEqual((a + b).diff(a + b).evaluate(), 1) - self.assertEqual((a + a).diff(a).evaluate(), 2) - self.assertEqual((a + a).diff(b).evaluate(), 0) + assert (a + b).diff(a).evaluate() == 1 + assert (a + b).diff(b).evaluate() == 1 + assert (a + b).diff(a + b).evaluate() == 1 + assert (a + a).diff(a).evaluate() == 2 + assert (a + a).diff(b).evaluate() == 0 # subtraction - self.assertEqual((a - b).diff(a).evaluate(), 1) - self.assertEqual((a - b).diff(b).evaluate(), -1) - self.assertEqual((a - b).diff(a - b).evaluate(), 1) - self.assertEqual((a - a).diff(a).evaluate(), 0) - self.assertEqual((a + a).diff(b).evaluate(), 0) + assert (a - b).diff(a).evaluate() == 1 + assert (a - b).diff(b).evaluate() == -1 + assert (a - b).diff(a - b).evaluate() == 1 + assert (a - a).diff(a).evaluate() == 0 + assert (a + a).diff(b).evaluate() == 0 # multiplication - self.assertEqual((a * b).diff(a).evaluate(y=y), 3) - self.assertEqual((a * b).diff(b).evaluate(y=y), 5) - self.assertEqual((a * b).diff(a * b).evaluate(y=y), 1) - self.assertEqual((a * a).diff(a).evaluate(y=y), 10) - self.assertEqual((a * a).diff(b).evaluate(y=y), 0) + assert (a * b).diff(a).evaluate(y=y) == 3 + assert (a * b).diff(b).evaluate(y=y) == 5 + assert (a * b).diff(a * b).evaluate(y=y) == 1 + assert (a * a).diff(a).evaluate(y=y) == 10 + assert (a * a).diff(b).evaluate(y=y) == 0 # matrix multiplication (not implemented) matmul = a @ b - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): matmul.diff(a) # inner - self.assertEqual(pybamm.inner(a, b).diff(a).evaluate(y=y), 3) - self.assertEqual(pybamm.inner(a, b).diff(b).evaluate(y=y), 5) - self.assertEqual(pybamm.inner(a, b).diff(pybamm.inner(a, b)).evaluate(y=y), 1) - self.assertEqual(pybamm.inner(a, a).diff(a).evaluate(y=y), 10) - self.assertEqual(pybamm.inner(a, a).diff(b).evaluate(y=y), 0) + assert pybamm.inner(a, b).diff(a).evaluate(y=y) == 3 + assert pybamm.inner(a, b).diff(b).evaluate(y=y) == 5 + assert pybamm.inner(a, b).diff(pybamm.inner(a, b)).evaluate(y=y) == 1 + assert pybamm.inner(a, a).diff(a).evaluate(y=y) == 10 + assert pybamm.inner(a, a).diff(b).evaluate(y=y) == 0 # division - self.assertEqual((a / b).diff(a).evaluate(y=y), 1 / 3) - self.assertEqual((a / b).diff(b).evaluate(y=y), -5 / 9) - self.assertEqual((a / b).diff(a / b).evaluate(y=y), 1) - self.assertEqual((a / a).diff(a).evaluate(y=y), 0) - self.assertEqual((a / a).diff(b).evaluate(y=y), 0) + assert (a / b).diff(a).evaluate(y=y) == 1 / 3 + assert (a / b).diff(b).evaluate(y=y) == -5 / 9 + assert (a / b).diff(a / b).evaluate(y=y) == 1 + assert (a / a).diff(a).evaluate(y=y) == 0 + assert (a / a).diff(b).evaluate(y=y) == 0 def test_printing(self): # This in not an exhaustive list of all cases. More test cases may need to @@ -154,23 +154,23 @@ def test_printing(self): b = pybamm.Parameter("b") c = pybamm.Parameter("c") d = pybamm.Parameter("d") - self.assertEqual(str(a + b), "a + b") - self.assertEqual(str(a + b + c + d), "a + b + c + d") - self.assertEqual(str((a + b) + (c + d)), "a + b + c + d") - self.assertEqual(str(a + b - c), "a + b - c") - self.assertEqual(str(a + b - c + d), "a + b - c + d") - self.assertEqual(str((a + b) - (c + d)), "a + b - (c + d)") - self.assertEqual(str((a + b) - (c - d)), "a + b - (c - d)") - - self.assertEqual(str((a + b) * (c + d)), "(a + b) * (c + d)") - self.assertEqual(str(a * b * (c + d)), "a * b * (c + d)") - self.assertEqual(str((a * b) * (c + d)), "a * b * (c + d)") - self.assertEqual(str(a * (b * (c + d))), "a * b * (c + d)") - self.assertEqual(str((a + b) / (c + d)), "(a + b) / (c + d)") - self.assertEqual(str(a + b / (c + d)), "a + b / (c + d)") - self.assertEqual(str(a * b / (c + d)), "a * b / (c + d)") - self.assertEqual(str((a * b) / (c + d)), "a * b / (c + d)") - self.assertEqual(str(a * (b / (c + d))), "a * b / (c + d)") + assert str(a + b) == "a + b" + assert str(a + b + c + d) == "a + b + c + d" + assert str((a + b) + (c + d)) == "a + b + c + d" + assert str(a + b - c) == "a + b - c" + assert str(a + b - c + d) == "a + b - c + d" + assert str((a + b) - (c + d)) == "a + b - (c + d)" + assert str((a + b) - (c - d)) == "a + b - (c - d)" + + assert str((a + b) * (c + d)) == "(a + b) * (c + d)" + assert str(a * b * (c + d)) == "a * b * (c + d)" + assert str((a * b) * (c + d)) == "a * b * (c + d)" + assert str(a * (b * (c + d))) == "a * b * (c + d)" + assert str((a + b) / (c + d)) == "(a + b) / (c + d)" + assert str(a + b / (c + d)) == "a + b / (c + d)" + assert str(a * b / (c + d)) == "a * b / (c + d)" + assert str((a * b) / (c + d)) == "a * b / (c + d)" + assert str(a * (b / (c + d))) == "a * b / (c + d)" def test_eq(self): a = pybamm.Scalar(4) @@ -178,20 +178,20 @@ def test_eq(self): bin1 = pybamm.BinaryOperator("test", a, b) bin2 = pybamm.BinaryOperator("test", a, b) bin3 = pybamm.BinaryOperator("new test", a, b) - self.assertEqual(bin1, bin2) - self.assertNotEqual(bin1, bin3) + assert bin1 == bin2 + assert bin1 != bin3 c = pybamm.Scalar(5) bin4 = pybamm.BinaryOperator("test", a, c) - self.assertEqual(bin1, bin4) + assert bin1 == bin4 d = pybamm.Scalar(42) bin5 = pybamm.BinaryOperator("test", a, d) - self.assertNotEqual(bin1, bin5) + assert bin1 != bin5 def test_number_overloading(self): a = pybamm.Scalar(4) prod = a * 3 - self.assertIsInstance(prod, pybamm.Scalar) - self.assertEqual(prod.evaluate(), 12) + assert isinstance(prod, pybamm.Scalar) + assert prod.evaluate() == 12 def test_sparse_multiply(self): row = np.array([0, 3, 1, 0]) @@ -225,11 +225,11 @@ def test_sparse_multiply(self): np.testing.assert_array_equal( (pybammS2 * pybammD2).evaluate().toarray(), S2.toarray() * D2 ) - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS1 * pybammS2).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS2 * pybammS1).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS2 * pybammS1).evaluate_ignoring_errors() # Matrix multiplication is normal matrix multiplication @@ -243,9 +243,9 @@ def test_sparse_multiply(self): np.testing.assert_array_equal((pybammD2 @ pybammS1).evaluate(), D2 * S1) np.testing.assert_array_equal((pybammS2 @ pybammD1).evaluate(), S2 * D1) np.testing.assert_array_equal((pybammD1 @ pybammS2).evaluate(), D1 * S2) - with self.assertRaisesRegex(pybamm.ShapeError, "dimension mismatch"): + with pytest.raises(pybamm.ShapeError, match="dimension mismatch"): (pybammS1 @ pybammS1).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "dimension mismatch"): + with pytest.raises(pybamm.ShapeError, match="dimension mismatch"): (pybammS2 @ pybammS2).test_shape() def test_sparse_divide(self): @@ -294,133 +294,133 @@ def test_inner(self): disc.process_model(model) # check doesn't evaluate on edges anymore - self.assertEqual(model.variables["inner"].evaluates_on_edges("primary"), False) + assert not model.variables["inner"].evaluates_on_edges("primary") def test_source(self): u = pybamm.Variable("u", domain="current collector") v = pybamm.Variable("v", domain="current collector") source = pybamm.source(u, v) - self.assertIsInstance(source.children[0], pybamm.Mass) + assert isinstance(source.children[0], pybamm.Mass) boundary_source = pybamm.source(u, v, boundary=True) - self.assertIsInstance(boundary_source.children[0], pybamm.BoundaryMass) + assert isinstance(boundary_source.children[0], pybamm.BoundaryMass) def test_source_error(self): # test error with domain not current collector v = pybamm.Vector(np.ones(5), domain="current collector") w = pybamm.Vector(2 * np.ones(3), domain="test") - with self.assertRaisesRegex(pybamm.DomainError, "'source'"): + with pytest.raises(pybamm.DomainError, match="'source'"): pybamm.source(v, w) def test_heaviside(self): b = pybamm.StateVector(slice(0, 1)) heav = 1 < b - self.assertEqual(heav.evaluate(y=np.array([2])), 1) - self.assertEqual(heav.evaluate(y=np.array([1])), 0) - self.assertEqual(heav.evaluate(y=np.array([0])), 0) - self.assertEqual(str(heav), "1.0 < y[0:1]") + assert heav.evaluate(y=np.array([2])) == 1 + assert heav.evaluate(y=np.array([1])) == 0 + assert heav.evaluate(y=np.array([0])) == 0 + assert str(heav) == "1.0 < y[0:1]" heav = 1 >= b - self.assertEqual(heav.evaluate(y=np.array([2])), 0) - self.assertEqual(heav.evaluate(y=np.array([1])), 1) - self.assertEqual(heav.evaluate(y=np.array([0])), 1) - self.assertEqual(str(heav), "y[0:1] <= 1.0") + assert heav.evaluate(y=np.array([2])) == 0 + assert heav.evaluate(y=np.array([1])) == 1 + assert heav.evaluate(y=np.array([0])) == 1 + assert str(heav) == "y[0:1] <= 1.0" # simplifications - self.assertEqual(1 < b + 2, -1 < b) - self.assertEqual(b + 1 > 2, b > 1) + assert (1 < b + 2) == (-1 < b) + assert (b + 1 > 2) == (b > 1) # expression with a subtract expr = 2 * (b < 1) - (b > 3) - self.assertEqual(expr.evaluate(y=np.array([0])), 2) - self.assertEqual(expr.evaluate(y=np.array([2])), 0) - self.assertEqual(expr.evaluate(y=np.array([4])), -1) + assert expr.evaluate(y=np.array([0])) == 2 + assert expr.evaluate(y=np.array([2])) == 0 + assert expr.evaluate(y=np.array([4])) == -1 def test_equality(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) equal = pybamm.Equality(a, b) - self.assertEqual(equal.evaluate(y=np.array([1])), 1) - self.assertEqual(equal.evaluate(y=np.array([2])), 0) - self.assertEqual(str(equal), "1.0 == y[0:1]") - self.assertEqual(equal.diff(b), 0) + assert equal.evaluate(y=np.array([1])) == 1 + assert equal.evaluate(y=np.array([2])) == 0 + assert str(equal) == "1.0 == y[0:1]" + assert equal.diff(b) == 0 def test_sigmoid(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) sigm = pybamm.sigmoid(a, b, 10) - self.assertAlmostEqual(sigm.evaluate(y=np.array([2]))[0, 0], 1) - self.assertEqual(sigm.evaluate(y=np.array([1])), 0.5) - self.assertAlmostEqual(sigm.evaluate(y=np.array([0]))[0, 0], 0) - self.assertEqual(str(sigm), "0.5 + 0.5 * tanh(-10.0 + 10.0 * y[0:1])") + assert sigm.evaluate(y=np.array([2]))[0, 0] == pytest.approx(1) + assert sigm.evaluate(y=np.array([1])) == 0.5 + pytest.approx(sigm.evaluate(y=np.array([0]))[0, 0], abs=0) + assert str(sigm) == "0.5 + 0.5 * tanh(-10.0 + 10.0 * y[0:1])" sigm = pybamm.sigmoid(b, a, 10) - self.assertAlmostEqual(sigm.evaluate(y=np.array([2]))[0, 0], 0) - self.assertEqual(sigm.evaluate(y=np.array([1])), 0.5) - self.assertAlmostEqual(sigm.evaluate(y=np.array([0]))[0, 0], 1) - self.assertEqual(str(sigm), "0.5 + 0.5 * tanh(10.0 - (10.0 * y[0:1]))") + pytest.approx(sigm.evaluate(y=np.array([2]))[0, 0], abs=0) + assert sigm.evaluate(y=np.array([1])) == 0.5 + pytest.approx(sigm.evaluate(y=np.array([0]))[0, 0], abs=1) + assert str(sigm) == "0.5 + 0.5 * tanh(10.0 - (10.0 * y[0:1]))" def test_modulo(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.Scalar(3) mod = a % b - self.assertEqual(mod.evaluate(y=np.array([4]))[0, 0], 1) - self.assertEqual(mod.evaluate(y=np.array([3]))[0, 0], 0) - self.assertEqual(mod.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(mod.evaluate(y=np.array([4.3]))[0, 0], 1.3) - self.assertAlmostEqual(mod.evaluate(y=np.array([2.2]))[0, 0], 2.2) - self.assertEqual(str(mod), "y[0:1] mod 3.0") + assert mod.evaluate(y=np.array([4]))[0, 0] == 1 + assert mod.evaluate(y=np.array([3]))[0, 0] == 0 + assert mod.evaluate(y=np.array([2]))[0, 0] == 2 + assert mod.evaluate(y=np.array([4.3]))[0, 0] == pytest.approx(1.3) + assert mod.evaluate(y=np.array([2.2]))[0, 0] == pytest.approx(2.2) + assert str(mod) == "y[0:1] mod 3.0" def test_minimum_maximum(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.minimum(a, b) - self.assertEqual(minimum.evaluate(y=np.array([2])), 1) - self.assertEqual(minimum.evaluate(y=np.array([1])), 1) - self.assertEqual(minimum.evaluate(y=np.array([0])), 0) - self.assertEqual(str(minimum), "minimum(1.0, y[0:1])") + assert minimum.evaluate(y=np.array([2])) == 1 + assert minimum.evaluate(y=np.array([1])) == 1 + assert minimum.evaluate(y=np.array([0])) == 0 + assert str(minimum) == "minimum(1.0, y[0:1])" maximum = pybamm.maximum(a, b) - self.assertEqual(maximum.evaluate(y=np.array([2])), 2) - self.assertEqual(maximum.evaluate(y=np.array([1])), 1) - self.assertEqual(maximum.evaluate(y=np.array([0])), 1) - self.assertEqual(str(maximum), "maximum(1.0, y[0:1])") + assert maximum.evaluate(y=np.array([2])) == 2 + assert maximum.evaluate(y=np.array([1])) == 1 + assert maximum.evaluate(y=np.array([0])) == 1 + assert str(maximum) == "maximum(1.0, y[0:1])" def test_softminus_softplus(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.softminus(a, b, 50) - self.assertAlmostEqual(minimum.evaluate(y=np.array([2]))[0, 0], 1) - self.assertAlmostEqual(minimum.evaluate(y=np.array([0]))[0, 0], 0) - self.assertEqual( - str(minimum), "-0.02 * log(1.9287498479639178e-22 + exp(-50.0 * y[0:1]))" + assert minimum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(1) + assert minimum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(0) + assert ( + str(minimum) == "-0.02 * log(1.9287498479639178e-22 + exp(-50.0 * y[0:1]))" ) maximum = pybamm.softplus(a, b, 50) - self.assertAlmostEqual(maximum.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(maximum.evaluate(y=np.array([0]))[0, 0], 1) - self.assertEqual( - str(maximum)[:20], - "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[:20], + assert maximum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(2) + assert maximum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(1) + assert ( + str(maximum)[:20] + == "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[:20] ) - self.assertEqual( - str(maximum)[-20:], - "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[-20:], + assert ( + str(maximum)[-20:] + == "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[-20:] ) # Test that smooth min/max are used when the setting is changed pybamm.settings.min_max_mode = "soft" pybamm.settings.min_max_smoothing = 10 - self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.softminus(a, b, 10))) - self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.softplus(a, b, 10))) + assert str(pybamm.minimum(a, b)) == str(pybamm.softminus(a, b, 10)) + assert str(pybamm.maximum(a, b)) == str(pybamm.softplus(a, b, 10)) # But exact min/max should still be used if both variables are constant a = pybamm.Scalar(1) b = pybamm.Scalar(2) - self.assertEqual(str(pybamm.minimum(a, b)), str(a)) - self.assertEqual(str(pybamm.maximum(a, b)), str(b)) + assert str(pybamm.minimum(a, b)) == str(a) + assert str(pybamm.maximum(a, b)) == str(b) # Change setting back for other tests pybamm.settings.set_smoothing_parameters("exact") @@ -430,36 +430,34 @@ def test_smooth_minus_plus(self): b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.smooth_min(a, b, 3000) - self.assertAlmostEqual(minimum.evaluate(y=np.array([2]))[0, 0], 1) - self.assertAlmostEqual(minimum.evaluate(y=np.array([0]))[0, 0], 0) + pytest.approx(minimum.evaluate(y=np.array([2]))[0, 0], abs=1) + pytest.approx(minimum.evaluate(y=np.array([0]))[0, 0], abs=0) maximum = pybamm.smooth_max(a, b, 3000) - self.assertAlmostEqual(maximum.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(maximum.evaluate(y=np.array([0]))[0, 0], 1) + assert maximum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(2) + assert maximum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(1) minimum = pybamm.smooth_min(a, b, 1) - self.assertEqual( - str(minimum), - "0.5 * (1.0 + y[0:1] - sqrt(1.0 + (1.0 - y[0:1]) ** 2.0))", + assert ( + str(minimum) == "0.5 * (1.0 + y[0:1] - sqrt(1.0 + (1.0 - y[0:1]) ** 2.0))" ) maximum = pybamm.smooth_max(a, b, 1) - self.assertEqual( - str(maximum), - "0.5 * (sqrt(1.0 + (1.0 - y[0:1]) ** 2.0) + 1.0 + y[0:1])", + assert ( + str(maximum) == "0.5 * (sqrt(1.0 + (1.0 - y[0:1]) ** 2.0) + 1.0 + y[0:1])" ) # Test that smooth min/max are used when the setting is changed pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 1 - self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.smooth_min(a, b, 1))) - self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.smooth_max(a, b, 1))) + assert str(pybamm.minimum(a, b)) == str(pybamm.smooth_min(a, b, 1)) + assert str(pybamm.maximum(a, b)) == str(pybamm.smooth_max(a, b, 1)) pybamm.settings.min_max_smoothing = 3000 a = pybamm.Scalar(1) b = pybamm.Scalar(2) - self.assertEqual(str(pybamm.minimum(a, b)), str(a)) - self.assertEqual(str(pybamm.maximum(a, b)), str(b)) + assert str(pybamm.minimum(a, b)) == str(a) + assert str(pybamm.maximum(a, b)) == str(b) # Change setting back for other tests pybamm.settings.set_smoothing_parameters("exact") @@ -480,134 +478,133 @@ def test_binary_simplifications(self): broad2_edge = pybamm.PrimaryBroadcastToEdges(2, "domain") # power - self.assertEqual((c**0), pybamm.Scalar(1)) - self.assertEqual((0**c), pybamm.Scalar(0)) - self.assertEqual((c**1), c) + assert (c**0) == pybamm.Scalar(1) + assert (0**c) == pybamm.Scalar(0) + assert (c**1) == c # power with broadcasts - self.assertEqual((c**broad2), pybamm.PrimaryBroadcast(c**2, "domain")) - self.assertEqual((broad2**c), pybamm.PrimaryBroadcast(2**c, "domain")) - self.assertEqual( - (broad2 ** pybamm.PrimaryBroadcast(c, "domain")), - pybamm.PrimaryBroadcast(2**c, "domain"), - ) + assert (c**broad2) == pybamm.PrimaryBroadcast(c**2, "domain") + assert (broad2**c) == pybamm.PrimaryBroadcast(2**c, "domain") + assert ( + broad2 ** pybamm.PrimaryBroadcast(c, "domain") + ) == pybamm.PrimaryBroadcast(2**c, "domain") # power with broadcasts to edge - self.assertIsInstance(var**broad2_edge, pybamm.Power) - self.assertEqual((var**broad2_edge).left, var) - self.assertEqual((var**broad2_edge).right, broad2_edge) + assert isinstance(var**broad2_edge, pybamm.Power) + assert (var**broad2_edge).left == var + assert (var**broad2_edge).right == broad2_edge # addition - self.assertEqual(a + b, pybamm.Scalar(1)) - self.assertEqual(b + b, pybamm.Scalar(2)) - self.assertEqual(b + a, pybamm.Scalar(1)) - self.assertEqual(0 + b, pybamm.Scalar(1)) - self.assertEqual(0 + c, c) - self.assertEqual(c + 0, c) + assert a + b == pybamm.Scalar(1) + assert b + b == pybamm.Scalar(2) + assert b + a == pybamm.Scalar(1) + assert 0 + b == pybamm.Scalar(1) + assert 0 + c == c + assert c + 0 == c # addition with subtraction - self.assertEqual(c + (d - c), d) - self.assertEqual((c - d) + d, c) + assert c + (d - c) == d + assert (c - d) + d == c # addition with broadcast zero - self.assertIsInstance((1 + broad0), pybamm.PrimaryBroadcast) + assert isinstance((1 + broad0), pybamm.PrimaryBroadcast) np.testing.assert_array_equal((1 + broad0).child.evaluate(), 1) np.testing.assert_array_equal((1 + broad0).domain, "domain") - self.assertIsInstance((broad0 + 1), pybamm.PrimaryBroadcast) + assert isinstance((broad0 + 1), pybamm.PrimaryBroadcast) np.testing.assert_array_equal((broad0 + 1).child.evaluate(), 1) np.testing.assert_array_equal((broad0 + 1).domain, "domain") # addition with broadcasts - self.assertEqual((c + broad2), pybamm.PrimaryBroadcast(c + 2, "domain")) - self.assertEqual((broad2 + c), pybamm.PrimaryBroadcast(2 + c, "domain")) + assert (c + broad2) == pybamm.PrimaryBroadcast(c + 2, "domain") + assert (broad2 + c) == pybamm.PrimaryBroadcast(2 + c, "domain") # addition with negate - self.assertEqual(c + -d, c - d) - self.assertEqual(-c + d, d - c) + assert c + -d == c - d + assert -c + d == d - c # subtraction - self.assertEqual(a - b, pybamm.Scalar(-1)) - self.assertEqual(b - b, pybamm.Scalar(0)) - self.assertEqual(b - a, pybamm.Scalar(1)) + assert a - b == pybamm.Scalar(-1) + assert b - b == pybamm.Scalar(0) + assert b - a == pybamm.Scalar(1) # subtraction with addition - self.assertEqual(c - (d + c), -d) - self.assertEqual(c - (c - d), d) - self.assertEqual((c + d) - d, c) - self.assertEqual((d + c) - d, c) - self.assertEqual((d - c) - d, -c) + assert c - (d + c) == -d + assert c - (c - d) == d + assert (c + d) - d == c + assert (d + c) - d == c + assert (d - c) - d == -c # subtraction with broadcasts - self.assertEqual((c - broad2), pybamm.PrimaryBroadcast(c - 2, "domain")) - self.assertEqual((broad2 - c), pybamm.PrimaryBroadcast(2 - c, "domain")) + assert (c - broad2) == pybamm.PrimaryBroadcast(c - 2, "domain") + assert (broad2 - c) == pybamm.PrimaryBroadcast(2 - c, "domain") # subtraction from itself - self.assertEqual((c - c), pybamm.Scalar(0)) - self.assertEqual((broad2 - broad2), broad0) + assert (c - c) == pybamm.Scalar(0) + assert (broad2 - broad2) == broad0 # subtraction with negate - self.assertEqual((c - (-d)), c + d) + assert (c - (-d)) == c + d # addition and subtraction with matrix zero - self.assertEqual(b + v, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(v + b, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(b - v, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(v - b, pybamm.Vector(-np.ones((10, 1)))) + assert b + v == pybamm.Vector(np.ones((10, 1))) + assert v + b == pybamm.Vector(np.ones((10, 1))) + assert b - v == pybamm.Vector(np.ones((10, 1))) + assert v - b == pybamm.Vector(-np.ones((10, 1))) # multiplication - self.assertEqual(a * b, pybamm.Scalar(0)) - self.assertEqual(b * a, pybamm.Scalar(0)) - self.assertEqual(b * b, pybamm.Scalar(1)) - self.assertEqual(a * a, pybamm.Scalar(0)) - self.assertEqual(a * c, pybamm.Scalar(0)) - self.assertEqual(c * a, pybamm.Scalar(0)) - self.assertEqual(b * c, c) + assert a * b == pybamm.Scalar(0) + assert b * a == pybamm.Scalar(0) + assert b * b == pybamm.Scalar(1) + assert a * a == pybamm.Scalar(0) + assert a * c == pybamm.Scalar(0) + assert c * a == pybamm.Scalar(0) + assert b * c == c # multiplication with -1 - self.assertEqual((c * -1), (-c)) - self.assertEqual((-1 * c), (-c)) + assert (c * -1) == (-c) + assert (-1 * c) == (-c) # multiplication with a negation - self.assertEqual((-c * -f), (c * f)) - self.assertEqual((-c * 4), (c * -4)) - self.assertEqual((4 * -c), (-4 * c)) + assert (-c * -f) == (c * f) + assert (-c * 4) == (c * -4) + assert (4 * -c) == (-4 * c) # multiplication with division - self.assertEqual((c * (d / c)), d) - self.assertEqual((c / d) * d, c) + assert (c * (d / c)) == d + assert (c / d) * d == c # multiplication with broadcasts - self.assertEqual((c * broad2), pybamm.PrimaryBroadcast(c * 2, "domain")) - self.assertEqual((broad2 * c), pybamm.PrimaryBroadcast(2 * c, "domain")) + assert (c * broad2) == pybamm.PrimaryBroadcast(c * 2, "domain") + assert (broad2 * c) == pybamm.PrimaryBroadcast(2 * c, "domain") # multiplication with matrix zero - self.assertEqual(b * v, pybamm.Vector(np.zeros((10, 1)))) - self.assertEqual(v * b, pybamm.Vector(np.zeros((10, 1)))) + assert b * v == pybamm.Vector(np.zeros((10, 1))) + assert v * b == pybamm.Vector(np.zeros((10, 1))) # multiplication with matrix one - self.assertEqual((f * v1), f) - self.assertEqual((v1 * f), f) + assert (f * v1) == f + assert (v1 * f) == f # multiplication with matrix minus one - self.assertEqual((f * (-v1)), (-f)) - self.assertEqual(((-v1) * f), (-f)) + assert (f * (-v1)) == (-f) + assert ((-v1) * f) == (-f) # multiplication with broadcast - self.assertEqual((var * broad2), (var * 2)) - self.assertEqual((broad2 * var), (2 * var)) + assert (var * broad2) == (var * 2) + assert (broad2 * var) == (2 * var) # multiplication with broadcast one - self.assertEqual((var * broad1), var) - self.assertEqual((broad1 * var), var) + assert (var * broad1) == var + assert (broad1 * var) == var # multiplication with broadcast minus one - self.assertEqual((var * -broad1), (-var)) - self.assertEqual((-broad1 * var), (-var)) + assert (var * -broad1) == (-var) + assert (-broad1 * var) == (-var) # division by itself - self.assertEqual((c / c), pybamm.Scalar(1)) - self.assertEqual((broad2 / broad2), broad1) + assert (c / c) == pybamm.Scalar(1) + assert (broad2 / broad2) == broad1 # division with a negation - self.assertEqual((-c / -f), (c / f)) - self.assertEqual((-c / 4), -0.25 * c) - self.assertEqual((4 / -c), (-4 / c)) + assert (-c / -f) == (c / f) + assert (-c / 4) == -0.25 * c + assert (4 / -c) == (-4 / c) # division with multiplication - self.assertEqual((c * d) / c, d) - self.assertEqual((d * c) / c, d) + assert (c * d) / c == d + assert (d * c) / c == d # division with broadcasts - self.assertEqual((c / broad2), pybamm.PrimaryBroadcast(c / 2, "domain")) - self.assertEqual((broad2 / c), pybamm.PrimaryBroadcast(2 / c, "domain")) + assert (c / broad2) == pybamm.PrimaryBroadcast(c / 2, "domain") + assert (broad2 / c) == pybamm.PrimaryBroadcast(2 / c, "domain") # division with matrix one - self.assertEqual((f / v1), f) - self.assertEqual((f / -v1), (-f)) + assert (f / v1) == f + assert (f / -v1) == (-f) # division by zero - with self.assertRaises(ZeroDivisionError): + with pytest.raises(ZeroDivisionError): b / a # division with a common term - self.assertEqual((2 * c) / (2 * var), (c / var)) - self.assertEqual((c * 2) / (var * 2), (c / var)) + assert (2 * c) / (2 * var) == (c / var) + assert (c * 2) / (var * 2) == (c / var) def test_binary_simplifications_concatenations(self): def conc_broad(x, y, z): @@ -625,10 +622,10 @@ def conc_broad(x, y, z): pybamm.InputParameter("y"), pybamm.InputParameter("z"), ) - self.assertEqual((a + 4), conc_broad(5, 6, 7)) - self.assertEqual((4 + a), conc_broad(5, 6, 7)) - self.assertEqual((a + b), conc_broad(12, 14, 16)) - self.assertIsInstance((a + c), pybamm.Concatenation) + assert (a + 4) == conc_broad(5, 6, 7) + assert (4 + a) == conc_broad(5, 6, 7) + assert (a + b) == conc_broad(12, 14, 16) + assert isinstance((a + c), pybamm.Concatenation) # No simplifications if all are Variable or StateVector objects v = pybamm.concatenation( @@ -636,8 +633,8 @@ def conc_broad(x, y, z): pybamm.Variable("y", "separator"), pybamm.Variable("z", "positive electrode"), ) - self.assertIsInstance((v * v), pybamm.Multiplication) - self.assertIsInstance((a * v), pybamm.Multiplication) + assert isinstance((v * v), pybamm.Multiplication) + assert isinstance((a * v), pybamm.Multiplication) def test_advanced_binary_simplifications(self): # MatMul simplifications that often appear when discretising spatial operators @@ -650,120 +647,120 @@ def test_advanced_binary_simplifications(self): # Do A@B first if it is constant expr = A @ (B @ var) - self.assertEqual(expr, ((A @ B) @ var)) + assert expr == ((A @ B) @ var) # Distribute the @ operator to a sum if one of the symbols being summed is # constant expr = A @ (var + vec) - self.assertEqual(expr, ((A @ var) + (A @ vec))) + assert expr == ((A @ var) + (A @ vec)) expr = A @ (var - vec) - self.assertEqual(expr, ((A @ var) - (A @ vec))) + assert expr == ((A @ var) - (A @ vec)) expr = A @ ((B @ var) + vec) - self.assertEqual(expr, (((A @ B) @ var) + (A @ vec))) + assert expr == (((A @ B) @ var) + (A @ vec)) expr = A @ ((B @ var) - vec) - self.assertEqual(expr, (((A @ B) @ var) - (A @ vec))) + assert expr == (((A @ B) @ var) - (A @ vec)) # Distribute the @ operator to a sum if both symbols being summed are matmuls expr = A @ (B @ var + C @ var2) - self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2)) + assert expr == ((A @ B) @ var + (A @ C) @ var2) expr = A @ (B @ var - C @ var2) - self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2)) + assert expr == ((A @ B) @ var - (A @ C) @ var2) # Reduce (A@var + B@var) to ((A+B)@var) expr = A @ var + B @ var - self.assertEqual(expr, ((A + B) @ var)) + assert expr == ((A + B) @ var) # Do A*e first if it is constant expr = A @ (5 * var) - self.assertEqual(expr, ((A * 5) @ var)) + assert expr == ((A * 5) @ var) expr = A @ (var * 5) - self.assertEqual(expr, ((A * 5) @ var)) + assert expr == ((A * 5) @ var) # Do A/e first if it is constant expr = A @ (var / 2) - self.assertEqual(expr, ((A / 2) @ var)) + assert expr == ((A / 2) @ var) # Do (vec*A) first if it is constant expr = vec * (A @ var) - self.assertEqual(expr, ((vec * A) @ var)) + assert expr == ((vec * A) @ var) expr = (A @ var) * vec - self.assertEqual(expr, ((vec * A) @ var)) + assert expr == ((vec * A) @ var) # Do (A/vec) first if it is constant expr = (A @ var) / vec - self.assertIsInstance(expr, pybamm.MatrixMultiplication) + assert isinstance(expr, pybamm.MatrixMultiplication) np.testing.assert_array_almost_equal(expr.left.evaluate(), (A / vec).evaluate()) - self.assertEqual(expr.children[1], var) + assert expr.children[1] == var # simplify additions and subtractions expr = 7 + (var + 5) - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = 7 + (5 + var) - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = (var + 5) + 7 - self.assertEqual(expr, (var + 12)) + assert expr == (var + 12) expr = (5 + var) + 7 - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = 7 + (var - 5) - self.assertEqual(expr, (2 + var)) + assert expr == (2 + var) expr = 7 + (5 - var) - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = (var - 5) + 7 - self.assertEqual(expr, (var + 2)) + assert expr == (var + 2) expr = (5 - var) + 7 - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = 7 - (var + 5) - self.assertEqual(expr, (2 - var)) + assert expr == (2 - var) expr = 7 - (5 + var) - self.assertEqual(expr, (2 - var)) + assert expr == (2 - var) expr = (var + 5) - 7 - self.assertEqual(expr, (var + -2)) + assert expr == (var + -2) expr = (5 + var) - 7 - self.assertEqual(expr, (-2 + var)) + assert expr == (-2 + var) expr = 7 - (var - 5) - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = 7 - (5 - var) - self.assertEqual(expr, (2 + var)) + assert expr == (2 + var) expr = (var - 5) - 7 - self.assertEqual(expr, (var - 12)) + assert expr == (var - 12) expr = (5 - var) - 7 - self.assertEqual(expr, (-2 - var)) + assert expr == (-2 - var) expr = var - (var + var2) - self.assertEqual(expr, -var2) + assert expr == -var2 # simplify multiplications and divisions expr = 10 * (var * 5) - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = (var * 5) * 10 - self.assertEqual(expr, var * 50) + assert expr == var * 50 expr = 10 * (5 * var) - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = (5 * var) * 10 - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = 10 * (var / 5) - self.assertEqual(expr, (10 / 5) * var) + assert expr == (10 / 5) * var expr = (var / 5) * 10 - self.assertEqual(expr, var * (10 / 5)) + assert expr == var * (10 / 5) expr = (var * 5) / 10 - self.assertEqual(expr, var * (5 / 10)) + assert expr == var * (5 / 10) expr = (5 * var) / 10 - self.assertEqual(expr, (5 / 10) * var) + assert expr == (5 / 10) * var expr = 5 / (10 * var) - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = 5 / (var * 10) - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = (5 / var) / 10 - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = 5 / (10 / var) - self.assertEqual(expr, (5 / 10) * var) + assert expr == (5 / 10) * var expr = 5 / (var / 10) - self.assertEqual(expr, 50 / var) + assert expr == 50 / var # use power rules on multiplications and divisions expr = (var * 5) ** 2 - self.assertEqual(expr, var**2 * 25) + assert expr == var**2 * 25 expr = (5 * var) ** 2 - self.assertEqual(expr, 25 * var**2) + assert expr == 25 * var**2 expr = (5 / var) ** 2 - self.assertEqual(expr, 25 / var**2) + assert expr == 25 / var**2 def test_inner_simplifications(self): a1 = pybamm.Scalar(0) @@ -776,116 +773,105 @@ def test_inner_simplifications(self): np.testing.assert_array_equal( pybamm.inner(a1, M2).evaluate().toarray(), M1.entries ) - self.assertEqual(pybamm.inner(a1, a2).evaluate(), 0) + assert pybamm.inner(a1, a2).evaluate() == 0 np.testing.assert_array_equal( pybamm.inner(M2, a1).evaluate().toarray(), M1.entries ) - self.assertEqual(pybamm.inner(a2, a1).evaluate(), 0) + assert pybamm.inner(a2, a1).evaluate() == 0 np.testing.assert_array_equal( pybamm.inner(M1, a3).evaluate().toarray(), M1.entries ) np.testing.assert_array_equal(pybamm.inner(v1, a3).evaluate(), 3 * v1.entries) - self.assertEqual(pybamm.inner(a2, a3).evaluate(), 3) - self.assertEqual(pybamm.inner(a3, a2).evaluate(), 3) - self.assertEqual(pybamm.inner(a3, a3).evaluate(), 9) + assert pybamm.inner(a2, a3).evaluate() == 3 + assert pybamm.inner(a3, a2).evaluate() == 3 + assert pybamm.inner(a3, a3).evaluate() == 9 def test_to_equation(self): # Test print_name pybamm.Addition.print_name = "test" - self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test")) + assert pybamm.Addition(1, 2).to_equation() == sympy.Symbol("test") # Test Power - self.assertEqual(pybamm.Power(7, 2).to_equation(), 49) + assert pybamm.Power(7, 2).to_equation() == 49 # Test Division - self.assertEqual(pybamm.Division(10, 2).to_equation(), 5) + assert pybamm.Division(10, 2).to_equation() == 5 # Test Matrix Multiplication arr1 = pybamm.Array([[1, 0], [0, 1]]) arr2 = pybamm.Array([[4, 1], [2, 2]]) - self.assertEqual( - pybamm.MatrixMultiplication(arr1, arr2).to_equation(), - sympy.Matrix([[4.0, 1.0], [2.0, 2.0]]), + assert pybamm.MatrixMultiplication(arr1, arr2).to_equation() == sympy.Matrix( + [[4.0, 1.0], [2.0, 2.0]] ) # Test EqualHeaviside - self.assertEqual(pybamm.EqualHeaviside(1, 0).to_equation(), False) + assert not pybamm.EqualHeaviside(1, 0).to_equation() # Test NotEqualHeaviside - self.assertEqual(pybamm.NotEqualHeaviside(2, 4).to_equation(), True) + assert pybamm.NotEqualHeaviside(2, 4).to_equation() - def test_to_json(self): + def test_to_json(self, mocker): # Test Addition add_json = { "name": "+", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } add = pybamm.Addition(2, 4) - self.assertEqual(add.to_json(), add_json) + assert add.to_json() == add_json add_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.Addition._from_json(add_json), add) + assert pybamm.Addition._from_json(add_json) == add # Test Power pow_json = { "name": "**", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } pow = pybamm.Power(7, 2) - self.assertEqual(pow.to_json(), pow_json) + assert pow.to_json() == pow_json pow_json["children"] = [pybamm.Scalar(7), pybamm.Scalar(2)] - self.assertEqual(pybamm.Power._from_json(pow_json), pow) + assert pybamm.Power._from_json(pow_json) == pow # Test Division div_json = { "name": "/", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } div = pybamm.Division(10, 5) - self.assertEqual(div.to_json(), div_json) + assert div.to_json() == div_json div_json["children"] = [pybamm.Scalar(10), pybamm.Scalar(5)] - self.assertEqual(pybamm.Division._from_json(div_json), div) + assert pybamm.Division._from_json(div_json) == div # Test EqualHeaviside equal_json = { "name": "<=", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } equal_h = pybamm.EqualHeaviside(2, 4) - self.assertEqual(equal_h.to_json(), equal_json) + assert equal_h.to_json() == equal_json equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.EqualHeaviside._from_json(equal_json), equal_h) + assert pybamm.EqualHeaviside._from_json(equal_json) == equal_h # Test notEqualHeaviside not_equal_json = { "name": "<", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } ne_h = pybamm.NotEqualHeaviside(2, 4) - self.assertEqual(ne_h.to_json(), not_equal_json) + assert ne_h.to_json() == not_equal_json not_equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.NotEqualHeaviside._from_json(not_equal_json), ne_h) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.NotEqualHeaviside._from_json(not_equal_json) == ne_h diff --git a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py index 033dcf5345..95d0b53a64 100644 --- a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py @@ -2,9 +2,9 @@ # Tests for the base battery model class # +import pytest from pybamm.models.full_battery_models.base_battery_model import BatteryModelOptions import pybamm -import unittest import io from contextlib import redirect_stdout import os @@ -56,7 +56,7 @@ """ -class TestBaseBatteryModel(unittest.TestCase): +class TestBaseBatteryModel: def test_process_parameters_and_discretise(self): model = pybamm.lithium_ion.SPM() # Set up geometry and parameters @@ -72,9 +72,9 @@ def test_process_parameters_and_discretise(self): * model.variables["X-averaged negative particle concentration [mol.m-3]"] ) processed_c = model.process_parameters_and_discretise(c, parameter_values, disc) - self.assertIsInstance(processed_c, pybamm.Multiplication) - self.assertIsInstance(processed_c.left, pybamm.Scalar) - self.assertIsInstance(processed_c.right, pybamm.StateVector) + assert isinstance(processed_c, pybamm.Multiplication) + assert isinstance(processed_c.left, pybamm.Scalar) + assert isinstance(processed_c.right, pybamm.StateVector) # Process flux manually and check result against flux computed in particle # submodel c_n = model.variables["X-averaged negative particle concentration [mol.m-3]"] @@ -89,47 +89,39 @@ def test_process_parameters_and_discretise(self): flux_2 = model.variables["X-averaged negative particle flux [mol.m-2.s-1]"] param_flux_2 = parameter_values.process_symbol(flux_2) disc_flux_2 = disc.process_symbol(param_flux_2) - self.assertEqual(flux_1, disc_flux_2) + assert flux_1 == disc_flux_2 def test_summary_variables(self): model = pybamm.BaseBatteryModel() model.variables["var"] = pybamm.Scalar(1) model.summary_variables = ["var"] - self.assertEqual(model.summary_variables, ["var"]) - with self.assertRaisesRegex(KeyError, "No cycling variable defined"): + assert model.summary_variables == ["var"] + with pytest.raises(KeyError, match="No cycling variable defined"): model.summary_variables = ["bad var"] def test_default_geometry(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertEqual( - model.default_geometry["current collector"]["z"]["position"], 1 - ) + assert model.default_geometry["current collector"]["z"]["position"] == 1 model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertEqual(model.default_geometry["current collector"]["z"]["min"], 0) + assert model.default_geometry["current collector"]["z"]["min"] == 0 model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertEqual(model.default_geometry["current collector"]["y"]["min"], 0) + assert model.default_geometry["current collector"]["y"]["min"] == 0 def test_default_submesh_types(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.SubMesh0D, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.SubMesh0D, ) model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.Uniform1DSubMesh, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.Uniform1DSubMesh, ) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.ScikitUniform2DSubMesh, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.ScikitUniform2DSubMesh, ) def test_default_var_pts(self): @@ -149,44 +141,44 @@ def test_default_var_pts(self): "R_p": 30, } model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertDictEqual(var_pts, model.default_var_pts) + assert var_pts == model.default_var_pts var_pts.update({"x_n": 10, "x_s": 10, "x_p": 10}) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertDictEqual(var_pts, model.default_var_pts) + assert var_pts == model.default_var_pts def test_default_spatial_methods(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.ZeroDimensionalSpatialMethod, ) model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.FiniteVolume ) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.ScikitFiniteElement, ) def test_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "Option"): + with pytest.raises(pybamm.OptionError, match="Option"): pybamm.BaseBatteryModel({"bad option": "bad option"}) - with self.assertRaisesRegex(pybamm.OptionError, "current collector model"): + with pytest.raises(pybamm.OptionError, match="current collector model"): pybamm.BaseBatteryModel({"current collector": "bad current collector"}) - with self.assertRaisesRegex(pybamm.OptionError, "thermal"): + with pytest.raises(pybamm.OptionError, match="thermal"): pybamm.BaseBatteryModel({"thermal": "bad thermal"}) - with self.assertRaisesRegex(pybamm.OptionError, "cell geometry"): + with pytest.raises(pybamm.OptionError, match="cell geometry"): pybamm.BaseBatteryModel({"cell geometry": "bad geometry"}) - with self.assertRaisesRegex(pybamm.OptionError, "dimensionality"): + with pytest.raises(pybamm.OptionError, match="dimensionality"): pybamm.BaseBatteryModel({"dimensionality": 5}) - with self.assertRaisesRegex(pybamm.OptionError, "current collector"): + with pytest.raises(pybamm.OptionError, match="current collector"): pybamm.BaseBatteryModel( {"dimensionality": 1, "current collector": "bad option"} ) - with self.assertRaisesRegex(pybamm.OptionError, "1D current collectors"): + with pytest.raises(pybamm.OptionError, match="1D current collectors"): pybamm.BaseBatteryModel( { "current collector": "potential pair", @@ -194,7 +186,7 @@ def test_options(self): "thermal": "x-full", } ) - with self.assertRaisesRegex(pybamm.OptionError, "2D current collectors"): + with pytest.raises(pybamm.OptionError, match="2D current collectors"): pybamm.BaseBatteryModel( { "current collector": "potential pair", @@ -202,58 +194,54 @@ def test_options(self): "thermal": "x-full", } ) - with self.assertRaisesRegex(pybamm.OptionError, "surface form"): + with pytest.raises(pybamm.OptionError, match="surface form"): pybamm.BaseBatteryModel({"surface form": "bad surface form"}) - with self.assertRaisesRegex(pybamm.OptionError, "convection"): + with pytest.raises(pybamm.OptionError, match="convection"): pybamm.BaseBatteryModel({"convection": "bad convection"}) - with self.assertRaisesRegex( - pybamm.OptionError, "cannot have transverse convection in 0D model" + with pytest.raises( + pybamm.OptionError, match="cannot have transverse convection in 0D model" ): pybamm.BaseBatteryModel({"convection": "full transverse"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle"): + with pytest.raises(pybamm.OptionError, match="particle"): pybamm.BaseBatteryModel({"particle": "bad particle"}) - with self.assertRaisesRegex(pybamm.OptionError, "working electrode"): + with pytest.raises(pybamm.OptionError, match="working electrode"): pybamm.BaseBatteryModel({"working electrode": "bad working electrode"}) - with self.assertRaisesRegex(pybamm.OptionError, "The 'negative' working"): + with pytest.raises(pybamm.OptionError, match="The 'negative' working"): pybamm.BaseBatteryModel({"working electrode": "negative"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle shape"): + with pytest.raises(pybamm.OptionError, match="particle shape"): pybamm.BaseBatteryModel({"particle shape": "bad particle shape"}) - with self.assertRaisesRegex(pybamm.OptionError, "operating mode"): + with pytest.raises(pybamm.OptionError, match="operating mode"): pybamm.BaseBatteryModel({"operating mode": "bad operating mode"}) - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.BaseBatteryModel( {"electrolyte conductivity": "bad electrolyte conductivity"} ) # SEI options - with self.assertRaisesRegex(pybamm.OptionError, "SEI"): + with pytest.raises(pybamm.OptionError, match="SEI"): pybamm.BaseBatteryModel({"SEI": "bad sei"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI film resistance"): + with pytest.raises(pybamm.OptionError, match="SEI film resistance"): pybamm.BaseBatteryModel({"SEI film resistance": "bad SEI film resistance"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI porosity change"): + with pytest.raises(pybamm.OptionError, match="SEI porosity change"): pybamm.BaseBatteryModel({"SEI porosity change": "bad SEI porosity change"}) # changing defaults based on other options model = pybamm.BaseBatteryModel() - self.assertEqual(model.options["SEI film resistance"], "none") + assert model.options["SEI film resistance"] == "none" model = pybamm.BaseBatteryModel({"SEI": "constant"}) - self.assertEqual(model.options["SEI film resistance"], "distributed") - self.assertEqual( - model.options["total interfacial current density as a state"], "true" - ) + assert model.options["SEI film resistance"] == "distributed" + assert model.options["total interfacial current density as a state"] == "true" model = pybamm.BaseBatteryModel( {"SEI film resistance": "average", "particle phases": "2"} ) - self.assertEqual( - model.options["total interfacial current density as a state"], "true" - ) - with self.assertRaisesRegex(pybamm.OptionError, "must be 'true'"): + assert model.options["total interfacial current density as a state"] == "true" + with pytest.raises(pybamm.OptionError, match="must be 'true'"): pybamm.BaseBatteryModel( { "SEI film resistance": "distributed", "total interfacial current density as a state": "false", } ) - with self.assertRaisesRegex(pybamm.OptionError, "must be 'true'"): + with pytest.raises(pybamm.OptionError, match="must be 'true'"): pybamm.BaseBatteryModel( { "SEI film resistance": "average", @@ -263,9 +251,9 @@ def test_options(self): ) # loss of active material model - with self.assertRaisesRegex(pybamm.OptionError, "loss of active material"): + with pytest.raises(pybamm.OptionError, match="loss of active material"): pybamm.BaseBatteryModel({"loss of active material": "bad LAM model"}) - with self.assertRaisesRegex(pybamm.OptionError, "loss of active material"): + with pytest.raises(pybamm.OptionError, match="loss of active material"): # can't have a 3-tuple pybamm.BaseBatteryModel( { @@ -281,11 +269,11 @@ def test_options(self): model = pybamm.BaseBatteryModel( {"loss of active material": "stress-driven", "SEI on cracks": "true"} ) - self.assertEqual( - model.options["particle mechanics"], - ("swelling and cracking", "swelling only"), + assert model.options["particle mechanics"] == ( + "swelling and cracking", + "swelling only", ) - self.assertEqual(model.options["stress-induced diffusion"], "true") + assert model.options["stress-induced diffusion"] == "true" model = pybamm.BaseBatteryModel( { "working electrode": "positive", @@ -293,29 +281,27 @@ def test_options(self): "SEI on cracks": "true", } ) - self.assertEqual(model.options["particle mechanics"], "swelling and cracking") - self.assertEqual(model.options["stress-induced diffusion"], "true") + assert model.options["particle mechanics"] == "swelling and cracking" + assert model.options["stress-induced diffusion"] == "true" # crack model - with self.assertRaisesRegex(pybamm.OptionError, "particle mechanics"): + with pytest.raises(pybamm.OptionError, match="particle mechanics"): pybamm.BaseBatteryModel({"particle mechanics": "bad particle cracking"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle cracking"): + with pytest.raises(pybamm.OptionError, match="particle cracking"): pybamm.BaseBatteryModel({"particle cracking": "bad particle cracking"}) # SEI on cracks - with self.assertRaisesRegex(pybamm.OptionError, "SEI on cracks"): + with pytest.raises(pybamm.OptionError, match="SEI on cracks"): pybamm.BaseBatteryModel({"SEI on cracks": "bad SEI on cracks"}) - with self.assertRaisesRegex(pybamm.OptionError, "'SEI on cracks' is 'true'"): + with pytest.raises(pybamm.OptionError, match="'SEI on cracks' is 'true'"): pybamm.BaseBatteryModel( {"SEI on cracks": "true", "particle mechanics": "swelling only"} ) # plating model - with self.assertRaisesRegex(pybamm.OptionError, "lithium plating"): + with pytest.raises(pybamm.OptionError, match="lithium plating"): pybamm.BaseBatteryModel({"lithium plating": "bad plating"}) - with self.assertRaisesRegex( - pybamm.OptionError, "lithium plating porosity change" - ): + with pytest.raises(pybamm.OptionError, match="lithium plating porosity change"): pybamm.BaseBatteryModel( { "lithium plating porosity change": "bad lithium " @@ -324,16 +310,16 @@ def test_options(self): ) # contact resistance - with self.assertRaisesRegex(pybamm.OptionError, "contact resistance"): + with pytest.raises(pybamm.OptionError, match="contact resistance"): pybamm.BaseBatteryModel({"contact resistance": "bad contact resistance"}) - with self.assertRaisesRegex(NotImplementedError, "Contact resistance not yet"): + with pytest.raises(NotImplementedError, match="Contact resistance not yet"): pybamm.BaseBatteryModel( { "contact resistance": "true", "operating mode": "explicit power", } ) - with self.assertRaisesRegex(NotImplementedError, "Contact resistance not yet"): + with pytest.raises(NotImplementedError, match="Contact resistance not yet"): pybamm.BaseBatteryModel( { "contact resistance": "true", @@ -342,29 +328,29 @@ def test_options(self): ) # stress-induced diffusion - with self.assertRaisesRegex(pybamm.OptionError, "cannot have stress"): + with pytest.raises(pybamm.OptionError, match="cannot have stress"): pybamm.BaseBatteryModel({"stress-induced diffusion": "true"}) # hydrolysis - with self.assertRaisesRegex(pybamm.OptionError, "surface formulation"): + with pytest.raises(pybamm.OptionError, match="surface formulation"): pybamm.lead_acid.LOQS({"hydrolysis": "true", "surface form": "false"}) # timescale - with self.assertRaisesRegex(pybamm.OptionError, "timescale"): + with pytest.raises(pybamm.OptionError, match="timescale"): pybamm.BaseBatteryModel({"timescale": "bad timescale"}) # thermal x-lumped - with self.assertRaisesRegex(pybamm.OptionError, "x-lumped"): + with pytest.raises(pybamm.OptionError, match="x-lumped"): pybamm.lithium_ion.BaseModel( {"cell geometry": "arbitrary", "thermal": "x-lumped"} ) # thermal half-cell - with self.assertRaisesRegex(pybamm.OptionError, "X-full"): + with pytest.raises(pybamm.OptionError, match="X-full"): pybamm.BaseBatteryModel( {"thermal": "x-full", "working electrode": "positive"} ) - with self.assertRaisesRegex(pybamm.OptionError, "X-lumped"): + with pytest.raises(pybamm.OptionError, match="X-lumped"): pybamm.BaseBatteryModel( { "dimensionality": 2, @@ -374,7 +360,7 @@ def test_options(self): ) # thermal heat of mixing - with self.assertRaisesRegex(NotImplementedError, "Heat of mixing"): + with pytest.raises(NotImplementedError, match="Heat of mixing"): pybamm.BaseBatteryModel( { "heat of mixing": "true", @@ -383,35 +369,35 @@ def test_options(self): ) # surface thermal model - with self.assertRaisesRegex(pybamm.OptionError, "surface temperature"): + with pytest.raises(pybamm.OptionError, match="surface temperature"): pybamm.BaseBatteryModel( {"surface temperature": "lumped", "thermal": "x-full"} ) # phases - with self.assertRaisesRegex(pybamm.OptionError, "multiple particle phases"): + with pytest.raises(pybamm.OptionError, match="multiple particle phases"): pybamm.BaseBatteryModel({"particle phases": "2", "surface form": "false"}) # msmr - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"open-circuit potential": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"particle": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"intercalation kinetics": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"open-circuit potential": "MSMR", "particle": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"open-circuit potential": "MSMR", "intercalation kinetics": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"particle": "MSMR", "intercalation kinetics": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( { "open-circuit potential": "MSMR", @@ -423,7 +409,7 @@ def test_options(self): def test_build_twice(self): model = pybamm.lithium_ion.SPM() # need to pick a model to set vars and build - with self.assertRaisesRegex(pybamm.ModelError, "Model already built"): + with pytest.raises(pybamm.ModelError, match="Model already built"): model.build_model() def test_get_coupled_variables(self): @@ -431,37 +417,37 @@ def test_get_coupled_variables(self): model.submodels["current collector"] = pybamm.current_collector.Uniform( model.param ) - with self.assertRaisesRegex(pybamm.ModelError, "Missing variable"): + with pytest.raises(pybamm.ModelError, match="Missing variable"): model.build_model() def test_default_solver(self): model = pybamm.BaseBatteryModel() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(model.default_solver, pybamm.CasadiSolver) # check that default_solver gives you a new solver, not an internal object solver = model.default_solver solver = pybamm.BaseModel() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) - self.assertIsInstance(solver, pybamm.BaseModel) + assert isinstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(solver, pybamm.BaseModel) # check that adding algebraic variables gives algebraic solver a = pybamm.Variable("a") model.algebraic = {a: a - 1} - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) def test_option_type(self): # no entry gets default options model = pybamm.BaseBatteryModel() - self.assertIsInstance(model.options, pybamm.BatteryModelOptions) + assert isinstance(model.options, pybamm.BatteryModelOptions) # dict options get converted to BatteryModelOptions model = pybamm.BaseBatteryModel({"thermal": "isothermal"}) - self.assertIsInstance(model.options, pybamm.BatteryModelOptions) + assert isinstance(model.options, pybamm.BatteryModelOptions) # special dict types are not changed options = pybamm.FuzzyDict({"thermal": "isothermal"}) model = pybamm.BaseBatteryModel(options) - self.assertEqual(model.options, options) + assert model.options == options def test_save_load_model(self): model = pybamm.lithium_ion.SPM() @@ -479,7 +465,7 @@ def test_save_load_model(self): ) # raises error if variables are saved without mesh - with self.assertRaises(ValueError): + with pytest.raises(ValueError): model.save_model( filename="test_base_battery_model", variables=model.variables ) @@ -487,71 +473,56 @@ def test_save_load_model(self): os.remove("test_base_battery_model.json") -class TestOptions(unittest.TestCase): +class TestOptions: def test_print_options(self): with io.StringIO() as buffer, redirect_stdout(buffer): BatteryModelOptions(OPTIONS_DICT).print_options() output = buffer.getvalue() - self.assertEqual(output, PRINT_OPTIONS_OUTPUT) + assert output == PRINT_OPTIONS_OUTPUT def test_option_phases(self): options = BatteryModelOptions({}) - self.assertEqual( - options.phases, {"negative": ["primary"], "positive": ["primary"]} - ) + assert options.phases == {"negative": ["primary"], "positive": ["primary"]} options = BatteryModelOptions({"particle phases": ("1", "2")}) - self.assertEqual( - options.phases, - {"negative": ["primary"], "positive": ["primary", "secondary"]}, - ) + assert options.phases == { + "negative": ["primary"], + "positive": ["primary", "secondary"], + } def test_domain_options(self): options = BatteryModelOptions( {"particle": ("Fickian diffusion", "quadratic profile")} ) - self.assertEqual(options.negative["particle"], "Fickian diffusion") - self.assertEqual(options.positive["particle"], "quadratic profile") + assert options.negative["particle"] == "Fickian diffusion" + assert options.positive["particle"] == "quadratic profile" # something that is the same in both domains - self.assertEqual(options.negative["thermal"], "isothermal") - self.assertEqual(options.positive["thermal"], "isothermal") + assert options.negative["thermal"] == "isothermal" + assert options.positive["thermal"] == "isothermal" def test_domain_phase_options(self): options = BatteryModelOptions( {"particle mechanics": (("swelling only", "swelling and cracking"), "none")} ) - self.assertEqual( - options.negative["particle mechanics"], - ("swelling only", "swelling and cracking"), + assert options.negative["particle mechanics"] == ( + "swelling only", + "swelling and cracking", ) - self.assertEqual( - options.negative.primary["particle mechanics"], "swelling only" + assert options.negative.primary["particle mechanics"] == "swelling only" + assert ( + options.negative.secondary["particle mechanics"] == "swelling and cracking" ) - self.assertEqual( - options.negative.secondary["particle mechanics"], "swelling and cracking" - ) - self.assertEqual(options.positive["particle mechanics"], "none") - self.assertEqual(options.positive.primary["particle mechanics"], "none") - self.assertEqual(options.positive.secondary["particle mechanics"], "none") + assert options.positive["particle mechanics"] == "none" + assert options.positive.primary["particle mechanics"] == "none" + assert options.positive.secondary["particle mechanics"] == "none" def test_whole_cell_domains(self): options = BatteryModelOptions({"working electrode": "positive"}) - self.assertEqual( - options.whole_cell_domains, ["separator", "positive electrode"] - ) + assert options.whole_cell_domains == ["separator", "positive electrode"] options = BatteryModelOptions({}) - self.assertEqual( - options.whole_cell_domains, - ["negative electrode", "separator", "positive electrode"], - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert options.whole_cell_domains == [ + "negative electrode", + "separator", + "positive electrode", + ] diff --git a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py index 014b467715..91fa8ef87e 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py @@ -2,11 +2,11 @@ # Tests for O'Kane (2022) parameter set # +import pytest import pybamm -import unittest -class TestOKane2022(unittest.TestCase): +class TestOKane2022: def test_functions(self): param = pybamm.ParameterValues("OKane2022") sto = pybamm.Scalar(0.9) @@ -40,16 +40,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py b/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py index d7133a73e0..77fc3d66e7 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py @@ -3,11 +3,10 @@ # import pybamm -import unittest -class TestParameterValuesWithModel(unittest.TestCase): - def test_parameter_values_with_model(self): +class TestParameterValuesWithModel: + def test_parameter_values_with_model(self, subtests): param_to_model = { "Ai2020": pybamm.lithium_ion.DFN( {"particle mechanics": "swelling and cracking"} @@ -46,16 +45,6 @@ def test_parameter_values_with_model(self): # Loop over each parameter set, testing that parameters can be set for param, model in param_to_model.items(): - with self.subTest(param=param): + with subtests.test(param=param): parameter_values = pybamm.ParameterValues(param) parameter_values.process_model(model) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index a99f108f40..53abb94c83 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -2,11 +2,10 @@ # Tests for the KLU-Jax interface class # -from parameterized import parameterized +import pytest import pybamm import numpy as np -import unittest testcase = [] if pybamm.has_idaklu() and pybamm.has_jax(): @@ -86,21 +85,21 @@ def no_jit(f): # Check the interface throws an appropriate error if either IDAKLU or JAX not available -@unittest.skipIf( +@pytest.mark.skipif( pybamm.has_idaklu() and pybamm.has_jax(), - "Both IDAKLU and JAX are available", + reason="Both IDAKLU and JAX are available", ) -class TestIDAKLUJax_NoJax(unittest.TestCase): +class TestIDAKLUJax_NoJax: def test_instantiate_fails(self): - with self.assertRaises(ModuleNotFoundError): + with pytest.raises(ModuleNotFoundError): pybamm.IDAKLUJax([], [], []) -@unittest.skipIf( +@pytest.mark.skipif( not pybamm.has_idaklu() or not pybamm.has_jax(), - "IDAKLU Solver and/or JAX are not available", + reason="IDAKLU Solver and/or JAX are not available", ) -class TestIDAKLUJax(unittest.TestCase): +class TestIDAKLUJax: # Initialisation tests def test_initialise_twice(self): @@ -110,7 +109,7 @@ def test_initialise_twice(self): output_variables=output_variables, calculate_sensitivities=True, ) - with self.assertWarns(UserWarning): + with pytest.warns(UserWarning): idaklu_jax_solver.jaxify( model, t_eval, @@ -127,15 +126,15 @@ def test_uninitialised(self): ) # simulate failure in initialisation idaklu_jax_solver.jaxpr = None - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.get_jaxpr() - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.jax_value() - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.jax_grad() def test_no_output_variables(self): - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_solver.jaxify( model, t_eval, @@ -170,63 +169,63 @@ def test_no_inputs(self): # Scalar evaluation - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval[k], inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval[k]) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.vmap(f, in_axes=in_axes))(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_batch_over_inputs(self, output_variables, idaklu_jax_solver, f, wrapper): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper(jax.vmap(f, in_axes=(None, 0)))(t_eval, inputs_mock) # Get all vars (should mirror test_f_* [above]) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_vars() # no variable name specified idaklu_jax_solver.get_vars(output_variables) # (okay) idaklu_jax_solver.get_vars(f, output_variables) # (okay) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_vars(1, 2, 3) # too many arguments - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval[k], inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval[k]) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vector_array( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -236,7 +235,7 @@ def test_getvars_vector_array( out = idaklu_jax_solver.get_vars(array, output_variables) np.testing.assert_allclose(out, array) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -250,20 +249,20 @@ def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Isolate single output variable - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_var() # no variable name specified idaklu_jax_solver.get_var(output_variables[0]) # (okay) idaklu_jax_solver.get_var(f, output_variables[0]) # (okay) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_var(1, 2, 3) # too many arguments - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_float_jaxpr( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -272,7 +271,7 @@ def test_getvar_scalar_float_jaxpr( out = wrapper(idaklu_jax_solver.get_var(outvar))(float(t_eval[k]), inputs) np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_float_f( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -283,35 +282,35 @@ def test_getvar_scalar_float_f( ) np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval[k], inputs) np.testing.assert_allclose(out, sim[outvar](t_eval[k])) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval[k], inputs) np.testing.assert_allclose(out, sim[outvar](t_eval[k])) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval, inputs) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(f, outvar))(t_eval, inputs) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided np.ndarray if wrapper == jax.jit: @@ -321,7 +320,7 @@ def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapp out = idaklu_jax_solver.get_var(array, outvar) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -334,7 +333,7 @@ def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Differentiation rules (jacfwd) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval[k], inputs) flat_out, _ = tree_flatten(out) @@ -348,7 +347,7 @@ def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ).T np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval, inputs) flat_out, _ = tree_flatten(out) @@ -365,7 +364,7 @@ def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): f"Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -384,11 +383,11 @@ def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_wrt_time( self, output_variables, idaklu_jax_solver, f, wrapper ): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacfwd(f, argnums=0), @@ -396,12 +395,12 @@ def test_jacfwd_vmap_wrt_time( ), )(t_eval, inputs) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper ): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacfwd(f, argnums=1), @@ -411,7 +410,7 @@ def test_jacfwd_batch_over_inputs( # Differentiation rules (jacrev) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval[k], inputs) flat_out, _ = tree_flatten(out) @@ -425,7 +424,7 @@ def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ).T np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval, inputs) flat_out, _ = tree_flatten(out) @@ -439,7 +438,7 @@ def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -458,12 +457,12 @@ def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper ): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacrev(f, argnums=1), @@ -473,7 +472,7 @@ def test_jacrev_batch_over_inputs( # Forward differentiation rules with get_vars (multiple) and get_var (singular) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -496,7 +495,7 @@ def test_jacfwd_scalar_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -515,7 +514,7 @@ def test_jacfwd_scalar_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -539,7 +538,7 @@ def test_jacfwd_vector_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -558,7 +557,7 @@ def test_jacfwd_vector_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -577,7 +576,7 @@ def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -596,7 +595,7 @@ def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Reverse differentiation rules with get_vars (multiple) and get_var (singular) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -619,7 +618,7 @@ def test_jacrev_scalar_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -640,7 +639,7 @@ def test_jacrev_scalar_getvar( f"Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -664,7 +663,7 @@ def test_jacrev_vector_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -683,7 +682,7 @@ def test_jacrev_vector_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -702,7 +701,7 @@ def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -721,7 +720,7 @@ def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Gradient rule (takes single variable) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -735,7 +734,7 @@ def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrappe check = np.array([sim[outvar].sensitivities[invar][k] for invar in inputs]) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -754,7 +753,7 @@ def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper) # Value and gradient (takes single variable) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_value_and_grad_scalar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -774,7 +773,7 @@ def test_value_and_grad_scalar( check = np.array([sim[outvar].sensitivities[invar][k] for invar in inputs]) np.testing.assert_allclose(flat_t, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: primals, tangents = wrapper( @@ -797,7 +796,7 @@ def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapp # Helper functions - These return values (not jaxexprs) so cannot be JITed - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: # Skipping test_jax_vars for jax.jit, jit not supported on helper functions @@ -812,7 +811,7 @@ def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): f"{outvar}: Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: # Skipping test_jax_grad for jax.jit, jit not supported on helper functions @@ -829,7 +828,7 @@ def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): # Wrap jaxified expression in another function and take the gradient - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_wrapper_sse(self, output_variables, idaklu_jax_solver, f, wrapper): # Use surrogate for experimental data data = sim["v"](t_eval) diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index dfaa6c7201..c5516ee880 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -1,15 +1,14 @@ # Tests for the Scipy Solver class # +import pytest import pybamm -import unittest import numpy as np from tests import get_mesh_for_testing, get_discretisation_for_testing import warnings -import sys -class TestScipySolver(unittest.TestCase): +class TestScipySolver: def test_model_solver_python_and_jax(self): if pybamm.has_jax(): formats = ["python", "jax"] @@ -43,10 +42,8 @@ def test_model_solver_python_and_jax(self): np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) # Test time - self.assertEqual( - solution.total_time, solution.solve_time + solution.set_up_time - ) - self.assertEqual(solution.termination, "final time") + assert solution.total_time == solution.solve_time + solution.set_up_time + assert solution.termination == "final time" def test_model_solver_failure(self): # Turn off warnings to ignore sqrt error @@ -65,7 +62,7 @@ def test_model_solver_failure(self): t_eval = np.linspace(0, 3, 100) solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") # Expect solver to fail when y goes negative - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): solver.solve(model, t_eval) # Turn warnings back on @@ -96,7 +93,7 @@ def test_model_solver_with_event_python(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) np.testing.assert_equal(solution.t_event[0], solution.t[-1]) @@ -222,7 +219,7 @@ def test_step_different_model(self): np.testing.assert_array_almost_equal(step_sol1.y[0], np.exp(0.1 * step_sol1.t)) # Step again, the model has changed so this raises an error - with self.assertRaisesRegex(RuntimeError, "already been initialised"): + with pytest.raises(RuntimeError, match="already been initialised"): solver.step(step_sol1, model2, dt) def test_model_solver_with_inputs(self): @@ -245,11 +242,11 @@ def test_model_solver_with_inputs(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) - def test_model_solver_multiple_inputs_happy_path(self): + def test_model_solver_multiple_inputs_happy_path(self, subtests): for convert_to_format in ["python", "casadi"]: # Create model model = pybamm.BaseModel() @@ -271,7 +268,7 @@ def test_model_solver_multiple_inputs_happy_path(self): solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) for i in range(ninputs): - with self.subTest(i=i): + with subtests.test(i=i): solution = solutions[i] np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( @@ -304,12 +301,10 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): event_type=pybamm.EventType.DISCONTINUITY, ) ] - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - ( - "Cannot solve for a list of input parameters" - " sets with discontinuities" - ), + match="Cannot solve for a list of input parameters" + " sets with discontinuities", ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) @@ -332,13 +327,14 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): ninputs = 8 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - ("Input parameters cannot appear in expression " "for initial conditions."), + match="Input parameters cannot appear in expression " + "for initial conditions.", ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - def test_model_solver_multiple_inputs_jax_format(self): + def test_model_solver_multiple_inputs_jax_format(self, subtests): if pybamm.has_jax(): # Create model model = pybamm.BaseModel() @@ -360,7 +356,7 @@ def test_model_solver_multiple_inputs_jax_format(self): solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) for i in range(ninputs): - with self.subTest(i=i): + with subtests.test(i=i): solution = solutions[i] np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( @@ -395,7 +391,7 @@ def test_model_solver_with_event_with_casadi(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model_disc, t_eval) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal( solution.t[:-1], t_eval[: len(solution.t) - 1] ) @@ -422,7 +418,7 @@ def test_model_solver_with_inputs_with_casadi(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) @@ -492,7 +488,7 @@ def test_scale_and_reference(self): ) -class TestScipySolverWithSensitivity(unittest.TestCase): +class TestScipySolverWithSensitivity: def test_solve_sensitivity_scalar_var_scalar_input(self): # Create model model = pybamm.BaseModel() @@ -780,12 +776,3 @@ def test_solve_sensitivity_vector_var_vector_input(self): solution["integral of var"].sensitivities["param"], np.vstack([-2 * t * np.exp(-p_eval * t) * l_n / n for t in t_eval]), ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 5a584fabbf..3aff012d5b 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -1,11 +1,12 @@ # # Tests for the Solution class # +import pytest import os - +import io +import logging import json import pybamm -import unittest import numpy as np import pandas as pd from scipy.io import loadmat @@ -13,23 +14,23 @@ from tempfile import TemporaryDirectory -class TestSolution(unittest.TestCase): +class TestSolution: def test_init(self): t = np.linspace(0, 1) y = np.tile(t, (20, 1)) sol = pybamm.Solution(t, y, pybamm.BaseModel(), {}) np.testing.assert_array_equal(sol.t, t) np.testing.assert_array_equal(sol.y, y) - self.assertEqual(sol.t_event, None) - self.assertEqual(sol.y_event, None) - self.assertEqual(sol.termination, "final time") - self.assertEqual(sol.all_inputs, [{}]) - self.assertIsInstance(sol.all_models[0], pybamm.BaseModel) + assert sol.t_event is None + assert sol.y_event is None + assert sol.termination == "final time" + assert sol.all_inputs == [{}] + assert isinstance(sol.all_models[0], pybamm.BaseModel) def test_sensitivities(self): t = np.linspace(0, 1) y = np.tile(t, (20, 1)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): pybamm.Solution(t, y, pybamm.BaseModel(), {}, sensitivities=1.0) def test_errors(self): @@ -37,8 +38,8 @@ def test_errors(self): sol = pybamm.Solution( bad_ts, [np.ones((1, 3)), np.ones((1, 3))], pybamm.BaseModel(), {} ) - with self.assertRaisesRegex( - ValueError, "Solution time vector must be strictly increasing" + with pytest.raises( + ValueError, match="Solution time vector must be strictly increasing" ): sol.set_t() @@ -48,21 +49,27 @@ def test_errors(self): var = pybamm.StateVector(slice(0, 1)) model.rhs = {var: 0} model.variables = {var.name: var} - with self.assertLogs() as captured: - pybamm.Solution(ts, bad_ys, model, {}) - self.assertIn("exceeds the maximum", captured.records[0].getMessage()) - - with self.assertRaisesRegex( - TypeError, "sensitivities arg needs to be a bool or dict" + log_capture = io.StringIO() + handler = logging.StreamHandler(log_capture) + handler.setLevel(logging.ERROR) + logger = logging.getLogger("pybamm.logger") + logger.addHandler(handler) + pybamm.Solution(ts, bad_ys, model, {}) + log_output = log_capture.getvalue() + assert "exceeds the maximum" in log_output + logger.removeHandler(handler) + + with pytest.raises( + TypeError, match="sensitivities arg needs to be a bool or dict" ): pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities="bad") sol = pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities={}) - with self.assertRaisesRegex(TypeError, "sensitivities arg needs to be a bool"): + with pytest.raises(TypeError, match="sensitivities arg needs to be a bool"): sol.sensitivities = "bad" - with self.assertRaisesRegex( + with pytest.raises( NotImplementedError, - "Setting sensitivities is not supported if sensitivities are already provided as a dict", + match="Setting sensitivities is not supported if sensitivities are already provided as a dict", ): sol.sensitivities = True @@ -84,7 +91,7 @@ def test_add_solutions(self): sol_sum = sol1 + sol2 # Test - self.assertEqual(sol_sum.integration_time, 0.8) + assert sol_sum.integration_time == 0.8 np.testing.assert_array_equal(sol_sum.t, np.concatenate([t1, t2[1:]])) np.testing.assert_array_equal( sol_sum.y, np.concatenate([y1, y2[:, 1:]], axis=1) @@ -92,36 +99,38 @@ def test_add_solutions(self): np.testing.assert_array_equal(sol_sum.all_inputs, [{"a": 1}, {"a": 2}]) # Test sub-solutions - self.assertEqual(len(sol_sum.sub_solutions), 2) + assert len(sol_sum.sub_solutions) == 2 np.testing.assert_array_equal(sol_sum.sub_solutions[0].t, t1) np.testing.assert_array_equal(sol_sum.sub_solutions[1].t, t2) - self.assertEqual(sol_sum.sub_solutions[0].all_models[0], sol_sum.all_models[0]) + assert sol_sum.sub_solutions[0].all_models[0] == sol_sum.all_models[0] np.testing.assert_array_equal(sol_sum.sub_solutions[0].all_inputs[0]["a"], 1) - self.assertEqual(sol_sum.sub_solutions[1].all_models[0], sol2.all_models[0]) - self.assertEqual(sol_sum.all_models[1], sol2.all_models[0]) + assert sol_sum.sub_solutions[1].all_models[0] == sol2.all_models[0] + assert sol_sum.all_models[1] == sol2.all_models[0] np.testing.assert_array_equal(sol_sum.sub_solutions[1].all_inputs[0]["a"], 2) # Add solution already contained in existing solution t3 = np.array([2]) y3 = np.ones((1, 1)) sol3 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {"a": 3}) - self.assertEqual((sol_sum + sol3).all_ts, sol_sum.copy().all_ts) + assert (sol_sum + sol3).all_ts == sol_sum.copy().all_ts # add None sol4 = sol3 + None - self.assertEqual(sol3.all_ys, sol4.all_ys) + assert sol3.all_ys == sol4.all_ys # radd sol5 = None + sol3 - self.assertEqual(sol3.all_ys, sol5.all_ys) + assert sol3.all_ys == sol5.all_ys # radd failure - with self.assertRaisesRegex( - pybamm.SolverError, "Only a Solution or None can be added to a Solution" + with pytest.raises( + pybamm.SolverError, + match="Only a Solution or None can be added to a Solution", ): sol3 + 2 - with self.assertRaisesRegex( - pybamm.SolverError, "Only a Solution or None can be added to a Solution" + with pytest.raises( + pybamm.SolverError, + match="Only a Solution or None can be added to a Solution", ): 2 + sol3 @@ -133,14 +142,12 @@ def test_add_solutions(self): all_sensitivities={"test": [np.ones((1, 3))]}, ) sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {}, all_sensitivities=True) - with self.assertRaisesRegex( - ValueError, "Sensitivities must be of the same type" - ): + with pytest.raises(ValueError, match="Sensitivities must be of the same type"): sol3 = sol1 + sol2 sol1 = pybamm.Solution(t1, y3, pybamm.BaseModel(), {}, all_sensitivities=False) sol2 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {}, all_sensitivities={}) sol3 = sol1 + sol2 - self.assertFalse(sol3._all_sensitivities) + assert not sol3._all_sensitivities def test_add_solutions_different_models(self): # Set up first solution @@ -160,8 +167,8 @@ def test_add_solutions_different_models(self): # Test np.testing.assert_array_equal(sol_sum.t, np.concatenate([t1, t2[1:]])) - with self.assertRaisesRegex( - pybamm.SolverError, "The solution is made up from different models" + with pytest.raises( + pybamm.SolverError, match="The solution is made up from different models" ): sol_sum.y @@ -176,14 +183,14 @@ def test_copy(self): sol1.integration_time = 0.3 sol_copy = sol1.copy() - self.assertEqual(sol_copy.all_ts, sol1.all_ts) + assert sol_copy.all_ts == sol1.all_ts for ys_copy, ys1 in zip(sol_copy.all_ys, sol1.all_ys): np.testing.assert_array_equal(ys_copy, ys1) - self.assertEqual(sol_copy.all_inputs, sol1.all_inputs) - self.assertEqual(sol_copy.all_inputs_casadi, sol1.all_inputs_casadi) - self.assertEqual(sol_copy.set_up_time, sol1.set_up_time) - self.assertEqual(sol_copy.solve_time, sol1.solve_time) - self.assertEqual(sol_copy.integration_time, sol1.integration_time) + assert sol_copy.all_inputs == sol1.all_inputs + assert sol_copy.all_inputs_casadi == sol1.all_inputs_casadi + assert sol_copy.set_up_time == sol1.set_up_time + assert sol_copy.solve_time == sol1.solve_time + assert sol_copy.integration_time == sol1.integration_time def test_last_state(self): # Set up first solution @@ -196,14 +203,14 @@ def test_last_state(self): sol1.integration_time = 0.3 sol_last_state = sol1.last_state - self.assertEqual(sol_last_state.all_ts[0], 2) + assert sol_last_state.all_ts[0] == 2 np.testing.assert_array_equal(sol_last_state.all_ys[0], 2) - self.assertEqual(sol_last_state.all_inputs, sol1.all_inputs[-1:]) - self.assertEqual(sol_last_state.all_inputs_casadi, sol1.all_inputs_casadi[-1:]) - self.assertEqual(sol_last_state.all_models, sol1.all_models[-1:]) - self.assertEqual(sol_last_state.set_up_time, 0) - self.assertEqual(sol_last_state.solve_time, 0) - self.assertEqual(sol_last_state.integration_time, 0) + assert sol_last_state.all_inputs == sol1.all_inputs[-1:] + assert sol_last_state.all_inputs_casadi == sol1.all_inputs_casadi[-1:] + assert sol_last_state.all_models == sol1.all_models[-1:] + assert sol_last_state.set_up_time == 0 + assert sol_last_state.solve_time == 0 + assert sol_last_state.integration_time == 0 def test_cycles(self): model = pybamm.lithium_ion.SPM() @@ -215,14 +222,14 @@ def test_cycles(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve() - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 len_cycle_1 = len(sol.cycles[0].t) - self.assertIsInstance(sol.cycles[0], pybamm.Solution) + assert isinstance(sol.cycles[0], pybamm.Solution) np.testing.assert_array_equal(sol.cycles[0].t, sol.t[:len_cycle_1]) np.testing.assert_array_equal(sol.cycles[0].y, sol.y[:, :len_cycle_1]) - self.assertIsInstance(sol.cycles[1], pybamm.Solution) + assert isinstance(sol.cycles[1], pybamm.Solution) np.testing.assert_array_equal(sol.cycles[1].t, sol.t[len_cycle_1:]) np.testing.assert_allclose(sol.cycles[1].y, sol.y[:, len_cycle_1:]) @@ -230,7 +237,7 @@ def test_total_time(self): sol = pybamm.Solution(np.array([0]), np.array([[1, 2]]), pybamm.BaseModel(), {}) sol.set_up_time = 0.5 sol.solve_time = 1.2 - self.assertEqual(sol.total_time, 1.7) + assert sol.total_time == 1.7 def test_getitem(self): model = pybamm.BaseModel() @@ -244,13 +251,13 @@ def test_getitem(self): # test create a new processed variable c_sol = solution["c"] - self.assertIsInstance(c_sol, pybamm.ProcessedVariable) + assert isinstance(c_sol, pybamm.ProcessedVariable) np.testing.assert_array_equal(c_sol.entries, c_sol(solution.t)) # test call an already created variable solution.update("2c") twoc_sol = solution["2c"] - self.assertIsInstance(twoc_sol, pybamm.ProcessedVariable) + assert isinstance(twoc_sol, pybamm.ProcessedVariable) np.testing.assert_array_equal(twoc_sol.entries, twoc_sol(solution.t)) np.testing.assert_array_equal(twoc_sol.entries, 2 * c_sol.entries) @@ -283,12 +290,12 @@ def test_save(self): solution = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) # test save data - with self.assertRaises(ValueError): + with pytest.raises(ValueError): solution.save_data(f"{test_stub}.pickle") # set variables first then save solution.update(["c", "d"]) - with self.assertRaisesRegex(ValueError, "pickle"): + with pytest.raises(ValueError, match="pickle"): solution.save_data(to_format="pickle") solution.save_data(f"{test_stub}.pickle") @@ -302,12 +309,12 @@ def test_save(self): np.testing.assert_array_equal(solution.data["c"], data_load["c"].flatten()) np.testing.assert_array_equal(solution.data["d"], data_load["d"]) - with self.assertRaisesRegex(ValueError, "matlab"): + with pytest.raises(ValueError, match="matlab"): solution.save_data(to_format="matlab") # to matlab with bad variables name fails solution.update(["c + d"]) - with self.assertRaisesRegex(ValueError, "Invalid character"): + with pytest.raises(ValueError, match="Invalid character"): solution.save_data(f"{test_stub}.mat", to_format="matlab") # Works if providing alternative name solution.save_data( @@ -319,8 +326,8 @@ def test_save(self): np.testing.assert_array_equal(solution.data["c + d"], data_load["c_plus_d"]) # to csv - with self.assertRaisesRegex( - ValueError, "only 0D variables can be saved to csv" + with pytest.raises( + ValueError, match="only 0D variables can be saved to csv" ): solution.save_data(f"{test_stub}.csv", to_format="csv") # only save "c" and "2c" @@ -330,7 +337,7 @@ def test_save(self): # check string is the same as the file with open(f"{test_stub}.csv") as f: # need to strip \r chars for windows - self.assertEqual(csv_str.replace("\r", ""), f.read()) + assert csv_str.replace("\r", "") == f.read() # read csv df = pd.read_csv(f"{test_stub}.csv") @@ -344,7 +351,7 @@ def test_save(self): # check string is the same as the file with open(f"{test_stub}.json") as f: # need to strip \r chars for windows - self.assertEqual(json_str.replace("\r", ""), f.read()) + assert json_str.replace("\r", "") == f.read() # check if string has the right values json_data = json.loads(json_str) @@ -352,17 +359,15 @@ def test_save(self): np.testing.assert_array_almost_equal(json_data["d"], solution.data["d"]) # raise error if format is unknown - with self.assertRaisesRegex( - ValueError, "format 'wrong_format' not recognised" + with pytest.raises( + ValueError, match="format 'wrong_format' not recognised" ): solution.save_data(f"{test_stub}.csv", to_format="wrong_format") # test save whole solution solution.save(f"{test_stub}.pickle") solution_load = pybamm.load(f"{test_stub}.pickle") - self.assertEqual( - solution.all_models[0].name, solution_load.all_models[0].name - ) + assert solution.all_models[0].name == solution_load.all_models[0].name np.testing.assert_array_equal( solution["c"].entries, solution_load["c"].entries ) @@ -412,14 +417,4 @@ def test_solution_evals_with_inputs(self): inputs = {"Negative electrode conductivity [S.m-1]": 0.1} sim.solve(t_eval=np.linspace(0, 10, 10), inputs=inputs) time = sim.solution["Time [h]"](sim.solution.t) - self.assertEqual(len(time), 10) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert len(time) == 10