Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Objective function not passed extra arguments when Hessian is provided #175

Closed
ForceBru opened this issue Jan 6, 2023 · 4 comments
Closed

Comments

@ForceBru
Copy link

ForceBru commented Jan 6, 2023

Code

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import cyipopt

def basic_function(x: jnp.ndarray, data: jnp.ndarray):
    return (x[0] * data + x[1]).mean()

cyipopt.minimize_ipopt(
    basic_function,
    jnp.array([1.0, 2.0]), (jnp.array([1.,2,5,1,5,2,8,2,0,4,9]), ),
    jac=jax.jacfwd(basic_function),
    hess=jax.hessian(basic_function),
)

Error

# Message from JAX trimmed
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 314, in minimize_ipopt
    x, info = nlp.solve(_x0)
  File "cyipopt/cython/ipopt_wrapper.pyx", line 642, in ipopt_wrapper.Problem.solve
  File "cyipopt/cython/ipopt_wrapper.pyx", line 895, in ipopt_wrapper.hessian_cb
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/cyipopt/scipy_interface.py", line 164, in hessian
    H = obj_factor * self.obj_hess(x)  # type: ignore
  File "/Users/forcebru/mambaforge/lib/python3.10/site-packages/jax/_src/api.py", line 1286, in jacfun
    y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
TypeError: basic_function() missing 1 required positional argument: 'data'

I think in this code:

def hessian(self, x, lagrange, obj_factor):
H = obj_factor * self.obj_hess(x) # type: ignore

...self.obj_hess should be called with self.args and self.kwargs, like self.fun and self.jac:

return self.jac(x, *self.args, **self.kwargs) # .T

  • cyipopt 1.2.0
@brocksam
Copy link
Collaborator

brocksam commented Jan 9, 2023

Thanks for reporting this. I agree, it does look like self.args and self.kwargs are needed where you suggest. Would you be happy to submit a PR making this change?

@ForceBru
Copy link
Author

Sure, I'll give it a try this week!

@moorepants
Copy link
Collaborator

This now works on master:

$ ipython
Python 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.12.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: %paste
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import cyipopt

def basic_function(x: jnp.ndarray, data: jnp.ndarray):
    return (x[0] * data + x[1]).mean()

cyipopt.minimize_ipopt(
    basic_function,
    jnp.array([1.0, 2.0]), (jnp.array([1.,2,5,1,5,2,8,2,0,4,9]), ),
    jac=jax.jacfwd(basic_function),
    hess=jax.hessian(basic_function),
)

## -- End pasted text --
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

Out[1]: 
 message: b'It seems that the iterates diverge.'
 success: False
  status: 4
     fun: -1.1315674525145682e+21
       x: [-2.956e+20 -8.339e+19]
     nit: 34
    info:     status: 4
                   x: [-2.956e+20 -8.339e+19]
                   g: []
             obj_val: -1.1315674525145682e+21
              mult_g: []
            mult_x_L: [ 0.000e+00  0.000e+00]
            mult_x_U: [ 0.000e+00  0.000e+00]
          status_msg: b'It seems that the iterates diverge.'
    nfev: 35
    njev: 36

Solved by #197

@moorepants
Copy link
Collaborator

although, no idea about this: status_msg: b'It seems that the iterates diverge.', but the error about the args/kwargs is gone.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants