Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass x0, bounds and n_parameters into Cost #101

Merged
merged 9 commits into from
Nov 17, 2023
9 changes: 3 additions & 6 deletions pybop/costs/error_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,16 @@ class BaseCost:
def __init__(self, problem):
self.problem = problem
self._target = problem._target
self.x0 = problem.x0
self.bounds = problem.bounds
self.n_parameters = problem.n_parameters

def __call__(self, x, grad=None):
"""
Returns the cost function value and computes the cost.
"""
raise NotImplementedError

def n_parameters(self):
"""
Returns the size of the parameter space.
"""
raise NotImplementedError


class RootMeanSquaredError(BaseCost):
"""
Expand Down
8 changes: 4 additions & 4 deletions pybop/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(
verbose=False,
):
self.cost = cost
self.problem = cost.problem
self.optimiser = optimiser
self.verbose = verbose
self.x0 = cost.problem.x0
self.bounds = self.problem.bounds
self.x0 = cost.x0
self.bounds = cost.bounds
self.n_parameters = cost.n_parameters
self.sigma0 = sigma0
self.log = []

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
self.pints = False

if issubclass(self.optimiser, pybop.NLoptOptimize):
self.optimiser = self.optimiser(self.problem.n_parameters)
self.optimiser = self.optimiser(self.n_parameters)

elif issubclass(self.optimiser, pybop.SciPyMinimize):
self.optimiser = self.optimiser()
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def test_costs(self):
assert cost.problem == problem
with pytest.raises(NotImplementedError):
cost([0.5])
with pytest.raises(NotImplementedError):
cost.n_parameters()

# Root Mean Squared Error
cost = pybop.RootMeanSquaredError(problem)
Expand Down
Loading