-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extracting attention maps #5
Comments
@vibrant-galaxy i'm not actually sure if it will be too interpretable as it is, since attention is done along each axis separately, and information can take up to two steps to be routed. however, i think what may be worth trying (and I haven't built it into this repo yet) is to do axial attention and then expand the attention map of each axis along the other axis and then sum, softmax, aggregate values. perhaps it could lead to something more interpretable, as you would have the full attention map. would you be interested in trying this if i were to build it? |
That sounds like a good approach to get the full map. Yes, I am very much interested in trying that! |
I tried to do something like the below, but it actually goes out of memory when you try to expand and sum the pre-attention maps So basically I don't think it's possible lol, unless if you see a way to make it work import torch
from torch import einsum, nn
from einops import rearrange
class AxialAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_height_k = nn.Linear(dim, inner_dim, bias = False)
self.to_width_k = nn.Linear(dim, inner_dim, bias = False)
self.to_frame_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
heads, b, f, c, h, w = self.heads, *x.shape
x = rearrange(x, 'b f c h w -> b f h w c')
q = self.to_q(x)
k_height = self.to_height_k(x)
k_width = self.to_width_k(x)
k_frame = self.to_frame_k(x)
v = self.to_v(x)
q, k_height, k_width, k_frame, v = map(lambda t: rearrange(t, 'b f x y (h d) -> (b h) f x y d', h = heads), (q, k_height, k_width, k_frame, v))
q *= q.shape[-1] ** -0.5
sim_frame = einsum('b f h w d, b j h w d -> b f h w j', q, k_frame)
sim_frame = sim_frame[..., :, None, None].expand(-1, -1, -1, -1, -1, h, w)
sim_height = einsum('b f h w d, b f k w d -> b f h w k', q, k_height)
sim_height = sim_height[..., None, :, None].expand(-1, -1, -1, -1, f, -1, w)
sim_width = einsum('b f h w d, b f h l d -> b f h w l', q, k_width)
sim_width = sim_width[..., None, None, :].expand(-1, -1, -1, -1, f, h, -1)
sim = rearrange(sim_frame + sim_height + sim_width, 'b f h w j k l -> b f h w (j k l)')
attn = sim.softmax(dim = -1)
attn = rearrange(attn, 'b f h w (j k l) -> b f h w j k l', j = f, k = h, l = w)
out = einsum('b f h w j k l, b j k l d -> b f h w d', attn, v)
out = rearrange(out, '(b h) f x y d -> b f x y (h d)', h = heads)
out = self.to_out(out)
out = rearrange(out, 'b f x y d -> b f d x y')
return out, attn
layer = AxialAttention(dim = 16)
video = torch.randn(1, 5, 16, 32, 32)
out, attn = layer(video) |
Hi there,
Excellent project!
I'm using axial-attention with video (1, 5, 128, 256, 256) and
sum_axial_out=True
, and I wish to visualise the attention maps.Essentially, given my video, and two frame indices
frame_a_idx
andframe_b_idx
, I need to extract the attention map over frame_b to a chosen pixel (x
,y
) in frame_a (after the axial sum).My understanding is that I should be able to reshape the
dots
(after softmax) according to the permutations incalculate_permutations
, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:
In
SelfAttention.forward()
:In
PermuteToFrom.forward()
:However, I am unsure of how to un-permute the dots appropriately such that all resulting “axes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!
The text was updated successfully, but these errors were encountered: