Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass x0, bounds and n_parameters into Cost #101

Merged
merged 9 commits into from
Nov 17, 2023
12 changes: 5 additions & 7 deletions pybop/costs/error_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@ class BaseCost:

def __init__(self, problem):
self.problem = problem
self._target = problem._target
if problem is not None:
self._target = problem._target
self.x0 = problem.x0
self.bounds = problem.bounds
self.n_parameters = problem.n_parameters

def __call__(self, x, grad=None):
"""
Returns the cost function value and computes the cost.
"""
raise NotImplementedError

def n_parameters(self):
"""
Returns the size of the parameter space.
"""
raise NotImplementedError


class RootMeanSquaredError(BaseCost):
"""
Expand Down
18 changes: 18 additions & 0 deletions pybop/costs/standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pybop
import numpy as np


class StandaloneCost(pybop.BaseCost):
def __init__(self, problem=None):
super().__init__(problem)

self.x0 = np.array([4.2])
self.n_parameters = len(self.x0)

self.bounds = dict(
lower=[-1],
upper=[10],
)

def __call__(self, x, grad=None):
return x[0] ** 2 + 42
8 changes: 4 additions & 4 deletions pybop/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(
verbose=False,
):
self.cost = cost
self.problem = cost.problem
self.optimiser = optimiser
self.verbose = verbose
self.x0 = cost.problem.x0
self.bounds = self.problem.bounds
self.x0 = cost.x0
self.bounds = cost.bounds
self.n_parameters = cost.n_parameters
self.sigma0 = sigma0
self.log = []

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
self.pints = False

if issubclass(self.optimiser, pybop.NLoptOptimize):
self.optimiser = self.optimiser(self.problem.n_parameters)
self.optimiser = self.optimiser(self.n_parameters)

elif issubclass(self.optimiser, pybop.SciPyMinimize):
self.optimiser = self.optimiser()
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_costs(self, cut_off):
assert base_cost.problem == problem
with pytest.raises(NotImplementedError):
base_cost([0.5])
with pytest.raises(NotImplementedError):
base_cost.n_parameters()

# Root Mean Squared Error
rmse_cost = pybop.RootMeanSquaredError(problem)
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_optimisation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import pybop
import numpy as np
import pytest
from pybop.costs.standalone import StandaloneCost


class TestOptimisation:
"""
A class to test the optimisation class.
"""

@pytest.mark.unit
def test_standalone(self):
# Build an Optimisation problem with a StandaloneCost
cost = StandaloneCost()

opt = pybop.Optimisation(cost=cost, optimiser=pybop.NLoptOptimize)

assert len(opt.x0) == opt.n_parameters

x, final_cost = opt.run()

np.testing.assert_allclose(x, 0, atol=1e-2)
np.testing.assert_allclose(final_cost, 42, atol=1e-2)

@pytest.mark.unit
def test_prior_sampling(self):
# Tests prior sampling
Expand Down