Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about derivatives #441

Closed
abaillod opened this issue Aug 1, 2024 · 11 comments
Closed

Question about derivatives #441

abaillod opened this issue Aug 1, 2024 · 11 comments
Assignees
Labels
question Further information is requested

Comments

@abaillod
Copy link
Contributor

abaillod commented Aug 1, 2024

Hi!

When writing an objective that depends on an Optimizable that has no local dofs, I run into some error (see below) which is quite obscure to me. I would appreciate some help!

Here is a simple example of a class that generates the problem:

def pure_objective(gamma, current):
    r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
    return jnp.mean(r) * current

class TestObjective(Optimizable):
    def __init__(self, coil):
        self.coil = coil
        self.J_jax = lambda gamma, current: pure_objective(gamma, current)
        self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
        self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0] 
        
        super().__init__(depends_on=[coil])


    def J(self):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return self.J_jax(gamma, current)


    def vjp(self, v):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return Derivative(
            {
                self.coil.curve: self.dobj_by_dgamma_vjp(gamma, current, v),
                self.coil.current: self.dobj_by_dcurrent_vjp(gamma, current, v)
            }
        )


def squared_objective(obj):
    return obj**2

class ObjectiveSquared(Optimizable):
    def __init__(self, obj):
        self.obj = obj

        self.J_jax = lambda x: squared_objective(x)
        self.thisgrad = lambda x: grad(self.J_jax)(x)
        
        super().__init__(depends_on=[obj])

    def J(self):
        return self.J_jax(self.obj.J())

    @derivative_dec
    def dJ(self):
        x = self.obj.J()
        grad = self.thisgrad( x )

        return self.obj.vjp( grad )

Then, for a given coil, if we do

tt = TestObjective(coil)
sq = ObjectiveSquared(tt)
sq.dJ()

We get

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[43], line 1
----> 1 sq.dJ()

File [~/Github/simsopt/src/simsopt/_core/derivative.py:217](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=216), in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
    215     return func(self, *args, **kwargs)
    216 else:
--> 217     return func(self, *args, **kwargs)(self)

File [~/Github/simsopt/src/simsopt/_core/derivative.py:185](http://localhost:8889/lab/tree/Projects/CriticalCurrent/~/Github/simsopt/src/simsopt/_core/derivative.py#line=184), in Derivative.__call__(self, optim, as_derivative)
    183 local_derivs = np.zeros(k.local_dof_size)
    184 for opt in k.dofs.dep_opts():
--> 185     local_derivs += self.data[opt][opt.local_dofs_free_status]
    186     keys.append(opt)
    187 derivs.append(local_derivs)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py:317](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/array.py#line=316), in ArrayImpl.__getitem__(self, idx)
    315   return lax_numpy._rewriting_take(self, idx)
    316 else:
--> 317   return lax_numpy._rewriting_take(self, idx)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4141), in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4136     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4137         dtypes.issubdtype(aval.dtype, np.integer) and
   4138         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4139         isinstance(arr.shape[0], int)):
   4140       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4144                unique_indices, mode, fill_value)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4219), in _split_index_for_jit(idx, shape)
   4216   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4218 # Expand any (concrete) boolean indices. We can then use advanced integer
   4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
   4222 leaves, treedef = tree_flatten(idx)
   4223 dynamic = [None] * len(leaves)

File [/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542](http://localhost:8889/opt/homebrew/Caskroom/miniconda/base/envs/simsopt2/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py#line=4541), in _expand_bool_indices(idx, shape)
   4540     expected_shape = shape[start: start + _ndim(i)]
   4541     if i_shape != expected_shape:
-> 4542       raise IndexError("boolean index did not match shape of indexed array in index "
   4543                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4544     out.extend(np.where(i))
   4545 else:

IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()
@abaillod abaillod added the question Further information is requested label Aug 1, 2024
@abaillod
Copy link
Contributor Author

abaillod commented Aug 1, 2024

Investigating more, I found the following. When removing the @derivative_dec decorator, sq.dJ() does not return an error. The error really only occurs when gathering all the partial derivatives together, and doing sq.dJ()(sq) (or including the derivative_dec decorator)

@mbkumar
Copy link
Collaborator

mbkumar commented Aug 6, 2024

@abaillod
I'll take a look at this in a couple of days.
One quick question. Why define self.J_Jax instead of directly using pure_objective?

@abaillod
Copy link
Contributor Author

abaillod commented Aug 8, 2024

@mbkumar Thank you for your answer.

Here I defined self.J_jax to mimic what I do in another, more complex class I am working on. In general objectives can have additional input parameters, and defining self.J_jax as a lambda function makes things more readable in my opinion.

For example, I could define

def pure_objective(gamma, current, k):
    r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
    return  jnp.mean(r) * current**k

class TestObjective(Optimizable):
    def __init__(self, coil, k):
        self.coil = coil
        self.k = k
        self.J_jax = lambda gamma, current: pure_objective(gamma, current, self.k)
        ...

@abaillod
Copy link
Contributor Author

@mbkumar any news about this?

@andrewgiuliani
Copy link
Contributor

andrewgiuliani commented Aug 13, 2024

What is the dimension of the array returned by dJ()?

@abaillod
Copy link
Contributor Author

This is what I get

image

The derivative w.r.t the current has not the right shape - do I understand this correctly?

abaillod added a commit that referenced this issue Aug 14, 2024
@abaillod
Copy link
Contributor Author

To give further context to this issues, this is related to the branch coil_forces in which @phuslage and I are working on a way to optimize for critical current.

You can have a look at these lines, where we attempt to take the derivative of a new objective, called CriticalCurrent. This thing works for any degree of freedom, excepted when taking the derivative w.r.t the coil current.

Anyway I would appreciate if someone could give me an example of how to implement these derivative correclty, or help me debug it.

@andrewgiuliani
Copy link
Contributor

Hi @abaillod, I am happy to dual debug with you, send me a calendar invite

@abaillod
Copy link
Contributor Author

Thank you. I just sent you an email to find a time.

@mbkumar
Copy link
Collaborator

mbkumar commented Aug 19, 2024 via email

@abaillod
Copy link
Contributor Author

Thanks to Andrew and Bharat, we found how to fix it. In a nutshell, I have to pass an array to simsopt.field.coil.Current.vjp, not a scalar. The class TestObjective should then be

class TestObjective(Optimizable):
    def __init__(self, coil):
        self.coil = coil
        self.J_jax = lambda gamma, current: pure_objective(gamma, current)
        self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
        self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0] 
        
        super().__init__(depends_on=[coil])


    def J(self):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        return self.J_jax(gamma, current)


    def vjp(self, v):
        gamma = self.coil.curve.gamma()
        current = self.coil.current.get_value()

        grad0 = self.dobj_by_dgamma_vjp(gamma, current, v)
        grad1 = jnp.array([self.dobj_by_dcurrent_vjp(gamma, current, v)])

        return self.coil.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.current.vjp(grad1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants