Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lita, Vila and Vita TRTLLM export #9734

Merged
merged 12 commits into from
Jul 19, 2024
5 changes: 3 additions & 2 deletions examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ infer:
max_input_len: 4096
max_output_len: 256
max_multimodal_len: 3072
vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset

model:
type: neva
type: neva #neva, video-neva, lita, vila, vita
precision: bfloat16
visual_model_path: /path/to/visual.nemo
llm_model_path: /path/to/llm.nemo
llm_model_type: llama
llm_model_type: llama
1 change: 1 addition & 0 deletions examples/multimodal/multimodal_llm/neva/neva_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(cfg):
tensor_parallel_size=cfg.infer.tensor_parallelism,
max_input_len=cfg.infer.max_input_len,
max_output_len=cfg.infer.max_output_len,
vision_max_batch_size=cfg.infer.vision_max_batch_size,
meatybobby marked this conversation as resolved.
Show resolved Hide resolved
max_batch_size=cfg.infer.max_batch_size,
max_multimodal_len=cfg.infer.max_multimodal_len,
dtype=cfg.model.precision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,10 @@ def __init__(
def create_vision_encoder_and_processor(self, mm_cfg):
# Initialize vision encoder and freeze it
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand All @@ -484,9 +483,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
Expand Down
13 changes: 5 additions & 8 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,17 +534,14 @@ def expand2square(pil_img, background_color):

def create_image_processor(mm_cfg):
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
Expand Down
134 changes: 104 additions & 30 deletions nemo/export/multimodal/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,19 @@
llm_checkpoint_path: str = None,
model_type: str = "neva",
llm_model_type: str = "llama",
tensor_parallel_size: int = 1,
tensor_parallelism_size: int = 1,
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 1,
max_multimodal_len: int = 1024,
dtype: str = "bfloat16",
):
trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False)
visual_checkpoint_model = ['neva', 'lita', 'vila', 'vita']
trt_llm_exporter.export(
nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path,
nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path,
model_type=llm_model_type,
tensor_parallel_size=tensor_parallel_size,
tensor_parallelism_size=tensor_parallelism_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
Expand All @@ -75,12 +76,24 @@


def build_trt_engine(
model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None
model_type,
input_sizes,
output_dir,
vision_max_batch_size,
dtype=torch.bfloat16,
image_size=None,
num_frames=None,
nemo_config=None,
):
part_name = 'visual_encoder'
onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
engine_file = '%s/%s.engine' % (output_dir, part_name)
config_file = '%s/%s' % (output_dir, "config.json")
nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml")

with open(nemo_config_file, 'w') as f:
yaml.dump(nemo_config, f)

logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)

builder = trt.Builder(logger)
Expand Down Expand Up @@ -110,8 +123,8 @@

nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
nOptBS = max(nMinBS, int(vision_max_batch_size / 2))
nMaxBS = vision_max_batch_size

inputT = network.get_input(0)

Expand Down Expand Up @@ -145,17 +158,41 @@


def build_neva_engine(
model_type: str,
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
with tempfile.TemporaryDirectory() as temp:
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp)
temp_path = Path(temp)
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)

vision_config = nemo_config["mm_cfg"]["vision_encoder"]

class DownSampleBlock(torch.nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
return vit_embeds

def flat_square(self, x):
n, w, h, c = x.size()
if w % 2 == 1:
x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
n, w, h, c = x.size()
if h % 2 == 1:
x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
n, w, h, c = x.size()
x = x.view(n, w, int(h / 2), int(c * 2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
return x

class VisionEncoderWrapper(torch.nn.Module):

def __init__(self, encoder, connector):
Expand All @@ -166,7 +203,6 @@
def forward(self, images):
vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
vision_x = vision_x.hidden_states[-2]
vision_x = vision_x[:, 1:]
vision_x = self.connector(vision_x)
return vision_x

Expand All @@ -178,44 +214,81 @@
dtype = hf_config.torch_dtype

# connector
assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu":
vision_connector = torch.nn.Sequential(
torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)

key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear":
vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
vision_connector.load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype),
}
)
elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample":
vision_connector = torch.nn.Sequential(
DownSampleBlock(),
torch.nn.LayerNorm(vision_config["hidden_size"] * 4),
torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True),
torch.nn.GELU(),
torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True),
).to(dtype=dtype)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in [1, 2, 4]:
vision_connector[layer].load_state_dict(
{
'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)

else:
raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}")

# export the whole wrapper
wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype)
image_size = hf_config.vision_config.image_size
if model_type == "lita" or model_type == "vila":
image_size = hf_config.image_size
if model_type == "lita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
else:
image_size = hf_config.vision_config.image_size
if model_type == "vita":
lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames']
dummy_image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=device
) # dummy image shape [B, C, H, W]

export_visual_wrapper_onnx(wrapper, dummy_image, model_dir)
build_trt_engine(
"neva",
model_type,
[3, image_size, image_size],
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None,

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'lita_num_frames' may be used before it is initialized.
nemo_config=nemo_config,
)


def build_video_neva_engine(
model_dir: str,
visual_checkpoint_path: str,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# extract NeMo checkpoint
Expand Down Expand Up @@ -279,7 +352,7 @@
"video-neva",
[num_frames, 3, image_size, image_size], # [num_frames, 3, H, W]
model_dir,
max_batch_size,
vision_max_batch_size,
dtype,
image_size=image_size,
num_frames=num_frames,
Expand All @@ -290,11 +363,12 @@
model_dir: str,
visual_checkpoint_path: str,
model_type: str = "neva",
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
):
if model_type == "neva":
build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
model_list = ['neva', 'lita', 'vila', 'vita']
if model_type in model_list:
build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size)
elif model_type == "video-neva":
build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size)
build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size)
else:
raise RuntimeError(f"Invalid model type {model_type}")
Loading
Loading