Skip to content

Commit

Permalink
Merge pull request #219 from python-adaptive/value_scale_ND
Browse files Browse the repository at this point in the history
pass value_scale to the LearnerND's loss_per_simplex function
  • Loading branch information
jbweston authored Sep 20, 2019
2 parents d9fc5dd + ca76230 commit fc297c3
Showing 1 changed file with 78 additions and 16 deletions.
94 changes: 78 additions & 16 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,78 @@ def orientation(simplex):
return sign


def uniform_loss(simplex, ys=None):
def uniform_loss(simplex, values, value_scale):
"""
Uniform loss.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The scaled function values of each of the simplex points.
value_scale : float
The scale of values, where ``values = function_values * value_scale``.
Returns
-------
loss : float
"""
return volume(simplex)


def std_loss(simplex, ys):
r = np.linalg.norm(np.std(ys, axis=0))
def std_loss(simplex, values, value_scale):
"""
Computes the loss of the simplex based on the standard deviation.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The scaled function values of each of the simplex points.
value_scale : float
The scale of values, where ``values = function_values * value_scale``.
Returns
-------
loss : float
"""

r = np.linalg.norm(np.std(values, axis=0))
vol = volume(simplex)

dim = len(simplex) - 1

return r.flat * np.power(vol, 1.0 / dim) + vol


def default_loss(simplex, ys):
# return std_loss(simplex, ys)
if isinstance(ys[0], Iterable):
pts = [(*x, *y) for x, y in zip(simplex, ys)]
def default_loss(simplex, values, value_scale):
"""
Computes the average of the volumes of the simplex.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The scaled function values of each of the simplex points.
value_scale : float
The scale of values, where ``values = function_values * value_scale``.
Returns
-------
loss : float
"""
if isinstance(values[0], Iterable):
pts = [(*x, *y) for x, y in zip(simplex, values)]
else:
pts = [(*x, y) for x, y in zip(simplex, ys)]
pts = [(*x, y) for x, y in zip(simplex, values)]
return simplex_volume_in_embedding(pts)


@uses_nth_neighbors(1)
def triangle_loss(simplex, values, neighbors, neighbor_values):
def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values):
"""
Computes the average of the volumes of the simplex combined with each
neighbouring point.
Expand All @@ -80,7 +128,9 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The function values of each of the simplex points.
The scaled function values of each of the simplex points.
value_scale : float
The scale of values, where ``values = function_values * value_scale``.
neighbors : list of tuples
The neighboring points of the simplex, ordered such that simplex[0]
exacly opposes neighbors[0], etc.
Expand Down Expand Up @@ -108,20 +158,22 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
def curvature_loss_function(exploration=0.05):
# XXX: add doc-string!
@uses_nth_neighbors(1)
def curvature_loss(simplex, values, neighbors, neighbor_values):
def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
"""Compute the curvature loss of a simplex.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The function values of each of the simplex points.
The scaled function values of each of the simplex points.
value_scale : float
The scale of values, where ``values = function_values * value_scale``.
neighbors : list of tuples
The neighboring points of the simplex, ordered such that simplex[0]
exacly opposes neighbors[0], etc.
neighbor_values : list of values
The function values for each of the neighboring points.
The scaled function values for each of the neighboring points.
Returns
-------
Expand All @@ -130,7 +182,9 @@ def curvature_loss(simplex, values, neighbors, neighbor_values):
dim = len(simplex[0]) # the number of coordinates
loss_input_volume = volume(simplex)

loss_curvature = triangle_loss(simplex, values, neighbors, neighbor_values)
loss_curvature = triangle_loss(
simplex, values, value_scale, neighbors, neighbor_values
)
return (
loss_curvature + exploration * loss_input_volume ** ((2 + dim) / dim)
) ** (1 / (2 + dim))
Expand Down Expand Up @@ -563,7 +617,9 @@ def _compute_loss(self, simplex):

if self.nth_neighbors == 0:
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
return float(
self.loss_per_simplex(vertices, values, self._output_multiplier)
)

# We do need the neighbors
neighbors = self.tri.get_opposing_vertices(simplex)
Expand All @@ -580,7 +636,13 @@ def _compute_loss(self, simplex):
neighbor_values[i] = self._output_multiplier * value

return float(
self.loss_per_simplex(vertices, values, neighbor_points, neighbor_values)
self.loss_per_simplex(
vertices,
values,
self._output_multiplier,
neighbor_points,
neighbor_values,
)
)

def _update_losses(self, to_delete: set, to_add: set):
Expand Down

0 comments on commit fc297c3

Please sign in to comment.