From 8a679821d2b346df7c3d38db4392181b0c191014 Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Thu, 11 Jul 2024 16:42:58 +0000 Subject: [PATCH 1/9] modify vision encoder config Signed-off-by: Vivian Chen --- .../models/multimodal_llm/neva/neva_model.py | 12 ++++-------- nemo/collections/multimodal/parts/utils.py | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 92f13c28c287..594c66fe44aa 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -479,11 +479,9 @@ def __init__( def create_vision_encoder_and_processor(self, mm_cfg): # Initialize vision encoder and freeze it if mm_cfg.vision_encoder.get("from_hf", False): - if ( - "clip" in mm_cfg.vision_encoder.from_pretrained - or "vit" in mm_cfg.vision_encoder.from_pretrained - or "clip" in mm_cfg.vision_encoder.get("model_type", "") - ): + from transformers import AutoConfig + config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained) + if config.architectures[0] == "CLIPVisionModel": vision_encoder = CLIPVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -493,9 +491,7 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get( - "model_type", "" - ): + elif config.architectures[0] == "SiglipVisionModel": vision_encoder = SiglipVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 75804b8acd00..5f406994d371 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -534,17 +534,13 @@ def expand2square(pil_img, background_color): def create_image_processor(mm_cfg): if mm_cfg.vision_encoder.get("from_hf", False): - if ( - "clip" in mm_cfg.vision_encoder.from_pretrained - or "vit" in mm_cfg.vision_encoder.from_pretrained - or "clip" in mm_cfg.vision_encoder.get("model_type", "") - ): + from transformers import AutoConfig + config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained) + if config.architectures[0] == "CLIPVisionModel": image_processor = CLIPImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) - elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get( - "model_type", "" - ): + elif config.architectures[0] == "SiglipVisionModel": image_processor = SiglipImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) From 3e65e6a1150bc315c2394f90eb79ddd58da3cf4b Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Fri, 12 Jul 2024 05:52:10 +0000 Subject: [PATCH 2/9] add lita, vila engine build support and fix export api bugs Signed-off-by: Vivian Chen --- .../multimodal_llm/neva/neva_export.py | 1 + nemo/export/multimodal/build.py | 124 +++++++++++++----- nemo/export/tensorrt_mm_exporter.py | 5 +- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 4 + 4 files changed, 102 insertions(+), 32 deletions(-) diff --git a/examples/multimodal/multimodal_llm/neva/neva_export.py b/examples/multimodal/multimodal_llm/neva/neva_export.py index 2c081d00a003..6cf44084a564 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_export.py +++ b/examples/multimodal/multimodal_llm/neva/neva_export.py @@ -27,6 +27,7 @@ def main(cfg): tensor_parallel_size=cfg.infer.tensor_parallelism, max_input_len=cfg.infer.max_input_len, max_output_len=cfg.infer.max_output_len, + vision_max_batch_size=cfg.infer.vision_max_batch_size, max_batch_size=cfg.infer.max_batch_size, max_multimodal_len=cfg.infer.max_multimodal_len, dtype=cfg.model.precision, diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py index b21e5383b57f..10580afbfa96 100644 --- a/nemo/export/multimodal/build.py +++ b/nemo/export/multimodal/build.py @@ -37,7 +37,7 @@ def build_trtllm_engine( llm_checkpoint_path: str = None, model_type: str = "neva", llm_model_type: str = "llama", - tensor_parallel_size: int = 1, + tensor_parallelism_size: int = 1, max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 1, @@ -45,10 +45,11 @@ def build_trtllm_engine( dtype: str = "bfloat16", ): trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False) + visual_checkpoint_model = ['neva', 'lita', 'vila'] trt_llm_exporter.export( - nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path, + nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path, model_type=llm_model_type, - tensor_parallel_size=tensor_parallel_size, + tensor_parallelism_size=tensor_parallelism_size, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, @@ -75,7 +76,7 @@ def export_visual_wrapper_onnx( def build_trt_engine( - model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None + model_type, input_sizes, output_dir, vision_max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None ): part_name = 'visual_encoder' onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) @@ -110,8 +111,8 @@ def build_trt_engine( nBS = -1 nMinBS = 1 - nOptBS = max(nMinBS, int(max_batch_size / 2)) - nMaxBS = max_batch_size + nOptBS = max(nMinBS, int(vision_max_batch_size / 2)) + nMaxBS = vision_max_batch_size inputT = network.get_input(0) @@ -145,9 +146,10 @@ def build_trt_engine( def build_neva_engine( + model_type: str, model_dir: str, visual_checkpoint_path: str, - max_batch_size: int = 1, + vision_max_batch_size: int = 1, ): device = torch.device("cuda") if torch.cuda.is_available() else "cpu" # extract NeMo checkpoint @@ -155,6 +157,28 @@ def build_neva_engine( mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp) vision_config = nemo_config["mm_cfg"]["vision_encoder"] + + class DownSampleBlock(torch.nn.Module): + def forward(self, x): + vit_embeds = x + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x): + n, w, h, c = x.size() + if w % 2 == 1: + x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() + n, w, h, c = x.size() + if h % 2 == 1: + x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() + n, w, h, c = x.size() + x = x.view(n, w, int(h / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) + return x class VisionEncoderWrapper(torch.nn.Module): @@ -178,44 +202,83 @@ def forward(self, images): dtype = hf_config.torch_dtype # connector - assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" - vision_connector = torch.nn.Sequential( - torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), - torch.nn.GELU(), - torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), - ).to(dtype=dtype) - - key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" - for layer in range(0, 3, 2): - vision_connector[layer].load_state_dict( + #assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" + + if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu": + vision_connector = torch.nn.Sequential( + torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), + torch.nn.GELU(), + torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), + ).to(dtype=dtype) + + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + for layer in range(0, 3, 2): + vision_connector[layer].load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + } + ) + elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear": + vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True) + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + vision_connector.load_state_dict( { - 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), - 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + 'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype), } ) - + elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample": + vision_connector = torch.nn.Sequential( + DownSampleBlock(), + torch.nn.LayerNorm(vision_config["hidden_size"] * 4), + torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), + torch.nn.GELU(), + torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), + ).to(dtype=dtype) + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + for layer in [1, 2, 4]: + vision_connector[layer].load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + } + ) + + else: + raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}") + # export the whole wrapper wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype) - image_size = hf_config.vision_config.image_size + if model_type == "lita": + lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] + + print(hf_config) + + if model_type == 'lita': + image_size = hf_config.image_size + else: + image_size = hf_config.vision_config.image_size dummy_image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=device ) # dummy image shape [B, C, H, W] export_visual_wrapper_onnx(wrapper, dummy_image, model_dir) build_trt_engine( - "neva", + model_type, [3, image_size, image_size], model_dir, - max_batch_size, + vision_max_batch_size, dtype, image_size=image_size, + num_frames=lita_num_frames if model_type == "lita" else None, ) def build_video_neva_engine( model_dir: str, visual_checkpoint_path: str, - max_batch_size: int = 1, + vision_max_batch_size: int = 1, ): device = torch.device("cuda") if torch.cuda.is_available() else "cpu" # extract NeMo checkpoint @@ -279,7 +342,7 @@ def forward(self, images): "video-neva", [num_frames, 3, image_size, image_size], # [num_frames, 3, H, W] model_dir, - max_batch_size, + vision_max_batch_size, dtype, image_size=image_size, num_frames=num_frames, @@ -290,11 +353,12 @@ def build_visual_engine( model_dir: str, visual_checkpoint_path: str, model_type: str = "neva", - max_batch_size: int = 1, -): - if model_type == "neva": - build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + vision_max_batch_size: int = 1, +): + model_list = ['neva', 'lita', 'vila'] + if model_type in model_list: + build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size) elif model_type == "video-neva": - build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size) else: raise RuntimeError(f"Invalid model type {model_type}") diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py index 13bc82b39334..7eee48d1f9f2 100644 --- a/nemo/export/tensorrt_mm_exporter.py +++ b/nemo/export/tensorrt_mm_exporter.py @@ -91,6 +91,7 @@ def export( max_input_len: int = 4096, max_output_len: int = 256, max_batch_size: int = 1, + vision_max_batch_size: int = 1, max_multimodal_len: int = 3072, dtype: str = "bfloat16", delete_existing_files: bool = True, @@ -119,7 +120,7 @@ def export( llm_checkpoint_path=llm_checkpoint_path, model_type=model_type, llm_model_type=llm_model_type, - tensor_parallel_size=tensor_parallel_size, + tensor_parallelism_size=tensor_parallel_size, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, @@ -128,7 +129,7 @@ def export( ) visual_dir = os.path.join(self.model_dir, "visual_engine") - build_visual_engine(visual_dir, visual_checkpoint_path, model_type, max_batch_size) + build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size) if load_model: self._load() diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 1d473f497f51..004069bddfdd 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -140,6 +140,9 @@ def _update_config_entry(key, file_pattern): def copy_tokenizer_files(config, out_dir): + # workaround for temp dir + out_dir = Path(out_dir) + basenames = { "model": "tokenizer", "vocab_file": "vocab", @@ -232,6 +235,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat model = load_sharded_metadata(dist_ckpt_folder) nemo_model_config = unpacked_checkpoint_dir.model_config + if nemo_model_config["tokenizer"].get("library", None) == "huggingface": tokenizer = AutoTokenizer.from_pretrained( nemo_model_config["tokenizer"]["type"], From d87455e34f51af5f1226d7144e83641c9fddd5e4 Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Mon, 15 Jul 2024 23:34:23 +0000 Subject: [PATCH 3/9] add run example for vila, lita and vita Signed-off-by: Vivian Chen --- nemo/export/multimodal/build.py | 36 +-- nemo/export/multimodal/run.py | 424 ++++++++++++++++++++++++++++++-- 2 files changed, 422 insertions(+), 38 deletions(-) diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py index 10580afbfa96..922bddcc1846 100644 --- a/nemo/export/multimodal/build.py +++ b/nemo/export/multimodal/build.py @@ -45,7 +45,7 @@ def build_trtllm_engine( dtype: str = "bfloat16", ): trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False) - visual_checkpoint_model = ['neva', 'lita', 'vila'] + visual_checkpoint_model = ['neva', 'lita', 'vila', 'vita'] trt_llm_exporter.export( nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path, model_type=llm_model_type, @@ -76,12 +76,18 @@ def export_visual_wrapper_onnx( def build_trt_engine( - model_type, input_sizes, output_dir, vision_max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None + model_type, input_sizes, output_dir, vision_max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None, nemo_config=None ): part_name = 'visual_encoder' onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) engine_file = '%s/%s.engine' % (output_dir, part_name) config_file = '%s/%s' % (output_dir, "config.json") + nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml") + + # save the nemo config to the output directory/visual_engine + with open(nemo_config_file, 'w') as f: + yaml.dump(nemo_config, f) + logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) builder = trt.Builder(logger) @@ -155,7 +161,10 @@ def build_neva_engine( # extract NeMo checkpoint with tempfile.TemporaryDirectory() as temp: mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp) - + + # save the nemo config to the output directory/visual_engine + + vision_config = nemo_config["mm_cfg"]["vision_encoder"] class DownSampleBlock(torch.nn.Module): @@ -189,8 +198,7 @@ def __init__(self, encoder, connector): def forward(self, images): vision_x = self.encoder(pixel_values=images, output_hidden_states=True) - vision_x = vision_x.hidden_states[-2] - vision_x = vision_x[:, 1:] + vision_x = vision_x.hidden_states[-2] vision_x = self.connector(vision_x) return vision_x @@ -232,7 +240,7 @@ def forward(self, images): vision_connector = torch.nn.Sequential( DownSampleBlock(), torch.nn.LayerNorm(vision_config["hidden_size"] * 4), - torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), + torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True), torch.nn.GELU(), torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), ).to(dtype=dtype) @@ -250,15 +258,14 @@ def forward(self, images): # export the whole wrapper wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype) - if model_type == "lita": - lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] - - print(hf_config) - - if model_type == 'lita': + if model_type == "lita" or model_type == "vila": image_size = hf_config.image_size + if model_type == "lita": + lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] else: image_size = hf_config.vision_config.image_size + if model_type == "vita": + lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] dummy_image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=device ) # dummy image shape [B, C, H, W] @@ -271,7 +278,8 @@ def forward(self, images): vision_max_batch_size, dtype, image_size=image_size, - num_frames=lita_num_frames if model_type == "lita" else None, + num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None, + nemo_config=nemo_config, ) @@ -355,7 +363,7 @@ def build_visual_engine( model_type: str = "neva", vision_max_batch_size: int = 1, ): - model_list = ['neva', 'lita', 'vila'] + model_list = ['neva', 'lita', 'vila', 'vita'] if model_type in model_list: build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size) elif model_type == "video-neva": diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py index f94c2e3f3944..7c9e2c5a02b6 100644 --- a/nemo/export/multimodal/run.py +++ b/nemo/export/multimodal/run.py @@ -16,17 +16,21 @@ import json import os +import decord +import einops +import yaml import numpy as np import tensorrt as trt import tensorrt_llm import tensorrt_llm.profiler as profiler import torch +from torch.nn import functional as F from PIL import Image from tensorrt_llm import logger from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo from torchvision import transforms -from transformers import CLIPImageProcessor +from transformers import CLIPImageProcessor, AutoProcessor, AutoModel def trt_dtype_to_torch(dtype): @@ -67,6 +71,8 @@ def __init__(self, visual_engine_dir, llm_engine_dir): self.init_image_encoder(visual_engine_dir) self.init_tokenizer(llm_engine_dir) self.init_llm(llm_engine_dir) + if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita': + self.init_vision_preprocessor(visual_engine_dir) def init_tokenizer(self, llm_engine_dir): if os.path.exists(os.path.join(llm_engine_dir, 'huggingface_tokenizer')): @@ -74,6 +80,12 @@ def init_tokenizer(self, llm_engine_dir): self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer')) self.tokenizer.pad_token = self.tokenizer.eos_token + + if self.model_type == 'vita': + self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.vid_start_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.vid_end_id = self.tokenizer.convert_tokens_to_ids("") else: from sentencepiece import SentencePieceProcessor @@ -114,6 +126,12 @@ def batch_decode(self, x, **kwargs): self.tokenizer.pad_token_id = sp.pad_id() self.tokenizer.padding_side = "right" + + if self.model_type == 'lita': + self.tokenizer.im_start_id = sp.piece_to_id("") + self.tokenizer.im_end_id = sp.piece_to_id("") + self.tokenizer.vid_start_id = sp.piece_to_id("") + self.tokenizer.vid_end_id = sp.piece_to_id("") def init_image_encoder(self, visual_engine_dir): vision_encoder_path = os.path.join(visual_engine_dir, 'visual_encoder.engine') @@ -122,6 +140,20 @@ def init_image_encoder(self, visual_engine_dir): engine_buffer = f.read() logger.info(f'Creating session from engine {vision_encoder_path}') self.visual_encoder_session = Session.from_serialized_engine(engine_buffer) + + def init_vision_preprocessor(self, visual_encoder_dir): + with open(os.path.join(visual_encoder_dir, 'nemo_config.yaml'), 'r') as f: + self.nemo_config = yaml.safe_load(f) + + vision_config = self.nemo_config["mm_cfg"]["vision_encoder"] + + if self.model_type == 'lita': + self.image_processor = AutoProcessor.from_pretrained(vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True) + elif self.model_type == 'vila' or self.model_type == 'vita': + from transformers import SiglipImageProcessor + self.image_processor = SiglipImageProcessor.from_pretrained(vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True) + else: + raise ValueError(f"Invalid model type: {self.model_type}") def init_llm(self, llm_engine_dir): self.model = ModelRunner.from_dir( @@ -137,25 +169,25 @@ def video_preprocess(self, video_path): vr = VideoReader(video_path) num_frames = self.num_frames if num_frames == -1: - frames = [Image.fromarray(frame.asnumpy()[:, :, ::-1]).convert('RGB') for frame in vr] + frames = [Image.fromarray(frame.asnumpy()).convert('RGB') for frame in vr] else: # equally sliced frames into self.num_frames frames # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame num_frames = min(num_frames, len(vr)) indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int) - frames = [Image.fromarray(vr[idx].asnumpy()[:, :, ::-1]).convert('RGB') for idx in indices] + frames = [Image.fromarray(vr[idx].asnumpy()).convert('RGB') for idx in indices] if len(frames) < num_frames: frames += [frames[-1]] * (num_frames - len(frames)) elif isinstance(video_path, np.ndarray): num_frames = self.num_frames if num_frames == -1: - frames = [Image.fromarray(frame[:, :, ::-1]).convert('RGB') for frame in video_path] + frames = [Image.fromarray(frame).convert('RGB') for frame in video_path] else: # equally sliced frames into self.num_frames frames # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame num_frames = min(num_frames, video_path.shape[0]) indices = np.linspace(0, video_path.shape[0] - 1, num=num_frames, dtype=int) - frames = [Image.fromarray(video_path[idx][:, :, ::-1]).convert('RGB') for idx in indices] + frames = [Image.fromarray(video_path[idx]).convert('RGB') for idx in indices] if len(frames) < num_frames: frames += [frames[-1]] * (num_frames - len(frames)) else: @@ -168,26 +200,109 @@ def video_preprocess(self, video_path): tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision) ) # [num_frames, 3, H, W] return media_tensors.unsqueeze(0) # [1, num_frames, 3, H, W] + + def insert_tokens_by_index(self, input_ids, nemo_config): + im_start_id = self.tokenizer.im_start_id + im_end_id = self.tokenizer.im_end_id + vid_start_id = self.tokenizer.vid_start_id + vid_end_id = self.tokenizer.vid_end_id + num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] + + image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist() + input_ids = input_ids.squeeze().tolist() + offset = 0 + + # Insert the image tokens and corresponding start/end tokens + for i in range(num_frames): + idx = image_token_indices[1] + offset + input_ids.insert(idx + 1, im_end_id) + input_ids.insert(idx + 1, 0) + input_ids.insert(idx + 1, im_start_id) + offset += 3 + + # Insert the video start and end tokens around the video token + vid_idx = image_token_indices[1] + offset + input_ids.insert(vid_idx + 1, vid_end_id) + input_ids.insert(vid_idx + 1, 0) + input_ids.insert(vid_idx + 1, vid_start_id) + + input_ids.pop(image_token_indices[1]) + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + return input_ids def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): if not warmup: profiler.start("Vision") - visual_features, visual_atts = self.get_visual_features(image, attention_mask) - if not warmup: profiler.stop("Vision") - pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids - if post_prompt[0] is not None: - post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids - if self.model_type == 'video-neva': - length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] + if self.model_type == 'vila': + visual_features, visual_atts = self.get_visual_features(image, attention_mask) + input_ids = self.tokenizer_image_token( + batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) + batch_split_prompts = self.split_prompt_by_images(input_ids) + first_batch_split_prompts = batch_split_prompts[0] + # compute prompt length + visual length + length = sum([ids.shape[1] for ids in first_batch_split_prompts]) + if batch_size == 1 and len(image) > 1: + # mode 1: multiple image as a whole, flatten visual dims + length += visual_atts.shape[0] * visual_atts.shape[1] else: - length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] + # mode 2: multiple images individually (replicate prompt for each image) + length += visual_atts.shape[1] + + input_lengths = torch.IntTensor([length] * batch_size).to( + torch.int32) + input_ids, ptuning_args = self.setup_fake_prompts_vila( + batch_size, visual_features, first_batch_split_prompts, + input_lengths) + return input_ids, input_lengths, ptuning_args, visual_features + + elif self.model_type == 'lita'or self.model_type == 'vita': + for i, img in enumerate(image): + visual_features, visual_atts = self.get_visual_features(img, attention_mask) + visual_features = visual_features.unsqueeze(0) + im_tokens, vid_tokens = self.preprocess_lita_visual(visual_features, self.nemo_config) + input_ids = self.tokenizer_image_token( + batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) + input_ids = self.insert_tokens_by_index(input_ids, self.nemo_config) + batch_splits = self.split_prompt_by_images(input_ids) + first_batch_split_prompts = batch_splits[0] + length = sum([ids.shape[1] for ids in first_batch_split_prompts]) + + visual_input = [] + visual_input.append(im_tokens) + visual_input.append(vid_tokens) + + + # we need to update visual atts shape to match im_tokens shape and vid_tokens shape + im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1]) + visual_features = torch.cat([im_tokens, vid_tokens], dim=1) + visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device) + + if batch_size == 1: + length += visual_atts.shape[0] * visual_atts.shape[1] + + input_lengths = torch.IntTensor([length] * batch_size).to( + torch.int32) + input_ids, ptuning_args = self.setup_fake_prompts_vila( + batch_size, visual_input, first_batch_split_prompts, + input_lengths) + return input_ids, input_lengths, ptuning_args, visual_features else: - post_input_ids = None - length = pre_input_ids.shape[1] + visual_atts.shape[1] + visual_features, visual_atts = self.get_visual_features(image, attention_mask) + pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids + if post_prompt[0] is not None: + post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids + if self.model_type == 'video-neva': + length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] + else: + length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] + else: + post_input_ids = None + length = pre_input_ids.shape[1] + visual_atts.shape[1] input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) @@ -196,6 +311,59 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat ) return input_ids, input_lengths, ptuning_args, visual_features + + @staticmethod + def tokenizer_image_token(batch_size, + prompt, + tokenizer, + image_token_index=-200): + prompt_chunks = [ + tokenizer(chunk).input_ids for chunk in prompt.split("") + ] + + def insert_separator(X, sep): + return [ + ele for sublist in zip(X, [sep] * len(X)) for ele in sublist + ][:-1] + + input_ids = [] + offset = 0 + if (len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 + and prompt_chunks[0][0] == tokenizer.bos_token_id): + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, + [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + input_ids = torch.tensor(input_ids, dtype=torch.long) + input_ids[input_ids == image_token_index] = 0 + input_ids = input_ids.unsqueeze(0).expand(batch_size, -1) + + return input_ids + + def split_prompt_by_images(self, tensor): + batch_splits = [] + for batch in tensor: + # Find indices where value is zero () + zero_indices = (batch == 0).nonzero(as_tuple=False).squeeze(0) + # Add starting point for slicing + start_idx = 0 + splits = [] + for idx in zero_indices: + if start_idx != idx: # Ensure not slicing zero-length tensors + splits.append(batch[start_idx:idx].unsqueeze(0)) + start_idx = idx + 1 # Move start index past the zero + if start_idx < len( + batch): # Handle last segment if it's not zero-ending + splits.append(batch[start_idx:].unsqueeze(0)) + # Remove empty tensors resulting from consecutive zeros + splits = [split for split in splits if split.numel() > 0] + batch_splits.append(splits) + + return batch_splits + def generate( self, @@ -312,9 +480,119 @@ def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, inp ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths) return input_ids, ptuning_args + + def setup_fake_prompts_vila(self, batch_size, visual_features, + split_input_ids, input_lengths): + + if self.model_type == 'lita' or self.model_type == 'vita': + squeeze_img_tokens = visual_features[0].squeeze(0) + reshape_img_tokens = [t.unsqueeze(0) for t in squeeze_img_tokens] + visual_features = reshape_img_tokens + [visual_features[1]] + + fake_prompt_counter = self.model_config.vocab_size + if batch_size == 1: + # only check for multi-image inference (mode 1) + assert len(visual_features) <= len( + split_input_ids + ), "Unexpected number of visual features. Please check # in prompt and the #image files." + + input_ids = [] + if batch_size == 1: + input_ids = [split_input_ids[0]] + + if self.model_type == 'vila': + # mode 1: multiple image as a whole, concat all prompts together,
...
+                for idx, visual_feature in enumerate(visual_features):
+                    fake_prompt_id = torch.arange(
+                        fake_prompt_counter,
+                        fake_prompt_counter + visual_feature.shape[0])
+                    fake_prompt_counter += visual_feature.shape[0]
+                    fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                    input_ids.append(fake_prompt_id)
+                    
+                    # in case no post prompt
+                    if len(split_input_ids) > idx + 1:
+                        input_ids.append(split_input_ids[idx + 1])
+            elif self.model_type == 'lita' or self.model_type == 'vita':
+                for idx, visual_f in enumerate(visual_features):
+                    fake_prompt_id = torch.arange(
+                        fake_prompt_counter,
+                        fake_prompt_counter + visual_f.shape[1])
+                    fake_prompt_id = fake_prompt_id.reshape(visual_f.shape[1])
+                    fake_prompt_counter += visual_f.shape[1]
+                    fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                    input_ids.append(fake_prompt_id)
+                
+                    # in case no post prompt
+                    if len(split_input_ids) > idx + 1:
+                        input_ids.append(split_input_ids[idx + 1])
+                    
+
+        elif batch_size > 1 and self.model_type == 'vila':
+            # mode 2: each image have individual prompt, 

+            for idx, visual_feature in enumerate(visual_features):
+                input_ids.append(split_input_ids[0])
+                fake_prompt_id = torch.arange(
+                    fake_prompt_counter,
+                    fake_prompt_counter + visual_feature.shape[0])
+                fake_prompt_counter += visual_feature.shape[0]
+                fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                input_ids.append(fake_prompt_id)
+                if len(split_input_ids) > 1:
+                    input_ids.append(split_input_ids[1])
+
+        input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
+        input_ids = input_ids.reshape(batch_size, -1)
+        ptuning_args = self.ptuning_setup(visual_features, input_ids,
+                                              input_lengths)
+        return input_ids, ptuning_args
+    
+    def preprocess_lita_visual(self, visual_features, config):
+        
+        b, t, s, d = visual_features.shape
+        
+        num_frames = t
+        if 'visual_token_format' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end':
+            num_image_frames = min(num_frames, config['mm_cfg']['lita']['sample_frames'])
+            idx = np.round(np.linspace(0, num_frames - 1,
+                                        num_image_frames)).astype(int)
+
+            # Image and video features
+            im_features = visual_features[:, idx, ...]
+
+            vid_features = einops.reduce(visual_features,
+                                            'b t s d -> b t d',
+                                            'mean')
+            return im_features, vid_features
+        
+        elif 'lita_video_arch' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool':
+            pool_size = 2
+            selected_frames = np.round(
+                np.linspace(0, visual_features.shape[1] - 1,
+                            pool_size * pool_size)).astype(int)
+            s_tokens = visual_features[:, selected_frames, ...]
+            s_tokens = einops.rearrange(s_tokens,
+                                        'b t (h w) d -> (b t) d h w',
+                                        h=16,
+                                        w=16)
+            s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size)
+            s_tokens = einops.rearrange(s_tokens,
+                                        '(b t) d h w -> b (t h w) d',
+                                        b=b)
+
+            t_tokens = einops.reduce(visual_features, 'b t s d -> b t d',
+                                        'mean')
+            
+            return t_tokens, s_tokens
+
+        else:
+            raise ValueError(f'Invalid visual token format: {config["mm_cfg"]["lita"]["visual_token_format"]}')
 
     def ptuning_setup(self, prompt_table, input_ids, input_lengths):
         hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
+        
+        if self.model_type == 'lita' or self.model_type == 'vita':
+            prompt_table = torch.cat(prompt_table, dim=1)
         if prompt_table is not None:
             task_vocab_size = torch.tensor(
                 [prompt_table.shape[1]],
@@ -337,6 +615,85 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths):
             tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
 
         return [prompt_table, tasks, task_vocab_size]
+    
+    def expand2square_pt(self, images, background_color):
+        height, width = images.shape[-2:]
+        b = len(images)
+        background_color = torch.Tensor(background_color)
+        if width == height:
+            return images
+        elif width > height:
+            result = einops.repeat(background_color, 'c -> b c h w', b=b, h=width, w=width).clone()
+            paste_start = (width - height) // 2
+            paste_end = paste_start + height
+            result[:, :, paste_start:paste_end, :] = images
+            return result
+        else:
+            result = einops.repeat(background_color, 'c -> b c h w', b=b, h=height, w=height).clone()
+            paste_start = (height - width) // 2
+            paste_end = paste_start + width
+            result[:, :, :, paste_start:paste_end] = images
+            return result
+    
+    def load_video(self, config, video_path, processor, num_frames):
+        
+        decord.bridge.set_bridge('torch')
+        video_reader = decord.VideoReader(uri=video_path)
+        idx = np.round(
+            np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
+        frames = video_reader.get_batch(idx)
+        frames = einops.rearrange(frames, 't h w c -> t c h w')
+        
+        if config['data']['image_aspect_ratio'] == 'pad':
+            frames = self.expand2square_pt(frames, tuple(int(x*255) for x in processor.image_mean))
+        processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
+        
+        return processed_frames
+    
+    def get_num_sample_frames(self, config, vid_len):
+        if 'visual_token_format' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end':
+            max_frames = config['data']['num_frames']
+            if vid_len <= max_frames:
+                return vid_len
+            else:
+                subsample = int(np.ceil(float(vid_len) / max_frames))
+                return int(np.round(float(vid_len) / subsample))
+        else:
+            return config['mm_cfg']['lita']['sample_frames']
+    
+    def process_video(self, nemo_config, video_path, image_processor):
+        vid_len = len(decord.VideoReader(video_path))
+        num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
+        image = self.load_video(nemo_config, video_path, image_processor, num_sample_frames).unsqueeze(0).to(self.device, dtype=torch.bfloat16)
+        return image
+    
+    def process_image(self, image_file, image_processor, nemo_config, image_folder):
+        image_processor = image_processor
+        if isinstance(image_file, str):
+            if image_folder is not None:
+                image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+            else:
+                image = Image.open(image_file).convert("RGB")
+        else:
+            # image is stored in bytearray
+            image = image_file
+
+        crop_size = nemo_config['mm_cfg']['vision_encoder']['crop_size']
+        crop_size = tuple(crop_size)
+        image = image.resize(crop_size)
+        if nemo_config['data']['image_aspect_ratio'] == 'pad':
+            image = self.expand2square_pt(image, tuple(int(x * 255) for x in image_processor.image_mean))
+            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+        else:
+            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+        return image
+
+    def process_vila_img(self, images):
+        new_images = [self.process_image(image, self.image_processor, self.nemo_config, None) for image in images]
+
+        if all(x.shape == new_images[0].shape for x in new_images):
+            new_images = torch.stack(new_images, dim=0)
+        return new_images
 
     def setup_inputs(self, input_text, raw_image, batch_size):
         attention_mask = None
@@ -370,21 +727,39 @@ def setup_inputs(self, input_text, raw_image, batch_size):
                 f"\n{input_text}\nAssistant\nquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n"
                 ""
             )
+        elif self.model_type in ['vila', 'lita', 'vita']:
+            if self.model_type == "vila" or self.model_type == "lita":
+                pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
+                if input_text is None:
+                    input_text = "\n Please elaborate what you see in the images?"
+                post_prompt = input_text + " ASSISTANT:"
+            
+            elif self.model_type == "vita":
+                pre_prompt = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. USER: "
+                if input_text is None:
+                    input_text = "\n Please elaborate what you see in the images?"
+                post_prompt = input_text + " ASSISTANT:"
+            
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
+        
+        if self.model_type == 'lita' or self.model_type == 'vita':
+            image = self.process_video(self.nemo_config, raw_image, self.image_processor)
+        
+        if self.model_type == 'vila':
+            raw_image = [raw_image] * batch_size
+            image = self.process_vila_img(raw_image)
 
         # Repeat inputs to match batch size
         pre_prompt = [pre_prompt] * batch_size
         post_prompt = [post_prompt] * batch_size
-        if image.dim() == 5:
-            image = image.expand(batch_size, -1, -1, -1, -1).contiguous()
-        else:
-            image = image.expand(batch_size, -1, -1, -1).contiguous()
+        if self.model_type not in ['vila', 'lita', 'vita']:
+            if image.dim() == 5:
+                image = image.expand(batch_size, -1, -1, -1, -1).contiguous()
+            else:
+                image = image.expand(batch_size, -1, -1, -1).contiguous()
         image = image.to(self.device)
 
-        # Generate decoder_input_ids for enc-dec models
-        # Custom prompts can be added as:
-        # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
         decoder_input_ids = None
 
         return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask
@@ -473,9 +848,10 @@ def print_result(self, input_text, output_text, batch_size, num_beams, run_profi
         logger.info("---------------------------------------------------------")
 
     def load_test_media(self, input_media):
-        if self.model_type == "video-neva":
+        media_model = ["video-neva", "lita", "vita"]
+        if self.model_type in media_model:
             media = input_media
-        elif self.model_type == "neva":
+        elif self.model_type == "neva" or self.model_type == "vila":
             media = Image.open(input_media).convert('RGB')
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")

From df4228f8f04732598710ff27f8d6b6ba4056501b Mon Sep 17 00:00:00 2001
From: Vivian Chen 
Date: Tue, 16 Jul 2024 21:12:38 +0000
Subject: [PATCH 4/9] couple of changes for exporter

Signed-off-by: Vivian Chen 
---
 .../multimodal/multimodal_llm/neva/conf/neva_export.yaml | 5 +++--
 nemo/export/multimodal/build.py                          | 9 ++-------
 nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py        | 2 --
 scripts/deploy/multimodal/deploy_triton.py               | 4 +++-
 4 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
index 5a163b250566..1ab9bdbd6398 100644
--- a/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
+++ b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
@@ -6,10 +6,11 @@ infer:
   max_input_len: 4096
   max_output_len: 256
   max_multimodal_len: 3072
+  vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset
 
 model:
-  type: neva
+  type: neva #neva, video-neva, lita, vila, vita
   precision: bfloat16
   visual_model_path: /path/to/visual.nemo
   llm_model_path: /path/to/llm.nemo
-  llm_model_type: llama
+  llm_model_type: llama 
diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py
index 922bddcc1846..96ae1af065f0 100644
--- a/nemo/export/multimodal/build.py
+++ b/nemo/export/multimodal/build.py
@@ -84,7 +84,6 @@ def build_trt_engine(
     config_file = '%s/%s' % (output_dir, "config.json")
     nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")
     
-    # save the nemo config to the output directory/visual_engine
     with open(nemo_config_file, 'w') as f:
         yaml.dump(nemo_config, f)
     
@@ -160,10 +159,8 @@ def build_neva_engine(
     device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
     # extract NeMo checkpoint
     with tempfile.TemporaryDirectory() as temp:
-        mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp)
-    
-    # save the nemo config to the output directory/visual_engine
-    
+        temp_path = Path(temp)
+        mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)
     
     vision_config = nemo_config["mm_cfg"]["vision_encoder"]
     
@@ -210,8 +207,6 @@ def forward(self, images):
     dtype = hf_config.torch_dtype
 
     # connector
-    #assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
-    
     if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu":
         vision_connector = torch.nn.Sequential(
             torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
index 004069bddfdd..72003c4bb8e0 100644
--- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
+++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
@@ -140,8 +140,6 @@ def _update_config_entry(key, file_pattern):
 
 
 def copy_tokenizer_files(config, out_dir):
-    # workaround for temp dir
-    out_dir = Path(out_dir)
     
     basenames = {
         "model": "tokenizer",
diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py
index 1e339b3405cf..a0161d67f8c4 100755
--- a/scripts/deploy/multimodal/deploy_triton.py
+++ b/scripts/deploy/multimodal/deploy_triton.py
@@ -82,8 +82,9 @@ def get_args(argv):
     )
     parser.add_argument("-mil", "--max_input_len", default=4096, type=int, help="Max input length of the model")
     parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
-    parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the model")
+    parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
     parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
+    parser.add_argument("-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the vision model")
     args = parser.parse_args(argv)
     return args
 
@@ -131,6 +132,7 @@ def get_trt_deployable(args):
                 tensor_parallel_size=args.num_gpus,
                 max_input_len=args.max_input_len,
                 max_output_len=args.max_output_len,
+                vision_max_batch_size=args.vision_max_batch_size,
                 max_batch_size=args.max_batch_size,
                 max_multimodal_len=args.max_multimodal_len,
                 dtype=args.dtype,

From d1f05fdf7460dbffa4b7d293878abc76f8b98621 Mon Sep 17 00:00:00 2001
From: xuanzic 
Date: Wed, 17 Jul 2024 03:22:09 +0000
Subject: [PATCH 5/9] Apply isort and black reformatting

Signed-off-by: xuanzic 
---
 .../models/multimodal_llm/neva/neva_model.py  |   1 +
 nemo/collections/multimodal/parts/utils.py    |   1 +
 nemo/export/multimodal/build.py               |  27 ++-
 nemo/export/multimodal/run.py                 | 229 ++++++++----------
 .../trt_llm/nemo_ckpt_loader/nemo_file.py     |   3 +-
 scripts/deploy/multimodal/deploy_triton.py    |   4 +-
 6 files changed, 129 insertions(+), 136 deletions(-)

diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
index 5972ca2bc976..c5805c972ad0 100644
--- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
+++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
@@ -471,6 +471,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
         # Initialize vision encoder and freeze it
         if mm_cfg.vision_encoder.get("from_hf", False):
             from transformers import AutoConfig
+
             config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
             if config.architectures[0] == "CLIPVisionModel":
                 vision_encoder = CLIPVisionModel.from_pretrained(
diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py
index 5f406994d371..1fe932ec046c 100644
--- a/nemo/collections/multimodal/parts/utils.py
+++ b/nemo/collections/multimodal/parts/utils.py
@@ -535,6 +535,7 @@ def expand2square(pil_img, background_color):
 def create_image_processor(mm_cfg):
     if mm_cfg.vision_encoder.get("from_hf", False):
         from transformers import AutoConfig
+
         config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
         if config.architectures[0] == "CLIPVisionModel":
             image_processor = CLIPImageProcessor.from_pretrained(
diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py
index 96ae1af065f0..b51ceb4590ff 100644
--- a/nemo/export/multimodal/build.py
+++ b/nemo/export/multimodal/build.py
@@ -76,17 +76,24 @@ def export_visual_wrapper_onnx(
 
 
 def build_trt_engine(
-    model_type, input_sizes, output_dir, vision_max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None, nemo_config=None
+    model_type,
+    input_sizes,
+    output_dir,
+    vision_max_batch_size,
+    dtype=torch.bfloat16,
+    image_size=None,
+    num_frames=None,
+    nemo_config=None,
 ):
     part_name = 'visual_encoder'
     onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
     engine_file = '%s/%s.engine' % (output_dir, part_name)
     config_file = '%s/%s' % (output_dir, "config.json")
     nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")
-    
+
     with open(nemo_config_file, 'w') as f:
         yaml.dump(nemo_config, f)
-    
+
     logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)
 
     builder = trt.Builder(logger)
@@ -161,9 +168,9 @@ def build_neva_engine(
     with tempfile.TemporaryDirectory() as temp:
         temp_path = Path(temp)
         mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)
-    
+
     vision_config = nemo_config["mm_cfg"]["vision_encoder"]
-    
+
     class DownSampleBlock(torch.nn.Module):
         def forward(self, x):
             vit_embeds = x
@@ -172,7 +179,7 @@ def forward(self, x):
             vit_embeds = self.flat_square(vit_embeds)
             vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
             return vit_embeds
-        
+
         def flat_square(self, x):
             n, w, h, c = x.size()
             if w % 2 == 1:
@@ -195,7 +202,7 @@ def __init__(self, encoder, connector):
 
         def forward(self, images):
             vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
-            vision_x = vision_x.hidden_states[-2] 
+            vision_x = vision_x.hidden_states[-2]
             vision_x = self.connector(vision_x)
             return vision_x
 
@@ -247,10 +254,10 @@ def forward(self, images):
                     'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
                 }
             )
-        
+
     else:
         raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")
-    
+
     # export the whole wrapper
     wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
     if model_type == "lita" or model_type == "vila":
@@ -357,7 +364,7 @@ def build_visual_engine(
     visual_checkpoint_path: str,
     model_type: str = "neva",
     vision_max_batch_size: int = 1,
-):  
+):
     model_list = ['neva', 'lita', 'vila', 'vita']
     if model_type in model_list:
         build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size)
diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py
index 7c9e2c5a02b6..5c7466db7279 100644
--- a/nemo/export/multimodal/run.py
+++ b/nemo/export/multimodal/run.py
@@ -18,19 +18,19 @@
 
 import decord
 import einops
-import yaml
 import numpy as np
 import tensorrt as trt
 import tensorrt_llm
 import tensorrt_llm.profiler as profiler
 import torch
-from torch.nn import functional as F
+import yaml
 from PIL import Image
 from tensorrt_llm import logger
 from tensorrt_llm._utils import str_dtype_to_trt
 from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo
+from torch.nn import functional as F
 from torchvision import transforms
-from transformers import CLIPImageProcessor, AutoProcessor, AutoModel
+from transformers import AutoModel, AutoProcessor, CLIPImageProcessor
 
 
 def trt_dtype_to_torch(dtype):
@@ -80,7 +80,7 @@ def init_tokenizer(self, llm_engine_dir):
 
             self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer'))
             self.tokenizer.pad_token = self.tokenizer.eos_token
-            
+
             if self.model_type == 'vita':
                 self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("")
                 self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("")
@@ -126,7 +126,7 @@ def batch_decode(self, x, **kwargs):
             self.tokenizer.pad_token_id = sp.pad_id()
 
             self.tokenizer.padding_side = "right"
-            
+
             if self.model_type == 'lita':
                 self.tokenizer.im_start_id = sp.piece_to_id("")
                 self.tokenizer.im_end_id = sp.piece_to_id("")
@@ -140,18 +140,23 @@ def init_image_encoder(self, visual_engine_dir):
             engine_buffer = f.read()
         logger.info(f'Creating session from engine {vision_encoder_path}')
         self.visual_encoder_session = Session.from_serialized_engine(engine_buffer)
-    
+
     def init_vision_preprocessor(self, visual_encoder_dir):
         with open(os.path.join(visual_encoder_dir, 'nemo_config.yaml'), 'r') as f:
             self.nemo_config = yaml.safe_load(f)
-        
+
         vision_config = self.nemo_config["mm_cfg"]["vision_encoder"]
-        
+
         if self.model_type == 'lita':
-            self.image_processor = AutoProcessor.from_pretrained(vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True)
+            self.image_processor = AutoProcessor.from_pretrained(
+                vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True
+            )
         elif self.model_type == 'vila' or self.model_type == 'vita':
             from transformers import SiglipImageProcessor
-            self.image_processor = SiglipImageProcessor.from_pretrained(vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True)
+
+            self.image_processor = SiglipImageProcessor.from_pretrained(
+                vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True
+            )
         else:
             raise ValueError(f"Invalid model type: {self.model_type}")
 
@@ -200,14 +205,14 @@ def video_preprocess(self, video_path):
             tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision)
         )  # [num_frames, 3, H, W]
         return media_tensors.unsqueeze(0)  # [1, num_frames, 3, H, W]
-    
+
     def insert_tokens_by_index(self, input_ids, nemo_config):
         im_start_id = self.tokenizer.im_start_id
         im_end_id = self.tokenizer.im_end_id
         vid_start_id = self.tokenizer.vid_start_id
         vid_end_id = self.tokenizer.vid_end_id
         num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
-        
+
         image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist()
         input_ids = input_ids.squeeze().tolist()
         offset = 0
@@ -215,16 +220,16 @@ def insert_tokens_by_index(self, input_ids, nemo_config):
         # Insert the image tokens and corresponding start/end tokens
         for i in range(num_frames):
             idx = image_token_indices[1] + offset
-            input_ids.insert(idx + 1, im_end_id)  
-            input_ids.insert(idx + 1, 0)  
-            input_ids.insert(idx + 1, im_start_id)  
-            offset += 3  
-            
+            input_ids.insert(idx + 1, im_end_id)
+            input_ids.insert(idx + 1, 0)
+            input_ids.insert(idx + 1, im_start_id)
+            offset += 3
+
         # Insert the video start and end tokens around the video token
         vid_idx = image_token_indices[1] + offset
-        input_ids.insert(vid_idx + 1, vid_end_id)  
-        input_ids.insert(vid_idx + 1, 0) 
-        input_ids.insert(vid_idx + 1, vid_start_id)  
+        input_ids.insert(vid_idx + 1, vid_end_id)
+        input_ids.insert(vid_idx + 1, 0)
+        input_ids.insert(vid_idx + 1, vid_start_id)
 
         input_ids.pop(image_token_indices[1])
         input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
@@ -240,8 +245,7 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
 
         if self.model_type == 'vila':
             visual_features, visual_atts = self.get_visual_features(image, attention_mask)
-            input_ids = self.tokenizer_image_token(
-                batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
+            input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
             batch_split_prompts = self.split_prompt_by_images(input_ids)
             first_batch_split_prompts = batch_split_prompts[0]
             # compute prompt length + visual length
@@ -253,43 +257,39 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
                 # mode 2: multiple images individually (replicate prompt for each image)
                 length += visual_atts.shape[1]
 
-            input_lengths = torch.IntTensor([length] * batch_size).to(
-                torch.int32)
+            input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
             input_ids, ptuning_args = self.setup_fake_prompts_vila(
-                batch_size, visual_features, first_batch_split_prompts,
-                input_lengths)
+                batch_size, visual_features, first_batch_split_prompts, input_lengths
+            )
             return input_ids, input_lengths, ptuning_args, visual_features
-        
-        elif self.model_type == 'lita'or self.model_type == 'vita':
+
+        elif self.model_type == 'lita' or self.model_type == 'vita':
             for i, img in enumerate(image):
                 visual_features, visual_atts = self.get_visual_features(img, attention_mask)
-            visual_features = visual_features.unsqueeze(0) 
-            im_tokens, vid_tokens = self.preprocess_lita_visual(visual_features, self.nemo_config) 
-            input_ids = self.tokenizer_image_token(
-                batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
+            visual_features = visual_features.unsqueeze(0)
+            im_tokens, vid_tokens = self.preprocess_lita_visual(visual_features, self.nemo_config)
+            input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
             input_ids = self.insert_tokens_by_index(input_ids, self.nemo_config)
             batch_splits = self.split_prompt_by_images(input_ids)
             first_batch_split_prompts = batch_splits[0]
             length = sum([ids.shape[1] for ids in first_batch_split_prompts])
-            
+
             visual_input = []
             visual_input.append(im_tokens)
             visual_input.append(vid_tokens)
-            
-            
+
             # we need to update visual atts shape to match im_tokens shape and vid_tokens shape
             im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1])
             visual_features = torch.cat([im_tokens, vid_tokens], dim=1)
             visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device)
-            
+
             if batch_size == 1:
                 length += visual_atts.shape[0] * visual_atts.shape[1]
-            
-            input_lengths = torch.IntTensor([length] * batch_size).to(
-                torch.int32)
+
+            input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
             input_ids, ptuning_args = self.setup_fake_prompts_vila(
-                batch_size, visual_input, first_batch_split_prompts,
-                input_lengths)
+                batch_size, visual_input, first_batch_split_prompts, input_lengths
+            )
             return input_ids, input_lengths, ptuning_args, visual_features
         else:
             visual_features, visual_atts = self.get_visual_features(image, attention_mask)
@@ -297,7 +297,9 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
             if post_prompt[0] is not None:
                 post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids
                 if self.model_type == 'video-neva':
-                    length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1]
+                    length = (
+                        pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1]
+                    )
                 else:
                     length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1]
             else:
@@ -311,36 +313,27 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
         )
 
         return input_ids, input_lengths, ptuning_args, visual_features
-    
+
     @staticmethod
-    def tokenizer_image_token(batch_size,
-                              prompt,
-                              tokenizer,
-                              image_token_index=-200):
-        prompt_chunks = [
-            tokenizer(chunk).input_ids for chunk in prompt.split("")
-        ]
+    def tokenizer_image_token(batch_size, prompt, tokenizer, image_token_index=-200):
+        prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")]
 
         def insert_separator(X, sep):
-            return [
-                ele for sublist in zip(X, [sep] * len(X)) for ele in sublist
-            ][:-1]
+            return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
 
         input_ids = []
         offset = 0
-        if (len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0
-                and prompt_chunks[0][0] == tokenizer.bos_token_id):
+        if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
             offset = 1
             input_ids.append(prompt_chunks[0][0])
 
-        for x in insert_separator(prompt_chunks,
-                                  [image_token_index] * (offset + 1)):
+        for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
             input_ids.extend(x[offset:])
 
         input_ids = torch.tensor(input_ids, dtype=torch.long)
         input_ids[input_ids == image_token_index] = 0
         input_ids = input_ids.unsqueeze(0).expand(batch_size, -1)
-        
+
         return input_ids
 
     def split_prompt_by_images(self, tensor):
@@ -355,8 +348,7 @@ def split_prompt_by_images(self, tensor):
                 if start_idx != idx:  # Ensure not slicing zero-length tensors
                     splits.append(batch[start_idx:idx].unsqueeze(0))
                 start_idx = idx + 1  # Move start index past the zero
-            if start_idx < len(
-                    batch):  # Handle last segment if it's not zero-ending
+            if start_idx < len(batch):  # Handle last segment if it's not zero-ending
                 splits.append(batch[start_idx:].unsqueeze(0))
             # Remove empty tensors resulting from consecutive zeros
             splits = [split for split in splits if split.numel() > 0]
@@ -364,7 +356,6 @@ def split_prompt_by_images(self, tensor):
 
         return batch_splits
 
-
     def generate(
         self,
         pre_prompt,
@@ -480,15 +471,14 @@ def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, inp
         ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
 
         return input_ids, ptuning_args
-    
-    def setup_fake_prompts_vila(self, batch_size, visual_features,
-                                split_input_ids, input_lengths):
-        
+
+    def setup_fake_prompts_vila(self, batch_size, visual_features, split_input_ids, input_lengths):
+
         if self.model_type == 'lita' or self.model_type == 'vita':
             squeeze_img_tokens = visual_features[0].squeeze(0)
             reshape_img_tokens = [t.unsqueeze(0) for t in squeeze_img_tokens]
             visual_features = reshape_img_tokens + [visual_features[1]]
-            
+
         fake_prompt_counter = self.model_config.vocab_size
         if batch_size == 1:
             # only check for multi-image inference (mode 1)
@@ -499,42 +489,35 @@ def setup_fake_prompts_vila(self, batch_size, visual_features,
         input_ids = []
         if batch_size == 1:
             input_ids = [split_input_ids[0]]
-        
+
             if self.model_type == 'vila':
                 # mode 1: multiple image as a whole, concat all prompts together, 
...
                 for idx, visual_feature in enumerate(visual_features):
-                    fake_prompt_id = torch.arange(
-                        fake_prompt_counter,
-                        fake_prompt_counter + visual_feature.shape[0])
+                    fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
                     fake_prompt_counter += visual_feature.shape[0]
                     fake_prompt_id = fake_prompt_id.unsqueeze(0)
                     input_ids.append(fake_prompt_id)
-                    
+
                     # in case no post prompt
                     if len(split_input_ids) > idx + 1:
                         input_ids.append(split_input_ids[idx + 1])
             elif self.model_type == 'lita' or self.model_type == 'vita':
                 for idx, visual_f in enumerate(visual_features):
-                    fake_prompt_id = torch.arange(
-                        fake_prompt_counter,
-                        fake_prompt_counter + visual_f.shape[1])
+                    fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_f.shape[1])
                     fake_prompt_id = fake_prompt_id.reshape(visual_f.shape[1])
                     fake_prompt_counter += visual_f.shape[1]
                     fake_prompt_id = fake_prompt_id.unsqueeze(0)
                     input_ids.append(fake_prompt_id)
-                
+
                     # in case no post prompt
                     if len(split_input_ids) > idx + 1:
                         input_ids.append(split_input_ids[idx + 1])
-                    
 
         elif batch_size > 1 and self.model_type == 'vila':
             # mode 2: each image have individual prompt, 

             for idx, visual_feature in enumerate(visual_features):
                 input_ids.append(split_input_ids[0])
-                fake_prompt_id = torch.arange(
-                    fake_prompt_counter,
-                    fake_prompt_counter + visual_feature.shape[0])
+                fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
                 fake_prompt_counter += visual_feature.shape[0]
                 fake_prompt_id = fake_prompt_id.unsqueeze(0)
                 input_ids.append(fake_prompt_id)
@@ -543,46 +526,40 @@ def setup_fake_prompts_vila(self, batch_size, visual_features,
 
         input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
         input_ids = input_ids.reshape(batch_size, -1)
-        ptuning_args = self.ptuning_setup(visual_features, input_ids,
-                                              input_lengths)
+        ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
         return input_ids, ptuning_args
-    
+
     def preprocess_lita_visual(self, visual_features, config):
-        
+
         b, t, s, d = visual_features.shape
-        
+
         num_frames = t
-        if 'visual_token_format' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end':
+        if (
+            'visual_token_format' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
+        ):
             num_image_frames = min(num_frames, config['mm_cfg']['lita']['sample_frames'])
-            idx = np.round(np.linspace(0, num_frames - 1,
-                                        num_image_frames)).astype(int)
+            idx = np.round(np.linspace(0, num_frames - 1, num_image_frames)).astype(int)
 
             # Image and video features
             im_features = visual_features[:, idx, ...]
 
-            vid_features = einops.reduce(visual_features,
-                                            'b t s d -> b t d',
-                                            'mean')
+            vid_features = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
             return im_features, vid_features
-        
-        elif 'lita_video_arch' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool':
+
+        elif (
+            'lita_video_arch' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool'
+        ):
             pool_size = 2
-            selected_frames = np.round(
-                np.linspace(0, visual_features.shape[1] - 1,
-                            pool_size * pool_size)).astype(int)
+            selected_frames = np.round(np.linspace(0, visual_features.shape[1] - 1, pool_size * pool_size)).astype(int)
             s_tokens = visual_features[:, selected_frames, ...]
-            s_tokens = einops.rearrange(s_tokens,
-                                        'b t (h w) d -> (b t) d h w',
-                                        h=16,
-                                        w=16)
+            s_tokens = einops.rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=16, w=16)
             s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size)
-            s_tokens = einops.rearrange(s_tokens,
-                                        '(b t) d h w -> b (t h w) d',
-                                        b=b)
+            s_tokens = einops.rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b)
+
+            t_tokens = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
 
-            t_tokens = einops.reduce(visual_features, 'b t s d -> b t d',
-                                        'mean')
-            
             return t_tokens, s_tokens
 
         else:
@@ -590,7 +567,7 @@ def preprocess_lita_visual(self, visual_features, config):
 
     def ptuning_setup(self, prompt_table, input_ids, input_lengths):
         hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
-        
+
         if self.model_type == 'lita' or self.model_type == 'vita':
             prompt_table = torch.cat(prompt_table, dim=1)
         if prompt_table is not None:
@@ -615,7 +592,7 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths):
             tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
 
         return [prompt_table, tasks, task_vocab_size]
-    
+
     def expand2square_pt(self, images, background_color):
         height, width = images.shape[-2:]
         b = len(images)
@@ -634,24 +611,26 @@ def expand2square_pt(self, images, background_color):
             paste_end = paste_start + width
             result[:, :, :, paste_start:paste_end] = images
             return result
-    
+
     def load_video(self, config, video_path, processor, num_frames):
-        
+
         decord.bridge.set_bridge('torch')
         video_reader = decord.VideoReader(uri=video_path)
-        idx = np.round(
-            np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
+        idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
         frames = video_reader.get_batch(idx)
         frames = einops.rearrange(frames, 't h w c -> t c h w')
-        
+
         if config['data']['image_aspect_ratio'] == 'pad':
-            frames = self.expand2square_pt(frames, tuple(int(x*255) for x in processor.image_mean))
+            frames = self.expand2square_pt(frames, tuple(int(x * 255) for x in processor.image_mean))
         processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
-        
+
         return processed_frames
-    
+
     def get_num_sample_frames(self, config, vid_len):
-        if 'visual_token_format' in config['mm_cfg']['lita'] and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end':
+        if (
+            'visual_token_format' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
+        ):
             max_frames = config['data']['num_frames']
             if vid_len <= max_frames:
                 return vid_len
@@ -660,13 +639,17 @@ def get_num_sample_frames(self, config, vid_len):
                 return int(np.round(float(vid_len) / subsample))
         else:
             return config['mm_cfg']['lita']['sample_frames']
-    
+
     def process_video(self, nemo_config, video_path, image_processor):
         vid_len = len(decord.VideoReader(video_path))
         num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
-        image = self.load_video(nemo_config, video_path, image_processor, num_sample_frames).unsqueeze(0).to(self.device, dtype=torch.bfloat16)
+        image = (
+            self.load_video(nemo_config, video_path, image_processor, num_sample_frames)
+            .unsqueeze(0)
+            .to(self.device, dtype=torch.bfloat16)
+        )
         return image
-    
+
     def process_image(self, image_file, image_processor, nemo_config, image_folder):
         image_processor = image_processor
         if isinstance(image_file, str):
@@ -733,19 +716,19 @@ def setup_inputs(self, input_text, raw_image, batch_size):
                 if input_text is None:
                     input_text = "\n Please elaborate what you see in the images?"
                 post_prompt = input_text + " ASSISTANT:"
-            
+
             elif self.model_type == "vita":
                 pre_prompt = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. USER: "
                 if input_text is None:
                     input_text = "\n Please elaborate what you see in the images?"
                 post_prompt = input_text + " ASSISTANT:"
-            
+
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
-        
+
         if self.model_type == 'lita' or self.model_type == 'vita':
             image = self.process_video(self.nemo_config, raw_image, self.image_processor)
-        
+
         if self.model_type == 'vila':
             raw_image = [raw_image] * batch_size
             image = self.process_vila_img(raw_image)
diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
index 72003c4bb8e0..a55a6813ce46 100644
--- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
+++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
@@ -140,7 +140,7 @@ def _update_config_entry(key, file_pattern):
 
 
 def copy_tokenizer_files(config, out_dir):
-    
+
     basenames = {
         "model": "tokenizer",
         "vocab_file": "vocab",
@@ -233,7 +233,6 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
             model = load_sharded_metadata(dist_ckpt_folder)
             nemo_model_config = unpacked_checkpoint_dir.model_config
 
-            
             if nemo_model_config["tokenizer"].get("library", None) == "huggingface":
                 tokenizer = AutoTokenizer.from_pretrained(
                     nemo_model_config["tokenizer"]["type"],
diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py
index a0161d67f8c4..15bade1b7dd2 100755
--- a/scripts/deploy/multimodal/deploy_triton.py
+++ b/scripts/deploy/multimodal/deploy_triton.py
@@ -84,7 +84,9 @@ def get_args(argv):
     parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
     parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
     parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
-    parser.add_argument("-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the vision model")
+    parser.add_argument(
+        "-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the vision model"
+    )
     args = parser.parse_args(argv)
     return args
 

From f33bf64890735951bc981a02c9d83d4de3b5c6e7 Mon Sep 17 00:00:00 2001
From: Vivian Chen 
Date: Wed, 17 Jul 2024 04:39:25 +0000
Subject: [PATCH 6/9] address code scanning issues

Signed-off-by: Vivian Chen 
---
 nemo/export/multimodal/build.py | 1 +
 nemo/export/multimodal/run.py   | 4 ++--
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py
index b51ceb4590ff..cd4a18706946 100644
--- a/nemo/export/multimodal/build.py
+++ b/nemo/export/multimodal/build.py
@@ -259,6 +259,7 @@ def forward(self, images):
         raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")
 
     # export the whole wrapper
+    lita_num_frames = None
     wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
     if model_type == "lita" or model_type == "vila":
         image_size = hf_config.image_size
diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py
index 5c7466db7279..3861753c4a6a 100644
--- a/nemo/export/multimodal/run.py
+++ b/nemo/export/multimodal/run.py
@@ -30,7 +30,7 @@
 from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo
 from torch.nn import functional as F
 from torchvision import transforms
-from transformers import AutoModel, AutoProcessor, CLIPImageProcessor
+from transformers import AutoProcessor, CLIPImageProcessor
 
 
 def trt_dtype_to_torch(dtype):
@@ -651,7 +651,6 @@ def process_video(self, nemo_config, video_path, image_processor):
         return image
 
     def process_image(self, image_file, image_processor, nemo_config, image_folder):
-        image_processor = image_processor
         if isinstance(image_file, str):
             if image_folder is not None:
                 image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
@@ -680,6 +679,7 @@ def process_vila_img(self, images):
 
     def setup_inputs(self, input_text, raw_image, batch_size):
         attention_mask = None
+        image = None
 
         if self.model_type == "neva":
             image_size = self.image_size

From c0427d934f503681afb2e3ad8f1e2e8803d97494 Mon Sep 17 00:00:00 2001
From: Vivian Chen 
Date: Thu, 18 Jul 2024 01:16:11 +0000
Subject: [PATCH 7/9] add triton deployment for lita/vila/vita

Signed-off-by: Vivian Chen 
---
 nemo/deploy/multimodal/query_multimodal.py    | 21 ++++-
 nemo/export/multimodal/build.py               |  1 +
 nemo/export/multimodal/run.py                 | 76 ++++++++++++-------
 nemo/export/tensorrt_mm_exporter.py           |  5 +-
 .../trt_llm/nemo_ckpt_loader/nemo_file.py     |  1 -
 scripts/deploy/multimodal/deploy_triton.py    |  6 +-
 6 files changed, 74 insertions(+), 36 deletions(-)

diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py
index 9f747ff6d306..b2d720c5b276 100644
--- a/nemo/deploy/multimodal/query_multimodal.py
+++ b/nemo/deploy/multimodal/query_multimodal.py
@@ -56,11 +56,30 @@ def setup_media(self, input_media):
             vr = VideoReader(input_media)
             frames = [f.asnumpy() for f in vr]
             return np.array(frames)
-        elif self.model_type == "neva":
+        elif self.model_type == "lita" or self.model_type == "vita":
+            vr = VideoReader(input_media)
+            frames = [f.asnumpy() for f in vr]
+            subsample_len = self.frame_len(frames)
+            sub_frames = self.get_subsampled_frames(frames, subsample_len)
+            return np.array(sub_frames)
+        elif self.model_type == "neva" or self.model_type == "vila":
             media = Image.open(input_media).convert('RGB')
             return np.expand_dims(np.array(media), axis=0)
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
+    
+    def frame_len(self, frames):
+        max_frames = 256
+        if len(frames) <= max_frames:
+            return len(frames)
+        else:
+            subsample = int(np.ceil(float(len(frames)) / max_frames))
+            return int(np.round(float(len(frames)) / subsample))
+    
+    def get_subsampled_frames(self, frames, subsample_len):
+        idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
+        sub_frames = [frames[i] for i in idx]
+        return sub_frames
 
     def query(
         self,
diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py
index cd4a18706946..9d63a19248ce 100644
--- a/nemo/export/multimodal/build.py
+++ b/nemo/export/multimodal/build.py
@@ -22,6 +22,7 @@
 import tensorrt as trt
 import torch
 import yaml
+from pathlib import Path
 from tensorrt_llm.builder import Builder
 from transformers import AutoModel
 
diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py
index 3861753c4a6a..fb5ed5edae34 100644
--- a/nemo/export/multimodal/run.py
+++ b/nemo/export/multimodal/run.py
@@ -206,12 +206,11 @@ def video_preprocess(self, video_path):
         )  # [num_frames, 3, H, W]
         return media_tensors.unsqueeze(0)  # [1, num_frames, 3, H, W]
 
-    def insert_tokens_by_index(self, input_ids, nemo_config):
+    def insert_tokens_by_index(self, input_ids, num_frames):
         im_start_id = self.tokenizer.im_start_id
         im_end_id = self.tokenizer.im_end_id
         vid_start_id = self.tokenizer.vid_start_id
         vid_end_id = self.tokenizer.vid_end_id
-        num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
 
         image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist()
         input_ids = input_ids.squeeze().tolist()
@@ -264,27 +263,29 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
             return input_ids, input_lengths, ptuning_args, visual_features
 
         elif self.model_type == 'lita' or self.model_type == 'vita':
+            visual_input = []
             for i, img in enumerate(image):
                 visual_features, visual_atts = self.get_visual_features(img, attention_mask)
             visual_features = visual_features.unsqueeze(0)
-            im_tokens, vid_tokens = self.preprocess_lita_visual(visual_features, self.nemo_config)
+            im_tokens, vid_tokens, num_sample_frames = self.preprocess_lita_visual(visual_features, self.nemo_config)
+            visual_input.extend([im_tokens, vid_tokens])
+            
             input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
-            input_ids = self.insert_tokens_by_index(input_ids, self.nemo_config)
+            input_ids = self.insert_tokens_by_index(input_ids, num_sample_frames)
             batch_splits = self.split_prompt_by_images(input_ids)
             first_batch_split_prompts = batch_splits[0]
             length = sum([ids.shape[1] for ids in first_batch_split_prompts])
 
-            visual_input = []
-            visual_input.append(im_tokens)
-            visual_input.append(vid_tokens)
 
-            # we need to update visual atts shape to match im_tokens shape and vid_tokens shape
+            # Update visual atts shape to match im_tokens shape and vid_tokens shape
             im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1])
             visual_features = torch.cat([im_tokens, vid_tokens], dim=1)
             visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device)
 
             if batch_size == 1:
                 length += visual_atts.shape[0] * visual_atts.shape[1]
+            else:
+                raise ValueError("Batch size greater than 1 is not supported for LITA and VITA models")
 
             input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
             input_ids, ptuning_args = self.setup_fake_prompts_vila(
@@ -545,7 +546,7 @@ def preprocess_lita_visual(self, visual_features, config):
             im_features = visual_features[:, idx, ...]
 
             vid_features = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
-            return im_features, vid_features
+            return im_features, vid_features, num_image_frames
 
         elif (
             'lita_video_arch' in config['mm_cfg']['lita']
@@ -560,7 +561,7 @@ def preprocess_lita_visual(self, visual_features, config):
 
             t_tokens = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
 
-            return t_tokens, s_tokens
+            return t_tokens, s_tokens, pool_size**2
 
         else:
             raise ValueError(f'Invalid visual token format: {config["mm_cfg"]["lita"]["visual_token_format"]}')
@@ -612,18 +613,25 @@ def expand2square_pt(self, images, background_color):
             result[:, :, :, paste_start:paste_end] = images
             return result
 
-    def load_video(self, config, video_path, processor, num_frames):
-
-        decord.bridge.set_bridge('torch')
-        video_reader = decord.VideoReader(uri=video_path)
-        idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
-        frames = video_reader.get_batch(idx)
+    def load_video(self, config, video_path, processor, num_frames=None):
+        if isinstance(video_path, str):
+            decord.bridge.set_bridge('torch')
+            video_reader = decord.VideoReader(uri=video_path)
+            if num_frames is not None:
+                idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
+                frames = video_reader.get_batch(idx)
+            else:
+                frames = torch.cat([torch.tensor(f.asnumpy()) for f in video_reader])
+        elif isinstance(video_path, np.ndarray):
+            frames = torch.tensor(video_path, dtype=torch.float32)
+        
+        return self.preprocess_frames(frames, config, processor)
+    
+    def preprocess_frames(self, frames, config, processor):
         frames = einops.rearrange(frames, 't h w c -> t c h w')
-
         if config['data']['image_aspect_ratio'] == 'pad':
             frames = self.expand2square_pt(frames, tuple(int(x * 255) for x in processor.image_mean))
         processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
-
         return processed_frames
 
     def get_num_sample_frames(self, config, vid_len):
@@ -640,14 +648,21 @@ def get_num_sample_frames(self, config, vid_len):
         else:
             return config['mm_cfg']['lita']['sample_frames']
 
-    def process_video(self, nemo_config, video_path, image_processor):
-        vid_len = len(decord.VideoReader(video_path))
-        num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
-        image = (
-            self.load_video(nemo_config, video_path, image_processor, num_sample_frames)
-            .unsqueeze(0)
-            .to(self.device, dtype=torch.bfloat16)
-        )
+    def process_lita_video(self, nemo_config, video_path, image_processor):
+        if isinstance(video_path, str):
+            vid_len = len(decord.VideoReader(video_path))
+            num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
+            image = (
+                self.load_video(nemo_config, video_path, image_processor, num_sample_frames)
+                .unsqueeze(0)
+                .to(self.device, dtype=torch.bfloat16)
+            )
+        elif isinstance(video_path, np.ndarray):
+            image = (
+                self.load_video(nemo_config, video_path, image_processor)
+                .unsqueeze(0)
+                .to(self.device, dtype=torch.bfloat16)
+            )
         return image
 
     def process_image(self, image_file, image_processor, nemo_config, image_folder):
@@ -718,16 +733,19 @@ def setup_inputs(self, input_text, raw_image, batch_size):
                 post_prompt = input_text + " ASSISTANT:"
 
             elif self.model_type == "vita":
-                pre_prompt = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. USER: "
+                # llama3 prompt template
+                pre_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
+                                "You are able to understand the visual content that the user provides, "
+                                "and assist the user with a variety of tasks using natural language. <|start_header_id|>user<|end_header_id|>\n\n"""
                 if input_text is None:
                     input_text = "\n Please elaborate what you see in the images?"
-                post_prompt = input_text + " ASSISTANT:"
+                post_prompt = input_text + "<|start_header_id|>assistant<|end_header_id|>\n\n"
 
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
 
         if self.model_type == 'lita' or self.model_type == 'vita':
-            image = self.process_video(self.nemo_config, raw_image, self.image_processor)
+            image = self.process_lita_video(self.nemo_config, raw_image, self.image_processor)
 
         if self.model_type == 'vila':
             raw_image = [raw_image] * batch_size
diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py
index 7eee48d1f9f2..b0536a55f95f 100644
--- a/nemo/export/tensorrt_mm_exporter.py
+++ b/nemo/export/tensorrt_mm_exporter.py
@@ -193,9 +193,10 @@ def triton_infer_fn(self, **inputs: np.ndarray):
                 )
 
             infer_input = {"input_text": str_ndarray2list(inputs.pop("input_text")[0])}
-            if self.runner.model_type == "neva":
+            video_model_list = ["video-neva", "lita", "vita"]
+            if self.runner.model_type == "neva" or self.runner.model_type == "vila":
                 infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0]
-            elif self.runner.model_type == "video-neva":
+            elif self.runner.model_type in video_model_list:
                 infer_input["input_image"] = inputs.pop("input_media")[0]
             if "batch_size" in inputs:
                 infer_input["batch_size"] = inputs.pop("batch_size")[0][0]
diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
index a55a6813ce46..1d473f497f51 100644
--- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
+++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
@@ -140,7 +140,6 @@ def _update_config_entry(key, file_pattern):
 
 
 def copy_tokenizer_files(config, out_dir):
-
     basenames = {
         "model": "tokenizer",
         "vocab_file": "vocab",
diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py
index 15bade1b7dd2..ef3298b60075 100755
--- a/scripts/deploy/multimodal/deploy_triton.py
+++ b/scripts/deploy/multimodal/deploy_triton.py
@@ -48,8 +48,8 @@ def get_args(argv):
         "--model_type",
         type=str,
         required=True,
-        choices=["neva", "video-neva"],
-        help="Type of the model. neva and video-neva are only supported.",
+        choices=["neva", "video-neva", "lita", "vila", "vita"],
+        help="Type of the model that is supported.",
     )
     parser.add_argument(
         "-lmt",
@@ -85,7 +85,7 @@ def get_args(argv):
     parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
     parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
     parser.add_argument(
-        "-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the vision model"
+        "-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256"
     )
     args = parser.parse_args(argv)
     return args

From 62b5fa71e3115b7043df9d91bd1cc56544c84401 Mon Sep 17 00:00:00 2001
From: xuanzic 
Date: Thu, 18 Jul 2024 01:17:33 +0000
Subject: [PATCH 8/9] Apply isort and black reformatting

Signed-off-by: xuanzic 
---
 nemo/deploy/multimodal/query_multimodal.py | 4 ++--
 nemo/export/multimodal/build.py            | 2 +-
 nemo/export/multimodal/run.py              | 7 +++----
 scripts/deploy/multimodal/deploy_triton.py | 6 +++++-
 4 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py
index b2d720c5b276..ee3d24d4ec1e 100644
--- a/nemo/deploy/multimodal/query_multimodal.py
+++ b/nemo/deploy/multimodal/query_multimodal.py
@@ -67,7 +67,7 @@ def setup_media(self, input_media):
             return np.expand_dims(np.array(media), axis=0)
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
-    
+
     def frame_len(self, frames):
         max_frames = 256
         if len(frames) <= max_frames:
@@ -75,7 +75,7 @@ def frame_len(self, frames):
         else:
             subsample = int(np.ceil(float(len(frames)) / max_frames))
             return int(np.round(float(len(frames)) / subsample))
-    
+
     def get_subsampled_frames(self, frames, subsample_len):
         idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
         sub_frames = [frames[i] for i in idx]
diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py
index 9d63a19248ce..03afec176325 100644
--- a/nemo/export/multimodal/build.py
+++ b/nemo/export/multimodal/build.py
@@ -17,12 +17,12 @@
 import shutil
 import tarfile
 import tempfile
+from pathlib import Path
 from time import time
 
 import tensorrt as trt
 import torch
 import yaml
-from pathlib import Path
 from tensorrt_llm.builder import Builder
 from transformers import AutoModel
 
diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py
index fb5ed5edae34..c7797f169dc1 100644
--- a/nemo/export/multimodal/run.py
+++ b/nemo/export/multimodal/run.py
@@ -269,14 +269,13 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat
             visual_features = visual_features.unsqueeze(0)
             im_tokens, vid_tokens, num_sample_frames = self.preprocess_lita_visual(visual_features, self.nemo_config)
             visual_input.extend([im_tokens, vid_tokens])
-            
+
             input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
             input_ids = self.insert_tokens_by_index(input_ids, num_sample_frames)
             batch_splits = self.split_prompt_by_images(input_ids)
             first_batch_split_prompts = batch_splits[0]
             length = sum([ids.shape[1] for ids in first_batch_split_prompts])
 
-
             # Update visual atts shape to match im_tokens shape and vid_tokens shape
             im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1])
             visual_features = torch.cat([im_tokens, vid_tokens], dim=1)
@@ -624,9 +623,9 @@ def load_video(self, config, video_path, processor, num_frames=None):
                 frames = torch.cat([torch.tensor(f.asnumpy()) for f in video_reader])
         elif isinstance(video_path, np.ndarray):
             frames = torch.tensor(video_path, dtype=torch.float32)
-        
+
         return self.preprocess_frames(frames, config, processor)
-    
+
     def preprocess_frames(self, frames, config, processor):
         frames = einops.rearrange(frames, 't h w c -> t c h w')
         if config['data']['image_aspect_ratio'] == 'pad':
diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py
index ef3298b60075..d0bf8f10548a 100755
--- a/scripts/deploy/multimodal/deploy_triton.py
+++ b/scripts/deploy/multimodal/deploy_triton.py
@@ -85,7 +85,11 @@ def get_args(argv):
     parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
     parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
     parser.add_argument(
-        "-vmb", "--vision_max_batch_size", default=1, type=int, help="Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256"
+        "-vmb",
+        "--vision_max_batch_size",
+        default=1,
+        type=int,
+        help="Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256",
     )
     args = parser.parse_args(argv)
     return args

From a1ea8cc0ab2d8b5896e163faafc2a7d97f6e61fa Mon Sep 17 00:00:00 2001
From: Vivian Chen 
Date: Thu, 18 Jul 2024 02:26:34 +0000
Subject: [PATCH 9/9] fix code scan

Signed-off-by: Vivian Chen 
---
 nemo/export/multimodal/run.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py
index c7797f169dc1..1809a6fc8ce7 100644
--- a/nemo/export/multimodal/run.py
+++ b/nemo/export/multimodal/run.py
@@ -613,6 +613,7 @@ def expand2square_pt(self, images, background_color):
             return result
 
     def load_video(self, config, video_path, processor, num_frames=None):
+        frames = None
         if isinstance(video_path, str):
             decord.bridge.set_bridge('torch')
             video_reader = decord.VideoReader(uri=video_path)
@@ -648,6 +649,7 @@ def get_num_sample_frames(self, config, vid_len):
             return config['mm_cfg']['lita']['sample_frames']
 
     def process_lita_video(self, nemo_config, video_path, image_processor):
+        image = None
         if isinstance(video_path, str):
             vid_len = len(decord.VideoReader(video_path))
             num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)