Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpViT on Swin Transformer #4

Open
fbragman opened this issue Aug 27, 2024 · 2 comments
Open

SpViT on Swin Transformer #4

fbragman opened this issue Aug 27, 2024 · 2 comments

Comments

@fbragman
Copy link

Hi,

Thanks a lot for your contribution!

I was wondering if you could release the SpViT implementation for the Swin Transformer?

Many thanks!

@PeiyanFlying
Copy link
Owner

Yes, theoretically, this working flow can also be used for Swin Transformer. At this time, we do not have enough time working on the related implementation. We would release in the future If we can make time.

Thanks for support and understanding.

@fbragman
Copy link
Author

fbragman commented Aug 27, 2024

Thank you for the quick response!

Do you plan on releasing the code used to generate results in Table 3 such as SPViT vs Swin-S and SPViT vs Swin-T?

I am currently implementing your method for my own Swin code. I would be grateful if you could help on the following comments.

In terms of the implementation, the ViT code for instance has

policy = torch.ones(B, init_n + 1, 1, dtype=x.dtype, device=x.device)

When you used Swin - did you generate a new policy for each layer since the number of patches changes for each successive layer?

Also - it is stated that the Token Selector is applied after PatchMerging for layers 2, 3, 4 of Swin. How does this affect how patch merging is performed if token packaging is applied?

Lastly - just to confirm my understanding of the code - this code snippet below corresponds to the token packaging?

              x2 = spatial_x * hard_drop_decision  # placehoder score [96, 196, 384]
              x2_sum = torch.sum(x2, dim=1)  # sum by the N dimension, output (B,N,C)-->(B,C) [96, 384]
              drop_len = torch.sum(hard_drop_decision)
              represent_token = (x2_sum/drop_len).reshape(B, 1, -1)
              x = torch.cat((x,represent_token), dim=1)

whilst the policy tensor is used in the attention module to ensure that the softmax is only done on tokens that are kept / repackaged?

Thanks for your guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants