Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why can't UBP be backproped? #490

Open
lockwo opened this issue Aug 16, 2024 · 7 comments
Open

Why can't UBP be backproped? #490

lockwo opened this issue Aug 16, 2024 · 7 comments
Labels
question User queries

Comments

@lockwo
Copy link
Contributor

lockwo commented Aug 16, 2024

In the docs it says "You do not need to backpropagate through the differential equation." for UBP usage. However, this doesn't seem to be theoretically necessary, you can just backprop through the solver with the added noise right? What's the motivation requiring this to be the case?

It says "Internally this operates by just sampling a fresh normal random variable over every interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.)" which makes sense inre adaptivity (since you need a brownian bridge or something of the like), but not for differentiation.

@patrick-kidger
Copy link
Owner

The difficulty is that backprop via RecursiveCheckpointAdjoint involves recomputing intermedaite states from checkpoints. I'm not confident that this will be bitwise identical to the original forward pass -- e.g. perhaps due to nondeterminism in convolutions.

Meanwhile, UBP casts float->int->key and uses that to sample its random noise. This means that floating point fluctuations will produce entirely different Brownian motion samples.

@patrick-kidger patrick-kidger added the question User queries label Aug 17, 2024
@lockwo
Copy link
Contributor Author

lockwo commented Aug 17, 2024

I guess it seems to me that these are solvable problems in engineering (I hope), but there isn't any theoretical limitations (like if we want to just store every sample made on the forward pass, and then on the backward pass reuse these samples/use a bridge because we have a lot of memory since neural networks aren't in play, that seems possible maybe not with a UBP, but with something simpler than a VBT). @frankschae can help articulate this better, and why we are interested in this direction (the motivation is basically, VBT is potentially overkill for us and might introduce slowdowns we don't need, if there is some distinction between UBP, VBT, and some object that does sampling differentiably, has some misc. stuff for weak solvers we are doing, and has bridges then that third object might be the most interesting).

@patrick-kidger
Copy link
Owner

Yup, this all makes sense.

If our controls consumed a step counter or something then one possible approach would just be to jr.fold_in the step index into a key, and that should operate reliably. Or if controls were stateful then we could jr.split a key each step. The current approach does what it does just because right now, a control consumes only the times at which it is evaluated.

I think this dovetails with the VBT discussion in #489 -- maybe we should think about modifying the way controls are handled, and if we pick our abstractions right we can tackle this issue whilst we're at it.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 21, 2024

I think in general, having some concept of state makes sense for general controls. Even if all existing controls can be done without it, we are working on more flexible controls that might take advantage of that (and it would allow more implementations of custom controls in an easier way). That being said, exactly how to do it is open. I think once we have a weak solver PR open to motivate this, it can provide more concrete examples.

@lockwo
Copy link
Contributor Author

lockwo commented Sep 16, 2024

As we've explored more, I think stateful controls can make a lot of sense and would be useful. For the limited usage we currently have, just leeching off the solver state and adding an argument to UBP is sufficient (see: https://github.com/lockwo/diffrax_extensions/blob/Owen/new-weak-internal/diffrax/_brownian/path.py#L114). A more robust and mainline implementation would be beneficial for a couple reasons:

  • It strengths the flexibility between UBP and VBT. UBP has a very narrow scope, but VBT comes with overhead that is unnecessary for some cases, so providing more of a continuum is nice
  • It removes 2.5 of the restrictions on UBP (since UBP could now have a default state that is just a key to be split), since it would be deterministic with respect to the key, could be differentiated and it can be used with adaptive solvers as long as they don't require BB calculations (e.g. previsible adaptive stepping solvers)
  • Allows for useful subclassing, it is not uncommon to want to sample a conditional noise value in the evaluate step, and as such, making things stateful allows for that
  • Requires no change for the default user
  • etc.

However, as I went about a draft implementation, one possible concern is that it would be a breaking change for a (maybe small) class of users (or it would require a some back checking). In the most abstract form, my approach was adding an argument to the AbstractPath evaluate and an init, then calling the init in integrate with all the others. For anyone using UBP or VBT, this is totally invisible, since in the default workflow, the user never actually calls evaluate on these classes (and even with the new changes, since their states are optional = None, this wouldn't be breaking). However, if you have a subclassed path, this would break (a possible fix would be to monkeypatch or wrap the provided class to just have nothing inits and accept None states, but if it worked it would still be breaking from a developer perspective since they aren't adhering to the abstractmethod specifications). Curious if you had any thoughts on if this approach (or if I am missing a non-breaking way to do this).

@patrick-kidger
Copy link
Owner

So I think this is probably too dangerous to implement in the main library itself. Right now integration happens with respect to the same path regardless of step size controller, and having a special case where that's not the case is definitely a footgun.

Fortunately, via the custom solver API you can write something that does this regardless!

Something like this should work:

class MyControl(AbstractPath):
    def evaluate(..., control_state):
        ...

class CallToEvaluate(AbstractPath):
    fn: Callable

    def evaluate(self, ...):
        return self.fn(...)

def is_control(x):
    return isinstance(x, MyControl)

def bind_control_state(terms, control_state):
    def _bind_control_state(x):
        if is_control(x):
            return CallToEvaluate(functools.partial(x.evaluate, control_state=control_state))
        else:
            return x
    return jtu.tree_map(_bind_control_state, terms, is_leaf=is_control)

class MySolver(AbstractWrappedSolver):
    def step(...):
        solver_state, control_state = solver_state
        terms = bind_control_state(terms, control_state)
        ... = self.solver.step(...)
        control_state = self.step_control_state(control_state)
        solver_state = (solver_state, control_state)
        return ...

Which is arguably the appropriate amount of fiddly for doing something questionable like this :)))

@lockwo
Copy link
Contributor Author

lockwo commented Sep 16, 2024

Yea, that's basically the implementation we have currently.

Which is arguably the appropriate amount of fiddly for doing something questionable like this :)))

I do agree, which is why we were looking into supporting less questionable ways. Although since it's too dangerous, we'll just formalize it more and put it in our extensions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants