You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Just to keep track of those in a separate place, here are my remarks about the contents of your JOSS paper per se:
Dynamax supports canonical SSMs and allows the user to construct bespoke models as needed.
Can you give more details on how this is implemented API-wise? For instance, how generic can observation distributions be?
This question was a major motivation for my own take on HMMs, coded in Julia (https://joss.theoj.org/papers/10.21105/joss.06436). Hopefully I didn't misrepresent dynamax in the state of the art there.
Dynamax provides a unique combination of low-level inference algorithms and high-level modeling objects that can support a wide range of research applications in JAX.
Would it be possible to provide concrete examples of what is missing from the previous libraries? Obviously JAX support is a big aspect, since I know that some of them are coded in Numpy (hmmlearn) or PyTorch (pomegranate)
Parallel message passing routines that leverage GPU or TPU acceleration to perform message passing in sublinear time.
What do you mean by sublinear time? Isn't it just (roughly speaking) total sequential time divided by amount of parallelism?
While other libraries exist for state space modeling in Python [...], Dynamax provides a unique combination of low-level inference algorithms and high-level modeling objects that can support a wide range of research applications in JAX.
I would like to see a more detailed discussion of the state of the field: what do you think is missing from existing packages? Is it just that some of them are not written in JAX? Does writing in JAX allow you to do things you aren't able to achieve with e.g. PyTorch?
Just to keep track of those in a separate place, here are my remarks about the contents of your JOSS paper per se:
Can you give more details on how this is implemented API-wise? For instance, how generic can observation distributions be?
This question was a major motivation for my own take on HMMs, coded in Julia (https://joss.theoj.org/papers/10.21105/joss.06436). Hopefully I didn't misrepresent dynamax in the state of the art there.
Would it be possible to provide concrete examples of what is missing from the previous libraries? Obviously JAX support is a big aspect, since I know that some of them are coded in Numpy (hmmlearn) or PyTorch (pomegranate)
What do you mean by sublinear time? Isn't it just (roughly speaking) total sequential time divided by amount of parallelism?
I would like to see a more detailed discussion of the state of the field: what do you think is missing from existing packages? Is it just that some of them are not written in JAX? Does writing in JAX allow you to do things you aren't able to achieve with e.g. PyTorch?
Related:
The text was updated successfully, but these errors were encountered: