Skip to content

Commit

Permalink
update lbfgs to avoid the randomness caused by paddle.dot() temporari…
Browse files Browse the repository at this point in the history
…ly (#60591)

* update lbfgs to avoid the randomness caused by paddle.dot() temporarily

* add note
  • Loading branch information
lijialin03 authored Jan 8, 2024
1 parent 41679e4 commit fa1f901
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions python/paddle/optimizer/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
__all__ = []


def dot(x, y):
r"""
NOTE: This is a temporary workaround for unstable result computed by `paddle.dot`,
which will be reverted when the problem is fixed."
"""
return (x * y).sum(axis=-1)


def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
r"""Cubic interpolation between (x1, f1, g1) and (x2, f2, g2).
Use two points and their gradient to determine a cubic function and get the minimum point
Expand Down Expand Up @@ -152,7 +160,7 @@ def _strong_wolfe(
# evaluate objective and gradient using initial step
loss_new, grad_new = obj_func(xk, alpha, d)
ls_func_evals = 1
gtd_new = paddle.dot(grad_new, d)
gtd_new = dot(grad_new, d)

# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = (0, loss, grad, gtd)
Expand Down Expand Up @@ -205,7 +213,7 @@ def _strong_wolfe(

loss_new, grad_new = obj_func(xk, alpha, d)
ls_func_evals += 1
gtd_new = grad_new.dot(d)
gtd_new = dot(grad_new, d)
ls_iter += 1

# reached max number of iterations?
Expand Down Expand Up @@ -265,7 +273,7 @@ def _strong_wolfe(
# Evaluate new point
loss_new, grad_new = obj_func(xk, alpha, d)
ls_func_evals += 1
gtd_new = grad_new.dot(d)
gtd_new = dot(grad_new, d)
ls_iter += 1

if (
Expand Down Expand Up @@ -644,7 +652,7 @@ def step(self, closure):
# do lbfgs update (update memory)
y = flat_grad.subtract(prev_flat_grad)
s = d.multiply(paddle.to_tensor(alpha, dtype=d.dtype))
ys = y.dot(s)
ys = dot(y, s)
if ys > 1e-10:
# updating memory
if len(old_yk) == history_size:
Expand All @@ -659,7 +667,7 @@ def step(self, closure):
ro.append(1.0 / ys)

# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
H_diag = ys / dot(y, y) # (y*y)

# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
Expand All @@ -672,14 +680,14 @@ def step(self, closure):
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_sk[i].dot(q) * ro[i]
al[i] = dot(old_sk[i], q) * ro[i]
paddle.assign(q.add(old_yk[i] * (-al[i])), q)

# multiply by initial Hessian
# r/d is the final direction
d = r = paddle.multiply(q, H_diag)
for i in range(num_old):
be_i = old_yk[i].dot(r) * ro[i]
be_i = dot(old_yk[i], r) * ro[i]
paddle.assign(r.add(old_sk[i] * (al[i] - be_i)), r)

if prev_flat_grad is None:
Expand All @@ -700,7 +708,7 @@ def step(self, closure):
alpha = learning_rate

# directional derivative
gtd = flat_grad.dot(d)
gtd = dot(flat_grad, d)

# directional derivative is below tolerance
if gtd > -tolerance_change:
Expand Down

0 comments on commit fa1f901

Please sign in to comment.