Skip to content

Commit

Permalink
Add Lita, Vila and Vita TRTLLM export (#9734)
Browse files Browse the repository at this point in the history
* modify vision encoder config

Signed-off-by: Vivian Chen <[email protected]>

* add lita, vila engine build support and fix export api bugs

Signed-off-by: Vivian Chen <[email protected]>

* add run example for vila, lita and vita

Signed-off-by: Vivian Chen <[email protected]>

* couple of changes for exporter

Signed-off-by: Vivian Chen <[email protected]>

* Apply isort and black reformatting

Signed-off-by: xuanzic <[email protected]>

* address code scanning issues

Signed-off-by: Vivian Chen <[email protected]>

* add triton deployment for lita/vila/vita

Signed-off-by: Vivian Chen <[email protected]>

* Apply isort and black reformatting

Signed-off-by: xuanzic <[email protected]>

* fix code scan

Signed-off-by: Vivian Chen <[email protected]>

---------

Signed-off-by: Vivian Chen <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
Signed-off-by: xuanzic <[email protected]>
Co-authored-by: Vivian Chen <[email protected]>
Co-authored-by: xuanzic <[email protected]>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent 5546190 commit ab8988e
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 80 deletions.
5 changes: 3 additions & 2 deletions examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ infer:
max_input_len: 4096
max_output_len: 256
max_multimodal_len: 3072
vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset

model:
type: neva
type: neva #neva, video-neva, lita, vila, vita
precision: bfloat16
visual_model_path: /path/to/visual.nemo
llm_model_path: /path/to/llm.nemo
llm_model_type: llama
llm_model_type: llama
1 change: 1 addition & 0 deletions examples/multimodal/multimodal_llm/neva/neva_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(cfg):
tensor_parallel_size=cfg.infer.tensor_parallelism,
max_input_len=cfg.infer.max_input_len,
max_output_len=cfg.infer.max_output_len,
vision_max_batch_size=cfg.infer.vision_max_batch_size,
max_batch_size=cfg.infer.max_batch_size,
max_multimodal_len=cfg.infer.max_multimodal_len,
dtype=cfg.model.precision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,10 @@ def __init__(
def create_vision_encoder_and_processor(self, mm_cfg):
# Initialize vision encoder and freeze it
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand All @@ -484,9 +483,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand Down
13 changes: 5 additions & 8 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,17 +534,14 @@ def expand2square(pil_img, background_color):

def create_image_processor(mm_cfg):
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
Expand Down
21 changes: 20 additions & 1 deletion nemo/deploy/multimodal/query_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,31 @@ def setup_media(self, input_media):
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
return np.array(frames)
elif self.model_type == "neva":
elif self.model_type == "lita" or self.model_type == "vita":
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
subsample_len = self.frame_len(frames)
sub_frames = self.get_subsampled_frames(frames, subsample_len)
return np.array(sub_frames)
elif self.model_type == "neva" or self.model_type == "vila":
media = Image.open(input_media).convert('RGB')
return np.expand_dims(np.array(media), axis=0)
else:
raise RuntimeError(f"Invalid model type {self.model_type}")

def frame_len(self, frames):
max_frames = 256
if len(frames) <= max_frames:
return len(frames)
else:
subsample = int(np.ceil(float(len(frames)) / max_frames))
return int(np.round(float(len(frames)) / subsample))

def get_subsampled_frames(self, frames, subsample_len):
idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
sub_frames = [frames[i] for i in idx]
return sub_frames

def query(
self,
input_text,
Expand Down
136 changes: 106 additions & 30 deletions nemo/export/multimodal/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import shutil
import tarfile
import tempfile
from pathlib import Path
from time import time

import tensorrt as trt
Expand All @@ -37,18 +38,19 @@ def build_trtllm_engine(
llm_checkpoint_path: str = None,
model_type: str = "neva",
llm_model_type: str = "llama",
tensor_parallel_size: int = 1,
tensor_parallelism_size: int = 1,
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 1,
max_multimodal_len: int = 1024,
dtype: str = "bfloat16",
):
trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False)
visual_checkpoint_model = ['neva', 'lita', 'vila', 'vita']
trt_llm_exporter.export(
nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path,
nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path,
model_type=llm_model_type,
tensor_parallel_size=tensor_parallel_size,
tensor_parallelism_size=tensor_parallelism_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
Expand All @@ -75,12 +77,24 @@ def export_visual_wrapper_onnx(


def build_trt_engine(
model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None
model_type,
input_sizes,
output_dir,
vision_max_batch_size,
dtype=torch.bfloat16,
image_size=None,
num_frames=None,
nemo_config=None,
):
part_name = 'visual_encoder'
onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
engine_file = '%s/%s.engine' % (output_dir, part_name)
config_file = '%s/%s' % (output_dir, "config.json")
nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")

with open(nemo_config_file, 'w') as f:
yaml.dump(nemo_config, f)

logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)

builder = trt.Builder(logger)
Expand Down Expand Up @@ -110,8 +124,8 @@ def build_trt_engine(

nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
nOptBS = max(nMinBS, int(vision_max_batch_size / 2))
nMaxBS = vision_max_batch_size

inputT = network.get_input(0)

Expand Down Expand Up @@ -145,17 +159,41 @@ def build_trt_engine(


def build_neva_engine(
model_type: str,
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
with tempfile.TemporaryDirectory() as temp:
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp)
temp_path = Path(temp)
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)

vision_config = nemo_config["mm_cfg"]["vision_encoder"]

class DownSampleBlock(torch.nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
return vit_embeds

def flat_square(self, x):
n, w, h, c = x.size()
if w % 2 == 1:
x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
n, w, h, c = x.size()
if h % 2 == 1:
x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
n, w, h, c = x.size()
x = x.view(n, w, int(h / 2), int(c * 2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
return x

class VisionEncoderWrapper(torch.nn.Module):

def __init__(self, encoder, connector):
Expand All @@ -166,7 +204,6 @@ def __init__(self, encoder, connector):
def forward(self, images):
vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
vision_x = vision_x.hidden_states[-2]
vision_x = vision_x[:, 1:]
vision_x = self.connector(vision_x)
return vision_x

Expand All @@ -178,44 +215,82 @@ def forward(self, images):
dtype = hf_config.torch_dtype

# connector
assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu":
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear":
vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
vision_connector.load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample":
vision_connector = torch.nn.Sequential(
DownSampleBlock(),
torch.nn.LayerNorm(vision_config["hidden_size"] * 4),
torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in [1, 2, 4]:
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)

else:
raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")

# export the whole wrapper
lita_num_frames = None
wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
image_size = hf_config.vision_config.image_size
if model_type == "lita" or model_type == "vila":
image_size = hf_config.image_size
if model_type == "lita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
else:
image_size = hf_config.vision_config.image_size
if model_type == "vita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
dummy_image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=device
) # dummy image shape [B, C, H, W]

export_visual_wrapper_onnx(wrapper, dummy_image, model_dir)
build_trt_engine(
"neva",
model_type,
[3, image_size, image_size],
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None,
nemo_config=nemo_config,
)


def build_video_neva_engine(
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
Expand Down Expand Up @@ -279,7 +354,7 @@ def forward(self, images):
"video-neva",
[num_frames, 3, image_size, image_size], # [num_frames, 3, H, W]
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=num_frames,
Expand All @@ -290,11 +365,12 @@ def build_visual_engine(
model_dir: str,
visual_checkpoint_path: str,
model_type: str = "neva",
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
if model_type == "neva":
build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
model_list = ['neva', 'lita', 'vila', 'vita']
if model_type in model_list:
build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size)
elif model_type == "video-neva":
build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size)
else:
raise RuntimeError(f"Invalid model type {model_type}")
Loading

0 comments on commit ab8988e

Please sign in to comment.