Skip to content

Commit

Permalink
Update my_embedding_layer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoyi-wu authored Aug 13, 2023
1 parent e582d73 commit b23757f
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions src/Model/RadFM/my_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vi
self.frame_patch_size = frame_patch_size
self.seg_channel = seg_channel

self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
self.bert_projection_fc = nn.Linear(768,vis_dim)
# self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
# self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
# self.bert_projection_fc = nn.Linear(768,vis_dim)

self.vision_encoder = ViT(
image_size = 512, # image size
Expand Down Expand Up @@ -80,44 +80,44 @@ def forward(self, text_input, vision_x, key_words_query = None):
vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S,F=1)

loss_matching = None
if key_words_query != None:
# key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding
query_words = [item for sublist in key_words_query for item in sublist]
query_words = list(set(query_words))
if len(query_words)>16:
random.shuffle(query_words)
query_words = query_words[0:16]
if query_words != []:
contrastive_labels = torch.zeros(B,len(query_words)) #B Q
for i,sublist in enumerate(key_words_query):
for j,item in enumerate(query_words):
if item in sublist:
contrastive_labels[i,j] = 1
contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device)
# if key_words_query != None:
# # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding
# query_words = [item for sublist in key_words_query for item in sublist]
# query_words = list(set(query_words))
# if len(query_words)>16:
# random.shuffle(query_words)
# query_words = query_words[0:16]
# if query_words != []:
# contrastive_labels = torch.zeros(B,len(query_words)) #B Q
# for i,sublist in enumerate(key_words_query):
# for j,item in enumerate(query_words):
# if item in sublist:
# contrastive_labels[i,j] = 1
# contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device)

with torch.no_grad():
query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt")
query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D
query_words_embedding = self.bert_projection_fc(query_words_embedding)
query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D
_,N,_ = query_words_embedding.shape
# with torch.no_grad():
# query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt")
# query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D
# query_words_embedding = self.bert_projection_fc(query_words_embedding)
# query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D
# _,N,_ = query_words_embedding.shape

image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。
image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d")
pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:]
# image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。
# image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d")
# pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:]

image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D
pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
query_words_embedding = query_words_embedding.transpose(0,1) # N B D
# image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D
# pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
# query_words_embedding = query_words_embedding.transpose(0,1) # N B D

oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding)
oo_embedding = oo_embedding.transpose(0,1) # B Q D
oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d')
oo_embedding = self.transformer_decoder_mlp(oo_embedding)
oo_embedding = self.cls_head(oo_embedding).mean(dim = -1)
oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q
# oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q
loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels)
# oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding)
# oo_embedding = oo_embedding.transpose(0,1) # B Q D
# oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d')
# oo_embedding = self.transformer_decoder_mlp(oo_embedding)
# oo_embedding = self.cls_head(oo_embedding).mean(dim = -1)
# oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q
# # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q
# loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels)

vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d)
#vision_x = checkpoint(self.perceiver,vision_x)
Expand Down

0 comments on commit b23757f

Please sign in to comment.