From e5c67b2966e0bc8c47df8265c5795521e1f53c50 Mon Sep 17 00:00:00 2001 From: Li Bo Date: Sat, 23 Dec 2023 20:38:28 +0800 Subject: [PATCH] Update modeling_otter.py --- src/otter_ai/models/otter/modeling_otter.py | 286 +------------------- 1 file changed, 1 insertion(+), 285 deletions(-) diff --git a/src/otter_ai/models/otter/modeling_otter.py b/src/otter_ai/models/otter/modeling_otter.py index 83e7fbb7..c46d55d6 100755 --- a/src/otter_ai/models/otter/modeling_otter.py +++ b/src/otter_ai/models/otter/modeling_otter.py @@ -1,4 +1,3 @@ -import builtins import random import sys from typing import List, Optional @@ -6,16 +5,12 @@ import torch import torch.distributed as dist import torch.nn as nn -from accelerate import Accelerator from accelerate.hooks import AlignDevicesHook, add_hook_to_module from einops import rearrange, repeat from peft import LoraConfig, TaskType, get_peft_model from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer - -from pipeline.utils.modeling_value_head import AutoModelForCausalLMWithValueHead - +from transformers.models.auto import AutoTokenizer from ..falcon.modelling_RW import RWForCausalLM from ..mpt.modeling_mpt import MPTForCausalLM from ..mpt_redpajama.mosaic_gpt import MosaicGPT @@ -1045,282 +1040,3 @@ def generate( self.lang_encoder.clear_conditioned_layers() return output - - -class OtterForConditionalGenerationWithValueHead(OtterPreTrainedModel): - config_class = OtterConfig - - def __init__( - self, - config: OtterConfig, - ): - super().__init__(config) - ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo - if "llama" not in config.text_config._name_or_path: - if config.text_config.architectures[0] == "MPTForCausalLM": - text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct") - lang_encoder = MPTForCausalLM(config=config.text_config) - elif config.text_config.architectures[0] == "MosaicGPT": - text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate") - lang_encoder = MosaicGPT(config=config.text_config) - elif config.text_config.architectures[0] == "RWForCausalLM": - text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON") - lang_encoder = RWForCausalLM(config=config.text_config) - elif config.text_config.architectures[0] == "LlamaForCausalLM": - text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path) - lang_encoder = LlamaForCausalLM(config=config.text_config) - else: - text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path) - lang_encoder = LlamaForCausalLM(config=config.text_config) - vision_encoder = CLIPVisionModel(config=config.vision_config) - - text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "", ""]}) - if text_tokenizer.pad_token is None: - text_tokenizer.add_special_tokens({"pad_token": ""}) - self.text_tokenizer = text_tokenizer - self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1] - self.media_token_id = text_tokenizer.encode("")[-1] - self.lang_encoder_with_vhead = AutoModelForCausalLMWithValueHead(lang_encoder) - extend_instance(self.lang_encoder_with_vhead.pretrained_model, OtterLMMixin) - decoder_layers_attr_name = _infer_decoder_layers_attr_name(self.lang_encoder_with_vhead.pretrained_model) - self.lang_encoder_with_vhead.pretrained_model.set_decoder_layers_attr_name(decoder_layers_attr_name) - if self.lang_encoder_with_vhead.pretrained_model.__class__.__name__ == "LlamaForCausalLM": - self.lang_encoder_with_vhead.pretrained_model.resize_token_embeddings(len(text_tokenizer)) - - self.cross_attn_every_n_layers = config.cross_attn_every_n_layers - # use_media_placement_augmentation is strictly false for Otter model - self.use_media_placement_augmentation = False # config.use_media_placement_augmentation - self.max_num_frames = config.max_num_frames if hasattr(config, "max_num_frames") else None - - # Informative master_print statement - if self.max_num_frames is None or self.max_num_frames == 1: - master_print(f"The current model version is configured for Otter-Image with max_num_frames set to {self.max_num_frames}.") - else: - master_print(f"The current model version is configured for Otter-Video with a maximum of {self.max_num_frames} frames.") - - vision_encoder.output_tokens = True - self.vision_encoder = vision_encoder - - self.vis_dim = 1024 - self.perceiver = OtterPerceiverResampler(dim=self.vis_dim, max_num_frames=self.max_num_frames) - - self.lang_encoder_with_vhead.pretrained_model.init_otter( - media_token_id=self.media_token_id, - vis_hidden_size=self.vis_dim, - cross_attn_every_n_layers=self.cross_attn_every_n_layers, - use_media_placement_augmentation=self.use_media_placement_augmentation, - ) - - if "lora_config" in config.__dict__: - original_architecture_name = self.lang_encoder_with_vhead.pretrained_model.__class__.__name__ - master_print(f"Using LoRA with config:{config.lora_config}") - standard_modules = ["q_proj", "v_proj"] - lang_encoder_short_name = MODEL_CLASSES[config.text_config.architectures[0]] - model_to_lora_modules = { - "llama": standard_modules, - "opt": standard_modules, - "gptj": standard_modules, - "gpt_neox": ["query_key_value"], - "mpt": ["Wqkv"], - } - lora_config = LoraConfig( - r=config.lora_config["r"], - lora_alpha=config.lora_config["lora_alpha"], - lora_dropout=config.lora_config["lora_dropout"], - task_type=TaskType.CAUSAL_LM, - target_modules=model_to_lora_modules[lang_encoder_short_name], - ) - self.lang_encoder_with_vhead.pretrained_model = get_peft_model(self.lang_encoder_with_vhead.pretrained_model, lora_config) - self.lang_encoder_with_vhead.pretrained_model.master_print_trainable_parameters() - self.lang_encoder_with_vhead.pretrained_model.__class__.__name__ = f"{original_architecture_name}LoRA" - - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.lang_encoder.get_input_embeddings() - - def set_input_embeddings(self, new_embeddings): - self.lang_encoder.set_input_embeddings(new_embeddings) - - def get_output_embeddings(self) -> nn.Module: - return self.lang_encoder.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.lang_encoder.set_output_embeddings(new_embeddings) - - def get_image_encoder(self) -> nn.Module: - return self.vision_encoder - - def get_lang_encoder(self) -> nn.Module: - return self.lang_encoder - - def init_weights(self): - # Freeze all parameters in self.model - for param in self.parameters(): - param.requires_grad = False - - # Freeze all parameters in vision encoder - if "train_vision_encoder" in self.config.__dict__ and self.config.train_vision_encoder is True: - for param in self.vision_encoder.parameters(): - param.requires_grad = True - - # Freeze all parameters in lang encoders except gated_cross_attn_layers - if "train_lang_encoder" in self.config.__dict__ and self.config.train_lang_encoder is True: - for name, param in self.lang_encoder_with_vhead.named_parameters(): - param.requires_grad = True - - # Freeze all parameters in lang encoders except gated_cross_attn_layers - if "train_connector" in self.config.__dict__ and self.config.train_connector is True: - for ( - name, - param, - ) in self.lang_encoder_with_vhead.pretrained_model.named_parameters(): - if "gated_cross_attn_layer" in name: - param.requires_grad = True - for name, param in self.named_parameters(): - if "perceiver" in name: - param.requires_grad = True - - if "lora_config" in self.config.__dict__: - # Use another logic to unfreeze gated_cross_attn_layers and perceivers - master_print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder_with_vhead.pretrained_model.parameters() if p.requires_grad)) / 1e9:.3f} B") - - # Unfreeze LM input and output embeddings - self.lang_encoder_with_vhead.pretrained_model.get_input_embeddings().requires_grad_(True) - ## MPTForCausalLM is tied word embedding - if "LlamaForCausalLM" in self.lang_encoder_with_vhead.__class__.__name__: - self.lang_encoder_with_vhead.lm_head.requires_grad_(True) - # master_print("====================Model Grad Part====================") - total_params = 0 - for name, param in self.named_parameters(): - if param.requires_grad: - total_params += param.numel() - master_print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M") - master_print(f"Total Trainable param: {total_params / 1e9:.6f} B") - - def forward( - self, - vision_x: torch.Tensor, - lang_x: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cached_vision_x: bool = False, - clear_conditioned_layers: bool = True, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: bool = False, - **kwargs, - ) -> CausalLMOutputWithPast: - """ - Forward pass of Otter. - - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) with F=1 - lang_x (torch.Tensor): Language input ids - shape (B, T_txt) - attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. - labels (torch.Tensor, optional): Labels. Defaults to None. - clear_conditioned_layers: if True, clear the conditioned layers - once the foward pass is completed. Set this to false if the - same set of images will be reused in another subsequent - forward pass. - past_key_values: pre-computed values to pass to language model. - See past_key_values documentation in Hugging Face - CausalLM models. - use_cache: whether to use cached key values. See use_cache - documentation in Hugging Face CausalLM models. - """ - assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True." - - if use_cached_vision_x: - # Case: use cached; vision_x should be cached and other - # vision-related inputs should not be provided. - assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True." - assert self.lang_encoder_with_vhead.is_conditioned() - - else: - # Case: do not use caching (i.e. this is a standard forward pass); - self._encode_vision_x(vision_x=vision_x) - - output = self.lang_encoder_with_vhead( - input_ids=lang_x, - attention_mask=attention_mask, - labels=labels, - past_key_values=past_key_values, - use_cache=use_cache, - **kwargs, - ) - - if clear_conditioned_layers: - self.lang_encoder_with_vhead.clear_conditioned_layers() - - return output - - def _encode_vision_x(self, vision_x: torch.Tensor): - """ - Compute media tokens from vision input by passing it through vision encoder and conditioning language model. - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - Images in the same chunk are collated along T_img, and frames are collated along F - Currently only F=1 is supported (single-frame videos) - - rearrange code based on https://github.com/dhansmair/flamingo-mini - """ - - assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" - b, T, F = vision_x.shape[:3] - - vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") - vision_x = self.vision_encoder(vision_x)[0][:, 1:, :] - vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) - - vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) - - for layer in self.lang_encoder._get_decoder_layers(): - layer.condition_vis_x(vision_x) - - @torch.no_grad() - def generate( - self, - vision_x: torch.Tensor, - lang_x: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - **generate_kwargs, - ): - """ - Generate text conditioned on vision and language inputs. - - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - images in the same chunk are collated along T_img, and frames are collated along F - currently only F=1 is supported (single-frame videos) - lang_x (torch.Tensor): Language input - shape (B, T_txt) - max_length (int, optional): Maximum length of the output. Defaults to None. - attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. - Returns: - torch.Tensor: lang_x with generated tokens appended to it - """ - if hasattr(self, "_hf_hook"): - # add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x - hook = AlignDevicesHook( - execution_device=lang_x.device, - io_same_device=True, - place_submodules=False, - ) - add_hook_to_module(self.lang_encoder, hook) - num_beams = generate_kwargs.get("num_beams", 1) - if num_beams > 1: - vision_x = vision_x.repeat_interleave(num_beams, dim=0) - self._encode_vision_x(vision_x=vision_x) - output = self.lang_encoder.generate( - input_ids=lang_x, - attention_mask=attention_mask, - eos_token_id=self.eoc_token_id, - **generate_kwargs, - ) - - self.lang_encoder.clear_conditioned_layers() - return output