Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get derivative of final time with respect to a variable #451

Closed
nsteffen opened this issue Jun 26, 2024 · 2 comments
Closed

How to get derivative of final time with respect to a variable #451

nsteffen opened this issue Jun 26, 2024 · 2 comments
Labels
question User queries

Comments

@nsteffen
Copy link

Hello!

I am working on a problem where I want to propagate a trajectory using diffrax, use a discrete terminating event to end that trajectory, and then get the derivative of that final time with respect to a variable. Regardless of the problems I have tried to do this on, the gradient of the final time always ends up as zero. Below is an example where I model the simplified dynamics of a cannonball and I try to evaluate the gradient of the final time with respect to the coefficient of drag.

import jax
import jax.numpy as jnp
import diffrax as dfx
jax.config.update('jax_enable_x64', True)


def dynamics(t, y, CD):
    g = 9.80665 # m/s^2
    rho = 1.225 # kg/m^3
    S = 0.005 # m^2
    m = 1. # kg

    v = y[2]
    gam = y[3]
    
    D = (0.5*rho*v**2)*S*CD
    
    sin_gam = jnp.sin(gam)
    cos_gam = jnp.cos(gam)
    ydot = jnp.array([v*cos_gam,        # rdot
                      v*sin_gam,        # hdot
                      -D/m - g*sin_gam, # vdot
                      -(g*cos_gam)/v])  # gamdot

    return ydot


def event(state, **kwargs):
    h = state.y[1]
    return h < 0.


def obj(CD):
    t0 = 0.
    t1 = 20.
    dt0 = None
    y0 = jnp.array([0.,                 # r0
                    100.,               # h0
                    5.,                 # v0
                    jnp.deg2rad(45.)])  # gam0

    solver = dfx.Tsit5()
    stepsize_controller = dfx.ConstantStepSize()
    saveat = dfx.SaveAt(ts=jnp.linspace(t0, t1, 100))
    
    sol = dfx.diffeqsolve(dfx.ODETerm(dynamics),
                          solver,
                          t0,
                          t1,
                          0.1,
                          y0,
                          args=CD,
                          stepsize_controller=stepsize_controller,
                          max_steps=None,
                          saveat=saveat,
                          discrete_terminating_event=event,
                          adjoint=dfx.RecursiveCheckpointAdjoint(checkpoints=100)
                          )
    
    ts = sol.ts[jnp.where(jnp.isfinite(sol.ts))]    
    return ts[-1]


if __name__ == '__main__':
    CD = 0.5
    print(jax.grad(obj)(CD))

Is there a correct/better way to do this? Thanks in advance!

@patrick-kidger
Copy link
Owner

So this is because you have a "discrete" terminating event -- the event halts at the end of the step in which the event was triggered. As that's a discrete thing then there (correctly) is no gradient.

You may like to try #387, which is a more featureful approach to events. In particular this includes the ability to (a) have an event return a real number, for which the solve terminates where that number is zero, and (b) have that exact location determined using a root find.

Tagging @cholberg for visibility, but I believe that should give you gradients.

@patrick-kidger patrick-kidger added the question User queries label Jun 26, 2024
@nsteffen
Copy link
Author

Just got to testing that out, and it looks like your suggestion worked! Thanks for the help; I'm excited for this feature to be part of the main code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants