-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace multiple applications of single dispatch with multimethod
multiple dispatch
#295
Conversation
multimethod
multiple dispatchmultimethod
multiple dispatch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great! Could you also re-run the notebook too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I sympathize with the goal of this PR, but I'm skeptical about using multimethod
this way and a number of the other changes here. Maybe a simpler thing to do would be to stop distinguishing between subtypes of the Dynamics
interface (e.g. ODEDynamics
), and only dispatch on the solver.
@@ -102,10 +99,10 @@ def _batched_odeint( | |||
return yt if event_fn is None else (event_t, yt) | |||
|
|||
|
|||
@ode_simulate.register(TorchDiffEq) | |||
@_simulate.register(ODEDynamics, TorchDiffEq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multimethod
dispatches on all of the arguments to _simulate
, so what does this registration mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that this registration means that we dispatch only on the first two arguments of _simulate
, and not the remaining. This isn't documented in multimethod
from what I could find, but when I remove the arguments I get the default NotImplementedError
for _simulate
or the other dispatches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if that's precisely true (I would guess it implicitly provides a type for the remaining arguments, maybe object
or whatever is in the signature?), but I don't think it's a good idea to rely on undocumented behavior in upstream dependencies since it can change without warning in ways that are difficult or impossible to work around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'll try and debug why the default dispatch on the full collection of arguments isn't working as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 , after a bit of digging around I've found two options that rely less on undocumented behavior and pass tests. I have not been able to get multimethod
to work with parametric types fully specified.
Option 1: Pass unparameterized types to the *args
of register
and leave fully parametric types explicit in the function signature. E.g.
@_simulate.register(ODEDynamics, TorchDiffEq, State, torch.Tensor, torch.Tensor)
def torchdiffeq_ode_simulate(
dynamics: ODEDynamics[torch.Tensor, torch.Tensor],
solver: TorchDiffEq,
initial_state: State[torch.Tensor],
start_time: torch.Tensor,
end_time: torch.Tensor,
) -> State[torch.Tensor]:
timespan = torch.stack((start_time, end_time))
trajectory = _torchdiffeq_ode_simulate_inner(
dynamics, initial_state, timespan, **solver.odeint_kwargs
)
return trajectory[..., -1].to_state()
Option 2: Remove type parameters from function signature and use default register
without any *args
. E.g.
@_simulate.register
def torchdiffeq_ode_simulate(
dynamics: ODEDynamics,
solver: TorchDiffEq,
initial_state: State,
start_time: torch.Tensor,
end_time: torch.Tensor,
) -> State:
timespan = torch.stack((start_time, end_time))
trajectory = _torchdiffeq_ode_simulate_inner(
dynamics, initial_state, timespan, **solver.odeint_kwargs
)
return trajectory[..., -1].to_state()
Which of these two would you prefer? Alternatively, would you like something else?
Unrelated: I think there is some misuse of types a bit scattered throughout the module, so a separate type refactoring PR (ideally pair programmed) might be nice before merging into master.
I actually quite like the idea of dispatching on both the type of the dynamical system and the solver, as this clearly communicates to the user what restrictions we impose/assume about the dynamical system. While we don't impose any explicit restrictions on what's inside the @eb8680 , could you share a bit more about your skepticism? Is it just your comment above about the |
|
I generally agree with this point. Does this mean that onus is on the user to know what the restrictions are on any particular solver? Alternatively, are there static or runtime checks that each solver can implement that enforce their restrictions? I'm very nervous about a user e.g. adding
I see. I think I need to read up on the distinction here.
This could be addressed by passing in
Fair. It is a bit sad to have this be the only other dependency besides Pyro, and it's not obvious how stable it is.
Also fair. I've experienced this a bit already (see above workaround to your question about |
Keep in mind that this is currently the case -
We could implement simple validation effect handlers that get applied to |
Hmm, no, my point was that someone could write a signature with a finer type than |
Closing this PR in favor of an alternative that removes |
Prior to this small refactoring PR, dispatching on both the type of the dynamical system and the type of the solver involved chaining two
functools.singledispatchmethod
definitions in which the order of arguments (and thus the particular variable being dispatched on) was shuffled around. Instead, this PR has each backend interface methodsimulate
,simulate_trajectory
, andget_next_interruptions_dynamic
dispatch on both dynamical system and solver type jointly. This substantially reduces indirection, and the development burden when implementing a new backend or dynamical system formalism.