Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InstanceNorm not working #9

Open
pimdh opened this issue Aug 6, 2023 · 2 comments
Open

InstanceNorm not working #9

pimdh opened this issue Aug 6, 2023 · 2 comments

Comments

@pimdh
Copy link

pimdh commented Aug 6, 2023

Hi! Thanks for this library :)
I'm trying to use InstanceNorm and it appears there's a bug.
When I run the following

irreps = BalancedIrreps(3, 20)
norm = InstanceNorm(irreps)
x = torch.randn(9, 20)
batch = torch.zeros(9, dtype=torch.long)
norm(x, batch)

I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 8
      6 x = torch.randn(9, 10)
      7 batch = torch.zeros(9, dtype=torch.long)
----> 8 norm(x, batch)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/segnn/segnn/instance_norm.py:85, in InstanceNorm.forward(self, input, batch)
     82 # For scalars first compute and subtract the mean
     83 if ir.l == 0:
     84     # Compute the mean
---> 85     field_mean = global_mean_pool(field, batch).reshape(-1, mul, 1)  # [batch, mul, 1]]
     86     # Subtract the mean
     87     field = field - field_mean[batch]

File /usr/local/lib/python3.8/dist-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
     61     return x.mean(dim=dim, keepdim=x.dim() <= 2)
     62 size = int(batch.max().item() + 1) if size is None else size
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:81, in scatter(src, index, dim, dim_size, reduce)
     78 count.scatter_add_(0, index, src.new_ones(src.size(dim)))
     79 count = count.clamp(min=1)
---> 81 index = broadcast(index, src, dim)
     82 out = src.new_zeros(size).scatter_add_(dim, index, src)
     84 return out [/](https://vscode-remote+ssh-002dremote-002bgatr.vscode-resource.vscode-cdn.net/) broadcast(count, out, dim)

File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:21, in broadcast(src, ref, dim)
     19 size = [1] * ref.dim()
     20 size[dim] = -1
---> 21 return src.view(size).expand_as(ref)

RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 1.  Target sizes: [9, 10, 1].  Tensor sizes: [1, 9, 1]

It appears this is because global_mean_pool from pytorch geometric does not support more than 2 dimensions. The solution could be to replace global_mean_pool(field, batch) with global_mean_pool(field.view(-1, mul), batch).

Cheers,
Pim

@RobDHess
Copy link
Owner

RobDHess commented Sep 1, 2023

Hi,

Thanks for alerting us to this, I will fix it in the near future.

Cheers,

Rob

@z4-qqq
Copy link

z4-qqq commented Sep 30, 2024

@RobDHess Hi!

Are there any changes about this bug? I also have it

Cheers,
Ivan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants