Skip to content

Commit

Permalink
Pass x0, bounds and n_parameters into Cost (#101)
Browse files Browse the repository at this point in the history
* Unpack x0, bounds and n_parameters into Cost

* Pass x0, bounds and n_parameters from Cost

* Remove n_parameters function check

* Allow problem=None in BaseCost

* Add StandaloneCost

* Add test of StandaloneCosts
  • Loading branch information
NicolaCourtier authored Nov 17, 2023
1 parent 790bb47 commit f67c810
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 13 deletions.
12 changes: 5 additions & 7 deletions pybop/costs/error_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@ class BaseCost:

def __init__(self, problem):
self.problem = problem
self._target = problem._target
if problem is not None:
self._target = problem._target
self.x0 = problem.x0
self.bounds = problem.bounds
self.n_parameters = problem.n_parameters

def __call__(self, x, grad=None):
"""
Returns the cost function value and computes the cost.
"""
raise NotImplementedError

def n_parameters(self):
"""
Returns the size of the parameter space.
"""
raise NotImplementedError


class RootMeanSquaredError(BaseCost):
"""
Expand Down
18 changes: 18 additions & 0 deletions pybop/costs/standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pybop
import numpy as np


class StandaloneCost(pybop.BaseCost):
def __init__(self, problem=None):
super().__init__(problem)

self.x0 = np.array([4.2])
self.n_parameters = len(self.x0)

self.bounds = dict(
lower=[-1],
upper=[10],
)

def __call__(self, x, grad=None):
return x[0] ** 2 + 42
8 changes: 4 additions & 4 deletions pybop/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(
verbose=False,
):
self.cost = cost
self.problem = cost.problem
self.optimiser = optimiser
self.verbose = verbose
self.x0 = cost.problem.x0
self.bounds = self.problem.bounds
self.x0 = cost.x0
self.bounds = cost.bounds
self.n_parameters = cost.n_parameters
self.sigma0 = sigma0
self.log = []

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
self.pints = False

if issubclass(self.optimiser, pybop.NLoptOptimize):
self.optimiser = self.optimiser(self.problem.n_parameters)
self.optimiser = self.optimiser(self.n_parameters)

elif issubclass(self.optimiser, pybop.SciPyMinimize):
self.optimiser = self.optimiser()
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_costs(self, cut_off):
assert base_cost.problem == problem
with pytest.raises(NotImplementedError):
base_cost([0.5])
with pytest.raises(NotImplementedError):
base_cost.n_parameters()

# Root Mean Squared Error
rmse_cost = pybop.RootMeanSquaredError(problem)
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_optimisation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import pybop
import numpy as np
import pytest
from pybop.costs.standalone import StandaloneCost


class TestOptimisation:
"""
A class to test the optimisation class.
"""

@pytest.mark.unit
def test_standalone(self):
# Build an Optimisation problem with a StandaloneCost
cost = StandaloneCost()

opt = pybop.Optimisation(cost=cost, optimiser=pybop.NLoptOptimize)

assert len(opt.x0) == opt.n_parameters

x, final_cost = opt.run()

np.testing.assert_allclose(x, 0, atol=1e-2)
np.testing.assert_allclose(final_cost, 42, atol=1e-2)

@pytest.mark.unit
def test_prior_sampling(self):
# Tests prior sampling
Expand Down

0 comments on commit f67c810

Please sign in to comment.