Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TRT ComposerModel inference wrapper #558

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6dbfb27
Add TRT ComposerModel inference wrapper
Aug 24, 2023
c312d21
Fix precommit
Aug 24, 2023
1703591
Add base wrapper
Aug 25, 2023
ce6ff26
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Aug 28, 2023
89adacc
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Aug 29, 2023
45dd3bb
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Aug 30, 2023
73febfb
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Sep 8, 2023
7cee807
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Dec 12, 2023
33e5289
Update model_registry.py
nik-mosaic Dec 12, 2023
da7b235
add changes to make llmfoundry install and test trtllm
Dec 13, 2023
a872c3d
add new yamls, fix trt bugs
Dec 13, 2023
f01be42
update trt wrapper for new logit format
Dec 13, 2023
c5a79da
more padding and shape fixes
Dec 13, 2023
19abfe2
update run script
Dec 13, 2023
3a3b334
update utils for multigpu trt models
Dec 19, 2023
f646edd
Metric device updates
Dec 22, 2023
dfa30b8
Update interface to support QA tasks
Jan 4, 2024
1c6037c
Update scripts, fix batching
Jan 6, 2024
3e5b5ee
Update foundry:
Jan 11, 2024
bb441f1
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Jan 11, 2024
1b28771
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Mar 5, 2024
633d651
Merge branch 'main' into add_trtllm_wrapper
nik-mosaic Mar 11, 2024
9cc6dea
update wrappers
Mar 25, 2024
86de601
update runner
Mar 25, 2024
1f3eeb6
update script
Mar 25, 2024
8b3a4b1
Remove prints
Mar 25, 2024
8382411
update wrappers
Mar 25, 2024
a92f7cc
update wrapper to properly support MC tasks
Mar 26, 2024
db3afef
Update TRT wrapper and imports
Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.trtllm import TRTLLMEvalWrapper

__all__ = ['InferenceAPIEvalWrapper', 'TRTLLMEvalWrapper']
110 changes: 110 additions & 0 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
# required for loading a python model into composer
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models import ComposerModel
from torchmetrics import Metric
from transformers import AutoTokenizer


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
self.model_name = model_cfg['version']
self.tokenizer = tokenizer
self.labels = None
# set up training and eval metrics
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
self.eval_metrics = {
metric.__class__.__name__: metric for metric in eval_metrics
}
super(InferenceAPIEvalWrapper, self).__init__()
self.mocked_layer = torch.nn.Linear(2, 3)

def get_metrics(self, is_train: bool = False):
if is_train:
metrics = []
else:
metrics = self.eval_metrics

return metrics if metrics else {}
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved

def get_next_token_logit_tensor(self, prompt: str):
raise NotImplementedError

def rebatch(self, batch: Batch):
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
# default is a no-op, but Chat API modifies these
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
output_logits_batch = []
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):

seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.pad_token_id + 1)
for i in range(len(expected_cont_tokens)):
# decode one token at a time
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
expected_cont_tokens[0:i])
next_logit_tensor = self.get_next_token_logit_tensor(prompt)
if next_logit_tensor is None:
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
continue
output_logits = torch.cat(
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.pad_token_id + 1)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
if isinstance(metric, InContextLearningMetric) and batch.get(
'mode', None) == 'icl_task':
assert self.labels is not None
metric.update(batch, outputs, self.labels)
else:
metric.update(
outputs,
self.labels) # pyright: ignore [reportGeneralTypeIssues]

def forward(self):
pass
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved

def loss(self):
pass
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
181 changes: 181 additions & 0 deletions llmfoundry/models/inference_api_wrapper/trtllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Implements a TRT-LLM evaluation model wrapped around a
:class:`.ComposerModel`."""

import json
from pathlib import Path
from typing import Any, Optional

import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper

__all__ = ['TRTLLMEvalWrapper']

try:
import tensorrt_llm
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
TRT_LLM_INSTALLED = True
except ImportError:
TRT_LLM_INSTALLED = False


def check_if_trt_llm_installed():
if not TRT_LLM_INSTALLED:
raise ImportError(
'TRT-LLM is not installed. It must be installed to use the TRTLLMEValWrapper.'
)


# From tensorrt_llm/examples/{model_name}/build.py
def get_engine_name(model: str, dtype: str, tp_size: int, rank: int):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)


class TRTLLMEvalWrapper(InferenceAPIEvalWrapper):

def __init__(
self,
model_cfg: DictConfig,
tokenizer: PreTrainedTokenizer,
):
check_if_trt_llm_installed()

super().__init__(model_cfg, tokenizer)

tensorrt_llm.logger.set_level(model_cfg['log_level'])

# Load TRT config from file
engine_dir = Path(model_cfg['engine_dir'])
config_path = engine_dir / 'config.json'
with open(config_path, 'r') as f:
config = json.load(f)

# Set vars from config
use_gpt_attention_plugin = config['plugin_config'][
'gpt_attention_plugin']
inflight_batching_gpt_attention_plugin = config['plugin_config'][
'inflight_batching_gpt_attention_plugin']
remove_input_padding = config['plugin_config']['remove_input_padding']
if remove_input_padding:
raise ValueError(
'TRT-LLM Evaluation Wrapper does not support remove_input_padding.'
)
dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
num_heads = config['builder_config']['num_heads'] // world_size
hidden_size = config['builder_config']['hidden_size'] // world_size
vocab_size = config['builder_config']['vocab_size']
num_layers = config['builder_config']['num_layers']
multi_query_mode = config['builder_config']['multi_query_mode']
paged_kv_cache = config['builder_config'].get('paged_kv_cache', False)
tokens_per_block = config['builder_config'].get('tokens_per_block', 64)
use_prompt_tuning = config['builder_config'].get(
'use_prompt_tuning', False)

self.hidden_size = hidden_size
self.vocab_size = vocab_size

# Device and rank
runtime_rank = tensorrt_llm.mpi_rank()
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

# Tokenization and sampling
self.END_ID = model_cfg.get('eos_token_id', self.tokenizer.eos_token_id)
self.PAD_ID = model_cfg.get('pad_token_id', self.tokenizer.pad_token_id)
if self.PAD_ID == None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

print('EOS TOKEN:', self.END_ID)
print('Pad token:', self.PAD_ID)

self.sampling_config = SamplingConfig(end_id=self.END_ID,
pad_id=self.PAD_ID,
num_beams=1)

# Load TRT engine
engine_name = get_engine_name(model_cfg['version'], dtype, world_size,
runtime_rank)
serialize_path = engine_dir / engine_name
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()

# Initialize generation session for model
trt_model_config = ModelConfig(
num_heads=num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
inflight_batching_gpt_attention_plugin=
inflight_batching_gpt_attention_plugin,
multi_query_mode=multi_query_mode,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
use_prompt_tuning=use_prompt_tuning)
self.decoder = tensorrt_llm.runtime.GenerationSession(
trt_model_config, engine_buffer, runtime_mapping)

def eval_forward(self, batch, outputs: Optional[Any] = None):
nik-mosaic marked this conversation as resolved.
Show resolved Hide resolved
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Strings will be returned from eval_forward
output_logits_batch = []
batch = self.rebatch(batch)
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):

seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]

prompt = tokens[:cont_idxs[0]]
input_ids = torch.tensor([prompt], dtype=torch.int, device='cuda')
input_lengths = torch.tensor([input_ids.size(1)],
dtype=torch.int,
device='cuda')
#print("prompt:", self.tokenizer.decode(prompt))
#print("Input ids data:", input_ids, len(input_ids), input_ids[0].shape)
#print("Input lengths:", input_lengths)
#print(cont_idxs[0])
#print("Expected continuation tokens:", len(expected_cont_tokens))
self.decoder.setup(input_lengths.size(0),
torch.max(input_lengths).item(),
len(expected_cont_tokens))

output_ids, output_logits_list = self.decoder.decode(
input_ids, input_lengths, self.sampling_config)

#print("Decoded output:", self.tokenizer.decode(output_ids[0][0][cont_idxs[0]:].tolist()))

output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]], device='cuda'),
num_classes=self.vocab_size)

for i in range(len(output_logits_list)):
output_logits_list[i] = output_logits_list[i].squeeze()

next_logit_tensor = torch.stack(output_logits_list)
output_logits = torch.cat([output_logits, next_logit_tensor])
#print(output_logits.shape)
#print(output_ids[0][0][cont_idxs[0]:].tolist())
padding = torch.nn.functional.one_hot(torch.full(
(seqlen - output_logits.shape[0],),
self.PAD_ID,
device=output_logits.device),
num_classes=self.vocab_size)
output_logits = torch.cat([output_logits, padding])
#print("Output logits shape:", output_logits.shape)
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)
2 changes: 2 additions & 0 deletions llmfoundry/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
ComposerHFT5)
from llmfoundry.models.inference_api_wrapper import TRTLLMEvalWrapper
from llmfoundry.models.mpt import ComposerMPTCausalLM

COMPOSER_MODEL_REGISTRY = {
'mpt_causal_lm': ComposerMPTCausalLM,
'hf_causal_lm': ComposerHFCausalLM,
'hf_prefix_lm': ComposerHFPrefixLM,
'hf_t5': ComposerHFT5,
'trtllm': TRTLLMEvalWrapper
}
Loading