Skip to content

Commit

Permalink
Add checks on early termination of simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaCourtier authored Nov 17, 2023
1 parent 8002a27 commit e0cdb75
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions tests/unit/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ class TestCosts:
Class for tests cost functions
"""

@pytest.mark.parametrize("cut_off", [2.5,3.777])
@pytest.mark.unit
def test_costs(self):
# Construct Problem
def test_costs(self, cut_off):
# Construct model
model = pybop.lithium_ion.SPM()

parameters = [
pybop.Parameter(
"Negative electrode active material volume fraction",
Expand All @@ -23,51 +25,53 @@ def test_costs(self):
# Form dataset
x0 = np.array([0.52])
solution = self.getdata(model, x0)

dataset = [
pybop.Dataset("Time [s]", solution["Time [s]"].data),
pybop.Dataset("Current function [A]", solution["Current [A]"].data),
pybop.Dataset("Voltage [V]", solution["Terminal voltage [V]"].data),
]

# Construct Problem
signal = "Voltage [V]"
problem = pybop.Problem(model, parameters, dataset, signal=signal)
model.parameter_set.update({"Lower voltage cut-off [V]": cut_off})
problem = pybop.Problem(model, parameters, dataset, signal=signal, x0=x0)

# Base Cost
cost = pybop.BaseCost(problem)
assert cost.problem == problem
base_cost = pybop.BaseCost(problem)
assert base_cost.problem == problem
with pytest.raises(NotImplementedError):
cost([0.5])
base_cost([0.5])
with pytest.raises(NotImplementedError):
cost.n_parameters()
base_cost.n_parameters()

# Root Mean Squared Error
cost = pybop.RootMeanSquaredError(problem)
cost([0.5])
rmse_cost = pybop.RootMeanSquaredError(problem)
rmse_cost([0.5])

assert type(cost([0.5])) == np.float64
assert cost([0.5]) >= 0
# Sum Squared Error
sums_cost = pybop.SumSquaredError(problem)
sums_cost([0.5])

# Root Mean Squared Error
cost = pybop.SumSquaredError(problem)
cost([0.5])
# Test type of returned value
assert type(rmse_cost([0.5])) == np.float64 or np.isinf(rmse_cost([0.5]))
assert rmse_cost([0.5]) >= 0
assert type(sums_cost([0.5])) == np.float64 or np.isinf(sums_cost([0.5]))
assert sums_cost([0.5]) >= 0

assert type(cost([0.5])) == np.float64
assert cost([0.5]) >= 0
# Test option setting
sums_cost.set_fail_gradient(1)

# Test catch on non-matching vector lengths
# Sum Squared Error
cost = pybop.SumSquaredError(problem)
# Test exception for non-numeric inputs
with pytest.raises(ValueError):
cost(["test-entry"])

rmse_cost(["StringInputShouldNotWork"])
with pytest.raises(ValueError):
cost.evaluateS1(["test-entry"])

# Root Mean Squared Error
cost = pybop.RootMeanSquaredError(problem)
sums_cost(["StringInputShouldNotWork"])
with pytest.raises(ValueError):
cost(["test-entry"])
sums_cost.evaluateS1(["StringInputShouldNotWork"])

# Test treatment of simulations that terminated early
# by variation of the cut-off voltage.


def getdata(self, model, x0):
model.parameter_set = model.pybamm_model.default_parameter_values
Expand Down

0 comments on commit e0cdb75

Please sign in to comment.