-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch only cauchy kernel for easier test #9
Comments
These extensions are required because the algorithm requires functionality that is not supported in vanilla PyTorch. There is a pure PyTorch implementation of the kernel that is slower and uses more memory. To use it, in the line https://github.com/HazyResearch/state-spaces/blob/2af126108991d214fd82ffc1899f6d4e31a1eda3/src/models/sequence/ss/kernel.py#L322 replace |
That worked @albertfgu , thank you! |
No problem. I can also make the requisite changes to fall back on this - should be pretty simple. Do keep in mind again that this version is less efficient, so it might be useful for preliminary testing but we would recommend trying to set up the full environment if possible. Also, there is an excellent port of S4 to JAX here if that might be easier to set up: https://srush.github.io/annotated-s4/ |
I just added a fallback to the slow version. Can you test this? |
Works perfectly @albertfgu , thank you! |
Hello,
I loved Structured State Spaces, and obteined a fantastic performance compared to LSTM/SRU/Transformers.
I want to introduce S4 to some researchers and students, and the self-contained S4 layer is super great! However it requires the "cautchy kernel".
There are two versions of cautchy kernel, a cuda and a Pykeops version.
However extensions/cauchy requires cuda and Pykeops do not support Windows, and the target people have a very diverse number of environments.
Would be possible to HazyResearch team to implement a self-contained S4 layer including a simpler pytorch cautchy kernel ?
The text was updated successfully, but these errors were encountered: