diff --git a/pybop/costs/error_costs.py b/pybop/costs/error_costs.py index 82582d52a..2c497d45b 100644 --- a/pybop/costs/error_costs.py +++ b/pybop/costs/error_costs.py @@ -10,7 +10,11 @@ class BaseCost: def __init__(self, problem): self.problem = problem - self._target = problem._target + if problem is not None: + self._target = problem._target + self.x0 = problem.x0 + self.bounds = problem.bounds + self.n_parameters = problem.n_parameters def __call__(self, x, grad=None): """ @@ -18,12 +22,6 @@ def __call__(self, x, grad=None): """ raise NotImplementedError - def n_parameters(self): - """ - Returns the size of the parameter space. - """ - raise NotImplementedError - class RootMeanSquaredError(BaseCost): """ diff --git a/pybop/costs/standalone.py b/pybop/costs/standalone.py new file mode 100644 index 000000000..197dcca5b --- /dev/null +++ b/pybop/costs/standalone.py @@ -0,0 +1,18 @@ +import pybop +import numpy as np + + +class StandaloneCost(pybop.BaseCost): + def __init__(self, problem=None): + super().__init__(problem) + + self.x0 = np.array([4.2]) + self.n_parameters = len(self.x0) + + self.bounds = dict( + lower=[-1], + upper=[10], + ) + + def __call__(self, x, grad=None): + return x[0] ** 2 + 42 diff --git a/pybop/optimisation.py b/pybop/optimisation.py index ac8ffdfd3..6dc947de7 100644 --- a/pybop/optimisation.py +++ b/pybop/optimisation.py @@ -23,11 +23,11 @@ def __init__( verbose=False, ): self.cost = cost - self.problem = cost.problem self.optimiser = optimiser self.verbose = verbose - self.x0 = cost.problem.x0 - self.bounds = self.problem.bounds + self.x0 = cost.x0 + self.bounds = cost.bounds + self.n_parameters = cost.n_parameters self.sigma0 = sigma0 self.log = [] @@ -56,7 +56,7 @@ def __init__( self.pints = False if issubclass(self.optimiser, pybop.NLoptOptimize): - self.optimiser = self.optimiser(self.problem.n_parameters) + self.optimiser = self.optimiser(self.n_parameters) elif issubclass(self.optimiser, pybop.SciPyMinimize): self.optimiser = self.optimiser() diff --git a/tests/unit/test_cost.py b/tests/unit/test_cost.py index 8f528e942..0c7e329f3 100644 --- a/tests/unit/test_cost.py +++ b/tests/unit/test_cost.py @@ -41,8 +41,6 @@ def test_costs(self, cut_off): assert base_cost.problem == problem with pytest.raises(NotImplementedError): base_cost([0.5]) - with pytest.raises(NotImplementedError): - base_cost.n_parameters() # Root Mean Squared Error rmse_cost = pybop.RootMeanSquaredError(problem) diff --git a/tests/unit/test_optimisation.py b/tests/unit/test_optimisation.py index b9d3b0414..406c89ee3 100644 --- a/tests/unit/test_optimisation.py +++ b/tests/unit/test_optimisation.py @@ -1,6 +1,7 @@ import pybop import numpy as np import pytest +from pybop.costs.standalone import StandaloneCost class TestOptimisation: @@ -8,6 +9,20 @@ class TestOptimisation: A class to test the optimisation class. """ + @pytest.mark.unit + def test_standalone(self): + # Build an Optimisation problem with a StandaloneCost + cost = StandaloneCost() + + opt = pybop.Optimisation(cost=cost, optimiser=pybop.NLoptOptimize) + + assert len(opt.x0) == opt.n_parameters + + x, final_cost = opt.run() + + np.testing.assert_allclose(x, 0, atol=1e-2) + np.testing.assert_allclose(final_cost, 42, atol=1e-2) + @pytest.mark.unit def test_prior_sampling(self): # Tests prior sampling