Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fourier basis not working with Sashimi? #54

Closed
andrewliu2001 opened this issue Jul 13, 2022 · 4 comments
Closed

Fourier basis not working with Sashimi? #54

andrewliu2001 opened this issue Jul 13, 2022 · 4 comments

Comments

@andrewliu2001
Copy link

andrewliu2001 commented Jul 13, 2022

Hi,

I set d_model=128 in Sashimi and set the basis of SSHiPPOKernel in src/models/sequence/ss/standalone/s4.py to Fourier and I am getting the following error:

Traceback (most recent call last):
File "train_boundary_s4.py", line 52, in
model, train_accs, val_accs = train_model(model, (train_boundary, train_label), (val_boundary, val_label), DATASET_LENGTHS, config)
File "/data/al451/ml4fg-project/train_boundary.py", line 108, in train_model
train_loss, train_acc, train_pr, train_rec = run_one_epoch(True, train_dataloader, model, optimizer, device, math.ceil(len(train_dataset)/batch_size), epoch, config['train'])
File "/data/al451/ml4fg-project/train_boundary.py", line 44, in run_one_epoch
output = model(x_cnn) # forward pass
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/data/al451/ml4fg-project/s4_standalone.py", line 353, in forward
x, _ = layer(x)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/data/al451/ml4fg-project/s4_standalone.py", line 169, in forward
z, _ = self.layer(z)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/data/al451/ml4fg-project/src/models/sequence/ss/standalone/s4.py", line 866, in forward
y = self.output_linear(y)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 302, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/al451/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 298, in _conv_forward
return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [256, 128, 1], expected input[64, 256, 4000] to have 128 channels, but got 256 channels instead

Any tips on how to resolve this? It seems to me that the Fourier setting doubles the number of channels.

@albertfgu
Copy link
Contributor

Can you give me a more specific command that you're trying?

Also, can you try going to the v3 branch and using the standalone file there? On that branch I can run the command
python -m train wandb=null pipeline=mnist model=s4 model.layer.measure=fourier which works without problems.

@andrewliu2001
Copy link
Author

Ok thanks it works. Also, in your How To Train Your HiPPO paper you mention a theorem that says the HiPPO matrices correspond to Fourier basis in the limit where N->infinity. Have you conducted ablations on how large N should be?

@albertfgu
Copy link
Contributor

I haven't done extensive ablations on the size of N.

@albertfgu
Copy link
Contributor

I'll also add two notes:

  • We haven't done experiments using the new Fourier measure, so I have no idea how performant it is for audio
  • There's also a "mixed" measure that uses both the original S4 (the measure is called LegS) as well as the new FouT measure. I think toggling this one can be done by passing in measure=hippo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants