diff --git a/src/Model/RadFM/my_embedding_layer.py b/src/Model/RadFM/my_embedding_layer.py index 0c1b9b2..ad3ef4b 100644 --- a/src/Model/RadFM/my_embedding_layer.py +++ b/src/Model/RadFM/my_embedding_layer.py @@ -26,9 +26,9 @@ def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vi self.frame_patch_size = frame_patch_size self.seg_channel = seg_channel - self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") - self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") - self.bert_projection_fc = nn.Linear(768,vis_dim) + # self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") + # self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") + # self.bert_projection_fc = nn.Linear(768,vis_dim) self.vision_encoder = ViT( image_size = 512, # image size @@ -80,44 +80,44 @@ def forward(self, text_input, vision_x, key_words_query = None): vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S,F=1) loss_matching = None - if key_words_query != None: - # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding - query_words = [item for sublist in key_words_query for item in sublist] - query_words = list(set(query_words)) - if len(query_words)>16: - random.shuffle(query_words) - query_words = query_words[0:16] - if query_words != []: - contrastive_labels = torch.zeros(B,len(query_words)) #B Q - for i,sublist in enumerate(key_words_query): - for j,item in enumerate(query_words): - if item in sublist: - contrastive_labels[i,j] = 1 - contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) + # if key_words_query != None: + # # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding + # query_words = [item for sublist in key_words_query for item in sublist] + # query_words = list(set(query_words)) + # if len(query_words)>16: + # random.shuffle(query_words) + # query_words = query_words[0:16] + # if query_words != []: + # contrastive_labels = torch.zeros(B,len(query_words)) #B Q + # for i,sublist in enumerate(key_words_query): + # for j,item in enumerate(query_words): + # if item in sublist: + # contrastive_labels[i,j] = 1 + # contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) - with torch.no_grad(): - query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt") - query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D - query_words_embedding = self.bert_projection_fc(query_words_embedding) - query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D - _,N,_ = query_words_embedding.shape + # with torch.no_grad(): + # query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt") + # query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D + # query_words_embedding = self.bert_projection_fc(query_words_embedding) + # query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D + # _,N,_ = query_words_embedding.shape - image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。 - image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") - pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:] + # image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。 + # image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") + # pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:] - image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D - pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D - query_words_embedding = query_words_embedding.transpose(0,1) # N B D + # image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D + # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D + # query_words_embedding = query_words_embedding.transpose(0,1) # N B D - oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding) - oo_embedding = oo_embedding.transpose(0,1) # B Q D - oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') - oo_embedding = self.transformer_decoder_mlp(oo_embedding) - oo_embedding = self.cls_head(oo_embedding).mean(dim = -1) - oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q - # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q - loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) + # oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding) + # oo_embedding = oo_embedding.transpose(0,1) # B Q D + # oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') + # oo_embedding = self.transformer_decoder_mlp(oo_embedding) + # oo_embedding = self.cls_head(oo_embedding).mean(dim = -1) + # oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q + # # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q + # loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d) #vision_x = checkpoint(self.perceiver,vision_x)