You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], line 8
6 x = torch.randn(9, 10)
7 batch = torch.zeros(9, dtype=torch.long)
----> 8 norm(x, batch)
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/segnn/segnn/instance_norm.py:85, in InstanceNorm.forward(self, input, batch)
82 # For scalars first compute and subtract the mean
83 if ir.l == 0:
84 # Compute the mean
---> 85 field_mean = global_mean_pool(field, batch).reshape(-1, mul, 1) # [batch, mul, 1]]
86 # Subtract the mean
87 field = field - field_mean[batch]
File /usr/local/lib/python3.8/dist-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
61 return x.mean(dim=dim, keepdim=x.dim() <= 2)
62 size = int(batch.max().item() + 1) if size is None else size
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')
File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:81, in scatter(src, index, dim, dim_size, reduce)
78 count.scatter_add_(0, index, src.new_ones(src.size(dim)))
79 count = count.clamp(min=1)
---> 81 index = broadcast(index, src, dim)
82 out = src.new_zeros(size).scatter_add_(dim, index, src)
84 return out [/](https://vscode-remote+ssh-002dremote-002bgatr.vscode-resource.vscode-cdn.net/) broadcast(count, out, dim)
File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:21, in broadcast(src, ref, dim)
19 size = [1] * ref.dim()
20 size[dim] = -1
---> 21 return src.view(size).expand_as(ref)
RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 1. Target sizes: [9, 10, 1]. Tensor sizes: [1, 9, 1]
It appears this is because global_mean_pool from pytorch geometric does not support more than 2 dimensions. The solution could be to replace global_mean_pool(field, batch) with global_mean_pool(field.view(-1, mul), batch).
Cheers,
Pim
The text was updated successfully, but these errors were encountered:
Hi! Thanks for this library :)
I'm trying to use InstanceNorm and it appears there's a bug.
When I run the following
I get the following error:
It appears this is because
global_mean_pool
from pytorch geometric does not support more than 2 dimensions. The solution could be to replaceglobal_mean_pool(field, batch)
withglobal_mean_pool(field.view(-1, mul), batch)
.Cheers,
Pim
The text was updated successfully, but these errors were encountered: