diff --git a/VisionLLMv2/visionllmv2/model/modeling_visionllmv2.py b/VisionLLMv2/visionllmv2/model/modeling_visionllmv2.py index 4e86539..2d3830a 100644 --- a/VisionLLMv2/visionllmv2/model/modeling_visionllmv2.py +++ b/VisionLLMv2/visionllmv2/model/modeling_visionllmv2.py @@ -162,6 +162,13 @@ def __init__(self, config: VisionLLMv2Config, v_hidden_size = self.v_hidden_size * 4 if self.use_pixelshuffle else self.v_hidden_size if vl_bridge_type == "linear": self.vl_bridge = nn.Linear(v_hidden_size, self.l_hidden_size) + elif vl_bridge_type == 'internvl_mlp' or vl_bridge_type == 'internvl': + self.vl_bridge = nn.Sequential( + nn.LayerNorm(v_hidden_size), + nn.Linear(v_hidden_size, self.l_hidden_size), + nn.GELU(), + nn.Linear(self.l_hidden_size, self.l_hidden_size) + ) else: mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu*', vl_bridge_type) if mlp_gelu_match: @@ -289,7 +296,10 @@ def freeze_region_encoder(self): self.region_encoder.requires_grad_(False) def freeze_emb_embeddings(self): - self.emb_embeddings.requires_grad_(False) + self.emb_embeddings_det.requires_grad_(False) + self.emb_embeddings_pose.requires_grad_(False) + self.emb_embeddings_gen.requires_grad_(False) + self.emb_embeddings_edit.requires_grad_(False) def get_vis_encoder(self): return getattr(self, 'vis_encoder', None) @@ -406,12 +416,17 @@ def forward( # TODO beautify (init images_pos within forward) if inputs_embeds is None: inputs_embeds = self.llm.get_input_embeddings()(input_ids) - - + # ------------------------------------------------------------ # NOTE: special operation for the [emb] tokens, this works well for both train and generation (use_cache=True) - # replace emb embeddings with predefined self.emb_embeddings, + # replace with tool emb_embeddings # and concat emb ids after special tool ids + if self.det_tool_id in input_ids[0] or self.seg_tool_id in input_ids[0] or self.grd_tool_id in input_ids[0] or \ + self.pose_tool_id in input_ids[0] or self.gen_tool_id in input_ids[0] or self.edit_tool_id in input_ids[0]: + if self.emb_token_id not in input_ids[0]: # for generation, generate tokens 1 by 1 + gap_len, gap_len_gen = 0, 0 + else: # for training, we have added the [EMB] tokens in the input_ids + gap_len, gap_len_gen = self.num_embs, self.num_embs_gen emb_ids = torch.tensor([x for x in range(self.emb_token_id, self.emb_token_id + self.num_embs)], dtype=torch.long).to(input_ids.device) emb_embeddings_det = self.emb_embeddings_det.weight.unsqueeze(0).repeat(inputs_embeds.shape[0], 1, 1) # [bs, num_embeds, c] emb_embeddings_pose = self.emb_embeddings_pose.weight.unsqueeze(0).repeat(inputs_embeds.shape[0], 1, 1) # [bs, num_embeds, c] @@ -438,7 +453,7 @@ def forward( [ cur_new_input_ids[: _start_pos + 1], emb_ids, - cur_new_input_ids[_start_pos + self.num_embs + 1 :] + cur_new_input_ids[_start_pos + gap_len + 1 :] ], dim=0 ) # repalce with emb embeddings @@ -446,7 +461,7 @@ def forward( [ cur_new_input_embeds[: _start_pos + 1], cur_emb_embeddings_det, - cur_new_input_embeds[_start_pos + self.num_embs + 1 :] + cur_new_input_embeds[_start_pos + gap_len + 1 :] ], dim=0 ).contiguous() # replace with self.emb_embeddings # using unipose @@ -456,7 +471,7 @@ def forward( [ cur_new_input_ids[: _start_pos + 1], emb_ids, - cur_new_input_ids[_start_pos + self.num_embs + 1 :] + cur_new_input_ids[_start_pos + gap_len + 1 :] ], dim=0 ) # repalce with emb embeddings @@ -464,7 +479,7 @@ def forward( [ cur_new_input_embeds[: _start_pos + 1], cur_emb_embeddings_pose, - cur_new_input_embeds[_start_pos + self.num_embs + 1 :] + cur_new_input_embeds[_start_pos + gap_len + 1 :] ], dim=0 ).contiguous() # replace with self.emb_embeddings # using sd @@ -474,7 +489,7 @@ def forward( [ cur_new_input_ids[: _start_pos + 1], emb_ids_gen, - cur_new_input_ids[_start_pos + self.num_embs_gen + 1 :] + cur_new_input_ids[_start_pos + gap_len_gen + 1 :] ], dim=0 ) # repalce with emb embeddings @@ -482,7 +497,7 @@ def forward( [ cur_new_input_embeds[: _start_pos + 1], cur_emb_embeddings_gen, - cur_new_input_embeds[_start_pos + self.num_embs_gen + 1 :] + cur_new_input_embeds[_start_pos + gap_len_gen + 1 :] ], dim=0 ).contiguous() # replace with self.emb_embeddings # using ip2p @@ -492,7 +507,7 @@ def forward( [ cur_new_input_ids[: _start_pos + 1], emb_ids_gen, - cur_new_input_ids[_start_pos + self.num_embs_gen + 1 :] + cur_new_input_ids[_start_pos + gap_len_gen + 1 :] ], dim=0 ) # repalce with emb embeddings @@ -500,7 +515,7 @@ def forward( [ cur_new_input_embeds[: _start_pos + 1], cur_emb_embeddings_edit, - cur_new_input_embeds[_start_pos + self.num_embs_gen + 1 :] + cur_new_input_embeds[_start_pos + gap_len_gen + 1 :] ], dim=0 ).contiguous() # replace with self.emb_embeddings # assert cur_new_input_embeds.shape[0] == cur_input_embeds.shape[0] @@ -518,28 +533,40 @@ def forward( attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], add_length))], dim=-1 ) - if input_ids.shape[1] != attention_mask.shape[1]: # generation - # having [emb] in this generation step - if input_ids.shape[1] != 1: + else: # useful for multi-round chat + total_length = input_ids.shape[1] + if attention_mask.shape[1] != total_length: + add_length = total_length - attention_mask.shape[1] attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], self.num_embs))], dim=-1 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], add_length))], dim=-1 ) + if past_key_values is not None: + # having [emb] in this generation step + if input_ids.shape[1] != 1: + if self.gen_tool_id in input_ids or self.edit_tool_id in input_ids: # visual generation + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], self.num_embs_gen))], dim=-1 + ) + else: # visual perception + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], self.num_embs))], dim=-1 + ) # ------------------------------------------------------------ # for the 1st step generation if past_key_values is None: with torch.no_grad(): if type(images) == list: # 'anyres' - images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - split_sizes = [image.shape[0] for image in images] - concat_images = torch.cat([image for image in images], dim=0) + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] # list[tensor], bs x [n_split, 3, h , w] + split_sizes = [image.shape[0] for image in images] # list[int] + concat_images = torch.cat([image for image in images], dim=0) # [n_all_image_splits, 3, h, w] image_forward_outs = self.vis_encoder(concat_images, output_hidden_states=True) else: # 'pad' # here images: after clip preprocess, [b, 3, h, w] image_forward_outs = self.vis_encoder(images, output_hidden_states=True) select_hidden_state_layer = getattr(self.config, "vis_output_layer", -2) select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] - image_features_ori = select_hidden_state[:, 1:].to(self.llm.dtype) + image_features_ori = select_hidden_state[:, 1:].to(self.llm.dtype) # [bs, img_len, 1024] or [n_all_image_splits, img_len, 1024] # pixel shuffle if self.use_pixelshuffle: @@ -547,67 +574,82 @@ def forward( image_features_ori = image_features_ori.reshape(image_features_ori.shape[0], h, w, -1) image_features_ori = self.pixel_shuffle(image_features_ori, scale_factor=0.5) image_features_ori = image_features_ori.reshape(image_features_ori.shape[0], -1, image_features_ori.shape[-1]) - image_features = self.vl_bridge(image_features_ori).to(inputs_embeds.dtype) + image_features = self.vl_bridge(image_features_ori).to(inputs_embeds.dtype) # [bs, img_len, 4096] or [n_all_image_splits, img_len, 4096] # replace image patch token for multi-modal/nlp datasets B, L, C = inputs_embeds.shape inputs_embeds = inputs_embeds.reshape(B * L, C) - selected = input_ids == self.imp_token_id - has_image = selected.sum(-1) != 0 + selected = input_ids == self.imp_token_id + has_image = selected.sum(-1) != 0 if type(images) == list: has_image = [has_image[i][None].repeat(split_sizes[i]) for i in range(B)] has_image = torch.cat(has_image, dim=0) - selected = selected.reshape(-1) - temp_embeds = torch.zeros_like(inputs_embeds) - temp_embeds[selected] = image_features[has_image].reshape(-1, C) - selected = selected.to(inputs_embeds.dtype).unsqueeze(-1) - inputs_embeds = inputs_embeds * (1 - selected) + temp_embeds * selected + selected = selected.reshape(-1) # [B*L] + # handle interleaved data when num() != num(images) + try: + vit_embeds = image_features[has_image].reshape(-1, C) + inputs_embeds[selected] = inputs_embeds[selected] * 0.0 + vit_embeds + ignore_flag = False + except Exception as e: + vit_embeds = image_features[has_image].reshape(-1, C) + print(f'warning: {e}, inputs_embeds[selected].shape={inputs_embeds[selected].shape}, ' + f'vit_embeds.shape={vit_embeds.shape}') + n_selected_token = selected.sum() + n_vit_token = vit_embeds.shape[0] + vit_embeds = vit_embeds.repeat(n_selected_token // n_vit_token, 1) if n_selected_token > n_vit_token \ + else vit_embeds[:n_vit_token] + inputs_embeds[selected] = inputs_embeds[selected] * 0.0 + vit_embeds + ignore_flag = True inputs_embeds = inputs_embeds.reshape(B, L, C) + # deal with region/non-region data joint train with zero2/3 if self.use_region_encoder: if regions is not None: # region data, list[tensor], bs x [n_region, h, w] + # concat regions in the batch dimension + # regions: list[tensor], bs x [n_region, h, w] num_regions = [len(regions_per_batch) for regions_per_batch in regions] all_regions = [regions_per_batch[:, None] for regions_per_batch in regions] - all_regions = torch.cat(all_regions, dim=0) + all_regions = torch.cat(all_regions, dim=0) # deal with model.generate() when num_beams > 1 num_beams = len(images) // len(regions) if num_beams > 1: num_regions = num_regions * num_beams all_regions = all_regions.repeat(num_beams, 1, 1, 1) + # repeat image and image_features, [bs, 3, h, w], [bs, 1024, c] if type(images) == list: # 'anyres', last split is the global image all_images = [images[i][-1][None].repeat_interleave(num_regions[i], dim=0) for i in range(len(images))] else: # 'pad' all_images = [images[i][None].repeat_interleave(num_regions[i], dim=0) for i in range(len(images))] # list[tensor], bs x [n_region, 3, h, w] - all_images = torch.cat(all_images, dim=0) + all_images = torch.cat(all_images, dim=0) # multi-scale of last 3 levels image features for region encoder mlvl_image_features = image_forward_outs.hidden_states[-3:] if type(images) == list: # 'anyres', last split is global image new_mlvl_image_features = [] - for image_features_per_level in mlvl_image_features: + for image_features_per_level in mlvl_image_features: image_features_per_level = torch.split(image_features_per_level, split_sizes, dim=0) image_features_per_level = [x[-1, 1:] for x in image_features_per_level] - image_features_per_level = torch.stack(image_features_per_level, dim=0) + image_features_per_level = torch.stack(image_features_per_level, dim=0) new_mlvl_image_features.append(image_features_per_level) mlvl_image_features = new_mlvl_image_features else: # 'pad' - mlvl_image_features = [mlvl_image_feature[:, 1:] for mlvl_image_feature in mlvl_image_features] + mlvl_image_features = [mlvl_image_feature[:, 1:] for mlvl_image_feature in mlvl_image_features] all_image_features = [] for image_features_per_level in mlvl_image_features: all_image_features_per_level = [image_features_per_level[i][None].repeat_interleave(num_regions[i], dim=0) for i in range(len(images))] all_image_features_per_level = torch.cat(all_image_features_per_level) - all_image_features.append(all_image_features_per_level) + all_image_features.append(all_image_features_per_level) # 3 x [n_all_regions, img_len, 1024] # all_images: [n_all_regions, 3, h, w] # all_regions: [n_all_regions, 1, h, w] # all_image_features: 3 x [n_all_regions, img_len, 1024] - all_region_features = self.region_encoder(all_images, all_regions, all_image_features) + all_region_features = self.region_encoder(all_images, all_regions, all_image_features) all_region_features = all_region_features.to(inputs_embeds.dtype) # replace token inputs_embeds = inputs_embeds.reshape(B*L, C) - region_mask = input_ids == self.reg_token_id - region_mask = region_mask.reshape(-1) - temp_embeds = torch.zeros_like(inputs_embeds) + region_mask = input_ids == self.reg_token_id + region_mask = region_mask.reshape(-1) + temp_embeds = torch.zeros_like(inputs_embeds) temp_embeds[region_mask] = all_region_features region_mask = region_mask.to(inputs_embeds.dtype).unsqueeze(-1) inputs_embeds = inputs_embeds * (1 - region_mask) + temp_embeds * region_mask @@ -615,16 +657,16 @@ def forward( else: # regions is None if type(images) == list: # 'anyres' H, W = images[0][0].shape[-2:] - dummy_all_images = torch.zeros((1, 3, H, W), dtype=inputs_embeds.dtype, device=inputs_embeds.device) - dummy_all_regions = torch.ones((1, 1, H, W), dtype=inputs_embeds.dtype, device=inputs_embeds.device) - dummy_all_image_features = torch.zeros_like(image_forward_outs.hidden_states[-1][:1, 1:]) - dummy_all_image_features = [dummy_all_image_features] * 3 + dummy_all_images = torch.zeros((1, 3, H, W), dtype=inputs_embeds.dtype, device=inputs_embeds.device) # [1, 3, h, w] + dummy_all_regions = torch.ones((1, 1, H, W), dtype=inputs_embeds.dtype, device=inputs_embeds.device) # [1, 1, h, w] + dummy_all_image_features = torch.zeros_like(image_forward_outs.hidden_states[-1][:1, 1:]) # [1, img_len, 1024] + dummy_all_image_features = [dummy_all_image_features] * 3 # multi-scale image features, 3 x [1, img_len, 1024] else: # 'pad' B, _, H, W = images.shape - dummy_all_images = torch.zeros_like(images) - dummy_all_regions = torch.ones((B, 1, H, W), dtype=images.dtype, device=images.device) - dummy_all_image_features = torch.zeros_like(image_forward_outs.hidden_states[-1][:, 1:]) - dummy_all_image_features = [dummy_all_image_features] * 3 + dummy_all_images = torch.zeros_like(images) # [b, 3, h, w] + dummy_all_regions = torch.ones((B, 1, H, W), dtype=images.dtype, device=images.device) # [b, 1, h, w] + dummy_all_image_features = torch.zeros_like(image_forward_outs.hidden_states[-1][:, 1:]) # [b, img_len, 1024] + dummy_all_image_features = [dummy_all_image_features] * 3 # multi-scale image features, 3 x [b, img_len, 1024] # dummy forward for region encoder dummy_all_region_features = self.region_encoder(dummy_all_images, dummy_all_regions, dummy_all_image_features) dummy_all_region_features = dummy_all_region_features.to(inputs_embeds.dtype) @@ -669,6 +711,8 @@ def forward( # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) + if ignore_flag: + loss = loss * 0.0 # ---------------------- atom tools ------------------------------ @@ -677,22 +721,23 @@ def forward( else: task = None + # det/grd/seg # gdino if self.use_gdino: if task in ['det', 'det_cap', 'grd', 'seg', 'count_text', 'count_visual', 'interactive'] and images_aug is not None: images_aug = nested_tensor_from_tensor_list(images_aug, size_divisibility=32) - pixel_values, pixel_mask = images_aug.tensors, ~images_aug.mask + pixel_values, pixel_mask = images_aug.tensors, ~images_aug.mask # [bs, 3, h, w], [bs, h, w] pixel_mask = pixel_values[:, 0, :, :] != 0 # valid is 1 # select the corresponding [EMB] hidden states as text_query batch_size, seq_len, hidden_size = inputs_embeds.shape - emb_select = (input_ids >= self.emb_token_id) & (input_ids <= self.emb_token_id + self.num_embs - 1) + emb_select = (input_ids >= self.emb_token_id) & (input_ids <= self.emb_token_id + self.num_embs - 1) # [bs, seq_len] # if have [EMB] tokens if emb_select.sum() != 0: num_patches = emb_select.sum(-1) // self.num_embs max_num_patches = num_patches.max() - text_query = torch.zeros((batch_size, max_num_patches, self.num_embs, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) - text_query_masks = torch.zeros(batch_size, max_num_patches, dtype=torch.bool, device=hidden_states.device) + text_query = torch.zeros((batch_size, max_num_patches, self.num_embs, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) + text_query_masks = torch.zeros(batch_size, max_num_patches, dtype=torch.bool, device=hidden_states.device) for batch_idx in range(batch_size): if num_patches[batch_idx] != 0: text_query_i = hidden_states[batch_idx, emb_select[batch_idx], :].reshape(-1, self.num_embs, hidden_size) @@ -714,7 +759,7 @@ def forward( images_aug = nested_tensor_from_tensor_list(images_aug, size_divisibility=32) # select the corresponding [EMB] hidden states as text_query batch_size, seq_len, hidden_size = inputs_embeds.shape - emb_select = (input_ids >= self.emb_token_id) & (input_ids <= self.emb_token_id + self.num_embs - 1) + emb_select = (input_ids >= self.emb_token_id) & (input_ids <= self.emb_token_id + self.num_embs - 1) # [bs, seq_len] # if have [EMB] tokens if emb_select.sum() != 0: num_patches = emb_select.sum(-1) // self.num_embs @@ -723,6 +768,7 @@ def forward( # this is for pose class patches max_num_kpt_patches = 100 + # [bs, max_num_obj/kpt_patches, num_embs, c], [bs, max_num_obj/kpt_patches] obj_querys = torch.zeros((batch_size, max_num_obj_patches, self.num_embs, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) obj_query_masks = torch.zeros((batch_size, max_num_obj_patches), dtype=torch.bool, device=hidden_states.device) kpt_querys = torch.zeros((batch_size, max_num_kpt_patches, self.num_embs, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device) @@ -733,9 +779,9 @@ def forward( if num_objcls != 0 and num_kpts != 0: text_query_i = hidden_states[batch_idx, emb_select[batch_idx], :].reshape(-1, self.num_embs, hidden_size) obj_querys[batch_idx, :num_objcls] = text_query_i[:num_objcls, ...] - obj_query_masks[batch_idx, :num_objcls] = 1 + obj_query_masks[batch_idx, :num_objcls] = 1 kpt_querys[batch_idx, :num_kpts] = text_query_i[num_objcls:, ...] - kpt_query_masks[batch_idx, :num_kpts] = 1 + kpt_query_masks[batch_idx, :num_kpts] = 1 text_query = dict( obj_querys=obj_querys, @@ -803,3 +849,7 @@ def forward( loss_ip2p=loss_ip2p.detach() if loss_ip2p is not None else None, ip2p_outputs=ip2p_outputs, ) + + +AutoConfig.register("visionllmv2", VisionLLMv2Config) +AutoModel.register(VisionLLMv2Config, VisionLLMv2Model)