-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about derivatives #441
Comments
Investigating more, I found the following. When removing the |
@abaillod |
@mbkumar Thank you for your answer. Here I defined For example, I could define def pure_objective(gamma, current, k):
r = jnp.sqrt(gamma[:,0]**2+gamma[:,1]**2+gamma[:,2]**2)
return jnp.mean(r) * current**k
class TestObjective(Optimizable):
def __init__(self, coil, k):
self.coil = coil
self.k = k
self.J_jax = lambda gamma, current: pure_objective(gamma, current, self.k)
... |
@mbkumar any news about this? |
What is the dimension of the array returned by dJ()? |
…ives w.r.t. the current, see issue #441
To give further context to this issues, this is related to the branch You can have a look at these lines, where we attempt to take the derivative of a new objective, called Anyway I would appreciate if someone could give me an example of how to implement these derivative correclty, or help me debug it. |
Hi @abaillod, I am happy to dual debug with you, send me a calendar invite |
Thank you. I just sent you an email to find a time. |
@abaillod,
Sorry, I went radio silence after the last conversation. I got a lot of
work the last couple of weeks and couldn't respond. Please add me as an
optional attendee.
Bharat Medasani
…On Mon, Aug 19, 2024 at 10:00 AM abaillod ***@***.***> wrote:
Thank you. I just sent you an email to find a time.
—
Reply to this email directly, view it on GitHub
<#441 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA62VEESVPEXYMILATYRV4TZSH27JAVCNFSM6AAAAABL2U6W4KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJWGY2TEOBWGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Thanks to Andrew and Bharat, we found how to fix it. In a nutshell, I have to pass an array to class TestObjective(Optimizable):
def __init__(self, coil):
self.coil = coil
self.J_jax = lambda gamma, current: pure_objective(gamma, current)
self.dobj_by_dgamma_vjp = lambda gamma, current, v: vjp(lambda g: self.J_jax(g, current), gamma)[1](v)[0]
self.dobj_by_dcurrent_vjp = lambda gamma, current, v: vjp(lambda c: self.J_jax(gamma, c), current)[1](v)[0]
super().__init__(depends_on=[coil])
def J(self):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
return self.J_jax(gamma, current)
def vjp(self, v):
gamma = self.coil.curve.gamma()
current = self.coil.current.get_value()
grad0 = self.dobj_by_dgamma_vjp(gamma, current, v)
grad1 = jnp.array([self.dobj_by_dcurrent_vjp(gamma, current, v)])
return self.coil.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.current.vjp(grad1) |
Hi!
When writing an objective that depends on an
Optimizable
that has no local dofs, I run into some error (see below) which is quite obscure to me. I would appreciate some help!Here is a simple example of a class that generates the problem:
Then, for a given
coil
, if we doWe get
The text was updated successfully, but these errors were encountered: