forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Triton deployment improvements for in-framework models (NVIDIA#9600)
* add NemoQueryLLMPyTorch class for triton query of in-framework models * nemo_export.py changes to better support in-framework models * separate out in-framework version of triton deploy script * add generate() function to MegatronLLMDeployable to allow for direct use in export tests * use NemoQueryLLMPyTorch in deploy tests * add warning message for when MegatronLLMDeployable overrides transformer_engine * remove enable_streaming argument from deploy_inframework_triton.py since MegatronLLMDeployable does not support streaming add query_inframework.py since original query.py does not work with in-framework deployments * Apply isort and black reformatting Signed-off-by: jukim-nv <[email protected]> * skip trtllm support check if in_framework testing * remove unused imports * run_existing_checkpoints was passing wrong prompts argument for in-framework mode * fix unused import in query_inframework.py --------- Signed-off-by: jukim-nv <[email protected]> Co-authored-by: jukim-nv <[email protected]> Co-authored-by: Onur Yilmaz <[email protected]>
- Loading branch information
1 parent
1c73e1b
commit 8898b76
Showing
7 changed files
with
376 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import logging | ||
import sys | ||
|
||
from nemo.deploy import DeployPyTriton | ||
|
||
LOGGER = logging.getLogger("NeMo") | ||
|
||
megatron_llm_supported = True | ||
try: | ||
from nemo.deploy.nlp import MegatronLLMDeployable | ||
except Exception as e: | ||
LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") | ||
megatron_llm_supported = False | ||
|
||
|
||
def get_args(argv): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
description=f"Deploy nemo models to Triton", | ||
) | ||
parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file") | ||
parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service") | ||
parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service") | ||
parser.add_argument( | ||
"-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests" | ||
) | ||
parser.add_argument( | ||
"-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" | ||
) | ||
parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") | ||
parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") | ||
parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") | ||
args = parser.parse_args(argv) | ||
return args | ||
|
||
|
||
def get_nemo_deployable(args): | ||
if args.nemo_checkpoint is None: | ||
raise ValueError("In-Framework deployment requires a .nemo checkpoint") | ||
|
||
return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus) | ||
|
||
|
||
def nemo_deploy(argv): | ||
args = get_args(argv) | ||
|
||
if args.debug_mode: | ||
loglevel = logging.DEBUG | ||
else: | ||
loglevel = logging.INFO | ||
|
||
LOGGER.setLevel(loglevel) | ||
LOGGER.info("Logging level set to {}".format(loglevel)) | ||
LOGGER.info(args) | ||
|
||
if not megatron_llm_supported: | ||
raise ValueError("MegatronLLMDeployable is not supported in this environment.") | ||
triton_deployable = get_nemo_deployable(args) | ||
|
||
try: | ||
nm = DeployPyTriton( | ||
model=triton_deployable, | ||
triton_model_name=args.triton_model_name, | ||
triton_model_version=args.triton_model_version, | ||
max_batch_size=args.max_batch_size, | ||
port=args.triton_port, | ||
address=args.triton_http_address, | ||
) | ||
|
||
LOGGER.info("Triton deploy function will be called.") | ||
nm.deploy() | ||
except Exception as error: | ||
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) | ||
return | ||
|
||
try: | ||
LOGGER.info("Model serving on Triton is will be started.") | ||
nm.serve() | ||
except Exception as error: | ||
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) | ||
return | ||
|
||
LOGGER.info("Model serving will be stopped.") | ||
nm.stop() | ||
|
||
|
||
if __name__ == '__main__': | ||
nemo_deploy(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import sys | ||
|
||
from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch | ||
|
||
|
||
def get_args(argv): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
description=f"Queries Triton server running an in-framework Nemo model", | ||
) | ||
parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") | ||
parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") | ||
prompt_group = parser.add_mutually_exclusive_group(required=True) | ||
prompt_group.add_argument("-p", "--prompt", required=False, type=str, help="Prompt") | ||
prompt_group.add_argument("-pf", "--prompt_file", required=False, type=str, help="File to read the prompt from") | ||
parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") | ||
parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") | ||
parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") | ||
parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") | ||
parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") | ||
|
||
args = parser.parse_args(argv) | ||
return args | ||
|
||
|
||
def query_llm( | ||
url, | ||
model_name, | ||
prompts, | ||
max_output_len=128, | ||
top_k=1, | ||
top_p=0.0, | ||
temperature=1.0, | ||
init_timeout=60.0, | ||
): | ||
nemo_query = NemoQueryLLMPyTorch(url, model_name) | ||
return nemo_query.query_llm( | ||
prompts=prompts, | ||
max_length=max_output_len, | ||
top_k=top_k, | ||
top_p=top_p, | ||
temperature=temperature, | ||
init_timeout=init_timeout, | ||
) | ||
|
||
|
||
def query(argv): | ||
args = get_args(argv) | ||
|
||
if args.prompt_file is not None: | ||
with open(args.prompt_file, "r") as f: | ||
args.prompt = f.read() | ||
|
||
outputs = query_llm( | ||
url=args.url, | ||
model_name=args.model_name, | ||
prompts=[args.prompt], | ||
max_output_len=args.max_output_len, | ||
top_k=args.top_k, | ||
top_p=args.top_p, | ||
temperature=args.temperature, | ||
init_timeout=args.init_timeout, | ||
) | ||
print(outputs["sentences"][0][0]) | ||
|
||
|
||
if __name__ == '__main__': | ||
query(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.