Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when using other than 4 channel images with BYOL task #522

Merged
merged 1 commit into from
May 2, 2022

Conversation

gaetanbahl
Copy link
Contributor

To avoid shape errors, we pass the in_channels argument from BYOLTask to BYOL.

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label May 2, 2022
@isaaccorley
Copy link
Collaborator

Thanks for the catch! LGTM. Could you paste the error message you received so it can be referenced in the future?

@gaetanbahl
Copy link
Contributor Author

Thanks for the catch! LGTM. Could you paste the error message you received so it can be referenced in the future?

When I tried to train a ResNet50 with BYOLTask on EuroSAT data (13 bands), I had this error:

Traceback (most recent call last):
  File "train_ssl.py", line 14, in <module>
    task = BYOLTask(in_channels=13,
  File "/opt/conda/lib/python3.8/site-packages/torchgeo/trainers/byol.py", line 370, in __init__
    self.config_task()
  File "/opt/conda/lib/python3.8/site-packages/torchgeo/trainers/byol.py", line 350, in config_task
    self.model = BYOL(encoder, image_size=(256, 256))
  File "/opt/conda/lib/python3.8/site-packages/torchgeo/trainers/byol.py", line 288, in __init__
    self.encoder(torch.zeros(2, self.in_channels, *image_size))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torchgeo/trainers/byol.py", line 228, in forward
    _ = self.model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/resnet.py", line 283, in forward
    return self._forward_impl(x)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/resnet.py", line 266, in _forward_impl
    x = self.conv1(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 13, 7, 7], expected input[2, 4, 256, 256] to have 13 channels, but got 4 channels instead

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this fixes #337

@adamjstewart adamjstewart added this to the 0.2.2 milestone May 2, 2022
@adamjstewart adamjstewart merged commit cf2a0cc into microsoft:main May 2, 2022
@isaaccorley
Copy link
Collaborator

isaaccorley commented May 3, 2022

We were basically hardcoding in_channels=4 since we didn't pass it from BYOLTask to BYOL. I think this likely solves the issue using BYOL with datasets that do not use RGBN NAIP imagery.

@adamjstewart
Copy link
Collaborator

Unfortunately this didn't fix #337 but this is obviously a step in the right direction. If anyone figures out how to fix #337 please open a PR!

remtav pushed a commit to remtav/torchgeo that referenced this pull request May 26, 2022
@adamjstewart adamjstewart modified the milestones: 0.2.2, 0.3.0 Jul 2, 2022
@adamjstewart adamjstewart mentioned this pull request Jul 11, 2022
@gaetanbahl gaetanbahl deleted the fix-byol-channels branch September 25, 2022 07:30
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants