Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lita, Vila and Vita TRTLLM export #9734

Merged
merged 12 commits into from
Jul 19, 2024
5 changes: 3 additions & 2 deletions examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ infer:
max_input_len: 4096
max_output_len: 256
max_multimodal_len: 3072
vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset

model:
type: neva
type: neva #neva, video-neva, lita, vila, vita
precision: bfloat16
visual_model_path: /path/to/visual.nemo
llm_model_path: /path/to/llm.nemo
llm_model_type: llama
llm_model_type: llama
1 change: 1 addition & 0 deletions examples/multimodal/multimodal_llm/neva/neva_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(cfg):
tensor_parallel_size=cfg.infer.tensor_parallelism,
max_input_len=cfg.infer.max_input_len,
max_output_len=cfg.infer.max_output_len,
vision_max_batch_size=cfg.infer.vision_max_batch_size,
meatybobby marked this conversation as resolved.
Show resolved Hide resolved
max_batch_size=cfg.infer.max_batch_size,
max_multimodal_len=cfg.infer.max_multimodal_len,
dtype=cfg.model.precision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,10 @@ def __init__(
def create_vision_encoder_and_processor(self, mm_cfg):
# Initialize vision encoder and freeze it
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand All @@ -484,9 +483,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand Down
13 changes: 5 additions & 8 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,17 +534,14 @@ def expand2square(pil_img, background_color):

def create_image_processor(mm_cfg):
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
Expand Down
21 changes: 20 additions & 1 deletion nemo/deploy/multimodal/query_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,31 @@ def setup_media(self, input_media):
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
return np.array(frames)
elif self.model_type == "neva":
elif self.model_type == "lita" or self.model_type == "vita":
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
subsample_len = self.frame_len(frames)
sub_frames = self.get_subsampled_frames(frames, subsample_len)
return np.array(sub_frames)
elif self.model_type == "neva" or self.model_type == "vila":
media = Image.open(input_media).convert('RGB')
return np.expand_dims(np.array(media), axis=0)
else:
raise RuntimeError(f"Invalid model type {self.model_type}")

def frame_len(self, frames):
max_frames = 256
if len(frames) <= max_frames:
return len(frames)
else:
subsample = int(np.ceil(float(len(frames)) / max_frames))
return int(np.round(float(len(frames)) / subsample))

def get_subsampled_frames(self, frames, subsample_len):
idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
sub_frames = [frames[i] for i in idx]
return sub_frames

def query(
self,
input_text,
Expand Down
136 changes: 106 additions & 30 deletions nemo/export/multimodal/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import shutil
import tarfile
import tempfile
from pathlib import Path
from time import time

import tensorrt as trt
Expand All @@ -37,18 +38,19 @@ def build_trtllm_engine(
llm_checkpoint_path: str = None,
model_type: str = "neva",
llm_model_type: str = "llama",
tensor_parallel_size: int = 1,
tensor_parallelism_size: int = 1,
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 1,
max_multimodal_len: int = 1024,
dtype: str = "bfloat16",
):
trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False)
visual_checkpoint_model = ['neva', 'lita', 'vila', 'vita']
trt_llm_exporter.export(
nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path,
nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path,
model_type=llm_model_type,
tensor_parallel_size=tensor_parallel_size,
tensor_parallelism_size=tensor_parallelism_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
Expand All @@ -75,12 +77,24 @@ def export_visual_wrapper_onnx(


def build_trt_engine(
model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None
model_type,
input_sizes,
output_dir,
vision_max_batch_size,
dtype=torch.bfloat16,
image_size=None,
num_frames=None,
nemo_config=None,
):
part_name = 'visual_encoder'
onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
engine_file = '%s/%s.engine' % (output_dir, part_name)
config_file = '%s/%s' % (output_dir, "config.json")
nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")

with open(nemo_config_file, 'w') as f:
yaml.dump(nemo_config, f)

logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)

builder = trt.Builder(logger)
Expand Down Expand Up @@ -110,8 +124,8 @@ def build_trt_engine(

nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
nOptBS = max(nMinBS, int(vision_max_batch_size / 2))
nMaxBS = vision_max_batch_size

inputT = network.get_input(0)

Expand Down Expand Up @@ -145,17 +159,41 @@ def build_trt_engine(


def build_neva_engine(
model_type: str,
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
with tempfile.TemporaryDirectory() as temp:
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp)
temp_path = Path(temp)
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)

vision_config = nemo_config["mm_cfg"]["vision_encoder"]

class DownSampleBlock(torch.nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
return vit_embeds

def flat_square(self, x):
n, w, h, c = x.size()
if w % 2 == 1:
x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
n, w, h, c = x.size()
if h % 2 == 1:
x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
n, w, h, c = x.size()
x = x.view(n, w, int(h / 2), int(c * 2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
return x

class VisionEncoderWrapper(torch.nn.Module):

def __init__(self, encoder, connector):
Expand All @@ -166,7 +204,6 @@ def __init__(self, encoder, connector):
def forward(self, images):
vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
vision_x = vision_x.hidden_states[-2]
vision_x = vision_x[:, 1:]
vision_x = self.connector(vision_x)
return vision_x

Expand All @@ -178,44 +215,82 @@ def forward(self, images):
dtype = hf_config.torch_dtype

# connector
assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu":
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear":
vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
vision_connector.load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample":
vision_connector = torch.nn.Sequential(
DownSampleBlock(),
torch.nn.LayerNorm(vision_config["hidden_size"] * 4),
torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in [1, 2, 4]:
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)

else:
raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")

# export the whole wrapper
lita_num_frames = None
wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
image_size = hf_config.vision_config.image_size
if model_type == "lita" or model_type == "vila":
image_size = hf_config.image_size
if model_type == "lita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
else:
image_size = hf_config.vision_config.image_size
if model_type == "vita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
dummy_image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=device
) # dummy image shape [B, C, H, W]

export_visual_wrapper_onnx(wrapper, dummy_image, model_dir)
build_trt_engine(
"neva",
model_type,
[3, image_size, image_size],
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None,

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'lita_num_frames' may be used before it is initialized.
nemo_config=nemo_config,
)


def build_video_neva_engine(
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
Expand Down Expand Up @@ -279,7 +354,7 @@ def forward(self, images):
"video-neva",
[num_frames, 3, image_size, image_size], # [num_frames, 3, H, W]
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=num_frames,
Expand All @@ -290,11 +365,12 @@ def build_visual_engine(
model_dir: str,
visual_checkpoint_path: str,
model_type: str = "neva",
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
if model_type == "neva":
build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
model_list = ['neva', 'lita', 'vila', 'vita']
if model_type in model_list:
build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size)
elif model_type == "video-neva":
build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size)
else:
raise RuntimeError(f"Invalid model type {model_type}")
Loading
Loading