You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a question: according to the paper, the squash function only be used after the sum of prediction u-hat? and in this code, there is a squash after the primary capsule. I got really confused. class PrimaryCapsLayer(nn.Module): def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride): super(PrimaryCapsLayer, self).__init__() self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride) # input_channels = 256,output_caps = 32, output_dim = 8, kernel_size = 9, stride = 2 self.input_channels = input_channels self.output_caps = output_caps self.output_dim = output_dim def forward(self, input): out = self.conv(input) N, C, H, W = out.size() out = out.view(N, self.output_caps, self.output_dim, H, W) # will output N x OUT_CAPS x OUT_DIM out = out.permute(0, 1, 3, 4, 2).contiguous() out = out.view(out.size(0), -1, out.size(4)) out = squash(out) #####QUESTION?? return out
The text was updated successfully, but these errors were encountered:
I have a question: according to the paper, the squash function only be used after the sum of prediction u-hat? and in this code, there is a squash after the primary capsule. I got really confused.
class PrimaryCapsLayer(nn.Module): def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride): super(PrimaryCapsLayer, self).__init__() self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride) # input_channels = 256,output_caps = 32, output_dim = 8, kernel_size = 9, stride = 2 self.input_channels = input_channels self.output_caps = output_caps self.output_dim = output_dim def forward(self, input): out = self.conv(input) N, C, H, W = out.size() out = out.view(N, self.output_caps, self.output_dim, H, W) # will output N x OUT_CAPS x OUT_DIM out = out.permute(0, 1, 3, 4, 2).contiguous() out = out.view(out.size(0), -1, out.size(4)) out = squash(out) #####QUESTION?? return out
The text was updated successfully, but these errors were encountered: