diff --git a/Dockerfile.win10.min b/Dockerfile.win10.min index dfb6930b15..97dd4c24c3 100644 --- a/Dockerfile.win10.min +++ b/Dockerfile.win10.min @@ -115,6 +115,7 @@ ARG CUDA_PACKAGES="nvcc_${CUDA_MAJOR}.${CUDA_MINOR} \ curand_${CUDA_MAJOR}.${CUDA_MINOR} curand_dev_${CUDA_MAJOR}.${CUDA_MINOR} \ cusolver_${CUDA_MAJOR}.${CUDA_MINOR} cusolver_dev_${CUDA_MAJOR}.${CUDA_MINOR} \ cusparse_${CUDA_MAJOR}.${CUDA_MINOR} cusparse_dev_${CUDA_MAJOR}.${CUDA_MINOR} \ + cupti_${CUDA_MAJOR}.${CUDA_MINOR} \ thrust_${CUDA_MAJOR}.${CUDA_MINOR} \ visual_studio_integration_${CUDA_MAJOR}.${CUDA_MINOR}" ARG CUDA_INSTALL_ROOT_WP="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${CUDA_MAJOR}.${CUDA_MINOR}" diff --git a/build.py b/build.py index a6bb111520..4b6e83fcb4 100755 --- a/build.py +++ b/build.py @@ -598,6 +598,18 @@ def backend_cmake_args(images, components, be, install_dir, library_paths): cargs.append( cmake_backend_enable(be, 'TRITON_ENABLE_METRICS', FLAGS.enable_metrics)) + # [DLIS-4950] always enable below once Windows image is updated with CUPTI + # cargs.append(cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', True)) + if (target_platform() == 'windows') and (not FLAGS.no_container_build): + print( + "Warning: Detected docker build is used for Windows, backend utility 'device memory tracker' will be disabled due to missing library in CUDA Windows docker image." + ) + cargs.append( + cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', False)) + elif FLAGS.enable_gpu: + cargs.append( + cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', True)) + cargs += cmake_backend_extra_args(be) cargs.append('..') return cargs diff --git a/docs/protocol/extension_statistics.md b/docs/protocol/extension_statistics.md index 6e11f3623e..6e82e971ba 100644 --- a/docs/protocol/extension_statistics.md +++ b/docs/protocol/extension_statistics.md @@ -78,7 +78,8 @@ $model_stat = "inference_count" : $number, "execution_count" : $number, "inference_stats" : $inference_stats, - "batch_stats" : [ $batch_stat, ... ] + "batch_stats" : [ $batch_stat, ... ], + "memory_usage" : [ $memory_usage, ...] } ``` @@ -119,6 +120,14 @@ $model_stat = due to different batch size (for example, larger batches typically take longer to compute). +- "memory_usage" : The memory usage detected during model loading, which may be + used to estimate the memory to be released once the model is unloaded. Note + that the estimation is inferenced by the profiling tools and framework's + memory schema, therefore it is advised to perform experiments to understand + the scenario that the reported memory usage can be relied on. As a starting + point, the GPU memory usage for models in ONNX Runtime backend and TensorRT + backend is usually aligned. + ``` $inference_stats = { @@ -217,6 +226,22 @@ $duration_stat = - “ns” : The total duration for the statistic in nanoseconds. +``` +$memory_usage = +{ + "type" : $string, + "id" : $number, + "byte_size" : $number +} +``` + +- "type" : The type of memory, the value can be "CPU", "CPU_PINNED", "GPU". + +- "id" : The id of the memory, typically used with "type" to identify + a device that hosts the memory. + +- "byte_size" : The byte size of the memory. + ### Statistics Response JSON Error Object A failed statistics request will be indicated by an HTTP error status @@ -325,7 +350,16 @@ message ModelStatistics // executed in the model. The batch statistics indicate how many actual // model executions were performed and show differences due to different // batch size (for example, larger batches typically take longer to compute). - InferBatchStatistics batch_stats = 7; + repeated InferBatchStatistics batch_stats = 7; + + // The memory usage detected during model loading, which may be + // used to estimate the memory to be released once the model is unloaded. Note + // that the estimation is inferenced by the profiling tools and framework's + // memory schema, therefore it is advised to perform experiments to understand + // the scenario that the reported memory usage can be relied on. As a starting + // point, the GPU memory usage for models in ONNX Runtime backend and TensorRT + // backend is usually aligned. + repeated MemoryUsage memory_usage = 8; } // Inference statistics. @@ -416,4 +450,18 @@ message InferBatchStatistics // tensor data from the GPU. StatisticDuration compute_output = 4; } + +// Memory usage. +message MemoryUsage +{ + // The type of memory, the value can be "CPU", "CPU_PINNED", "GPU". + string type = 1; + + // The id of the memory, typically used with "type" to identify + // a device that hosts the memory. + int64_t id = 2; + + // The byte size of the memory. + uint64_t byte_size = 3; +} ``` diff --git a/qa/L0_device_memory_tracker/test.py b/qa/L0_device_memory_tracker/test.py new file mode 100644 index 0000000000..0265f043d5 --- /dev/null +++ b/qa/L0_device_memory_tracker/test.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import unittest +import time +from functools import partial + +import tritonclient.http as httpclient +import tritonclient.grpc as grpcclient + +import nvidia_smi + + +class UnifiedClientProxy: + + def __init__(self, client): + self.client_ = client + + def __getattr__(self, attr): + forward_attr = getattr(self.client_, attr) + if type(self.client_) == grpcclient.InferenceServerClient: + if attr == "get_model_config": + return lambda *args, **kwargs: forward_attr( + *args, **kwargs, as_json=True)["config"] + elif attr == "get_inference_statistics": + return partial(forward_attr, as_json=True) + return forward_attr + + +class MemoryUsageTest(unittest.TestCase): + + def setUp(self): + nvidia_smi.nvmlInit() + self.gpu_handle_ = nvidia_smi.nvmlDeviceGetHandleByIndex(0) + self.http_client_ = httpclient.InferenceServerClient( + url="localhost:8000") + self.grpc_client_ = grpcclient.InferenceServerClient( + url="localhost:8001") + + def tearDown(self): + nvidia_smi.nvmlShutdown() + + def report_used_gpu_memory(self): + info = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle_) + return info.used + + def is_testing_backend(self, model_name, backend_name): + return self.client_.get_model_config( + model_name)["backend"] == backend_name + + def verify_recorded_usage(self, model_stat): + recorded_gpu_usage = 0 + for usage in model_stat["memory_usage"]: + if usage["type"] == "GPU": + recorded_gpu_usage += int(usage["byte_size"]) + # unload and verify recorded usage + before_total_usage = self.report_used_gpu_memory() + self.client_.unload_model(model_stat["name"]) + # unload can return before the model is fully unloaded, + # wait to be finished + time.sleep(2) + usage_delta = before_total_usage - self.report_used_gpu_memory() + # check with tolerance as gpu usage obtained is overall usage + self.assertTrue( + usage_delta * 0.9 <= recorded_gpu_usage <= usage_delta * 1.1, + msg= + "For model {}, expect recorded usage to be in range [{}, {}], got {}" + .format(model_stat["name"], usage_delta * 0.9, usage_delta * 1.1, + recorded_gpu_usage)) + + def test_onnx_http(self): + self.client_ = UnifiedClientProxy(self.http_client_) + model_stats = self.client_.get_inference_statistics()["model_stats"] + for model_stat in model_stats: + if self.is_testing_backend(model_stat["name"], "onnxruntime"): + self.verify_recorded_usage(model_stat) + + def test_plan_grpc(self): + self.client_ = UnifiedClientProxy(self.grpc_client_) + model_stats = self.client_.get_inference_statistics()["model_stats"] + for model_stat in model_stats: + if self.is_testing_backend(model_stat["name"], "tensorrt"): + self.verify_recorded_usage(model_stat) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_device_memory_tracker/test.sh b/qa/L0_device_memory_tracker/test.sh new file mode 100644 index 0000000000..7eb0d745da --- /dev/null +++ b/qa/L0_device_memory_tracker/test.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION} +if [ "$#" -ge 1 ]; then + REPO_VERSION=$1 +fi +if [ -z "$REPO_VERSION" ]; then + echo -e "Repository version must be specified" + echo -e "\n***\n*** Test Failed\n***" + exit 1 +fi +if [ ! -z "$TEST_REPO_ARCH" ]; then + REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH} +fi + +export CUDA_VISIBLE_DEVICES=0 + +TEST_LOG="./test.log" +TEST_PY=test.py + +DATADIR=/data/inferenceserver/${REPO_VERSION} +rm -f *.log + +TEST_RESULT_FILE='test_results.txt' +SERVER=/opt/tritonserver/bin/tritonserver +SERVER_LOG="./server.log" + +source ../common/util.sh + +RET=0 + +# prepare model repository, only contains ONNX and TRT models as the +# corresponding backend are known to be memory. +rm -rf models && mkdir models +# ONNX +cp -r /data/inferenceserver/${REPO_VERSION}/onnx_model_store/* models/. +rm -r models/*cpu + +# Convert to get TRT models against the system +CAFFE2PLAN=../common/caffe2plan +set +e +mkdir -p models/vgg19_plan/1 && rm -f models/vgg19_plan/1/model.plan && \ + $CAFFE2PLAN -b32 -n prob -o models/vgg19_plan/1/model.plan \ + $DATADIR/caffe_models/vgg19.prototxt $DATADIR/caffe_models/vgg19.caffemodel +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Failed to generate vgg19 PLAN\n***" + exit 1 +fi + +mkdir -p models/resnet50_plan/1 && rm -f models/resnet50_plan/1/model.plan && \ + $CAFFE2PLAN -b32 -n prob -o models/resnet50_plan/1/model.plan \ + $DATADIR/caffe_models/resnet50.prototxt $DATADIR/caffe_models/resnet50.caffemodel +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Failed to generate resnet50 PLAN\n***" + exit 1 +fi + +mkdir -p models/resnet152_plan/1 && rm -f models/resnet152_plan/1/model.plan && \ + $CAFFE2PLAN -h -b32 -n prob -o models/resnet152_plan/1/model.plan \ + $DATADIR/caffe_models/resnet152.prototxt $DATADIR/caffe_models/resnet152.caffemodel +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Failed to generate resnet152 PLAN\n***" + exit 1 +fi +set -e + +# Set multiple instances on selected model to test instance-wise collection +# and accumulation. +echo "instance_group [{ count: 2; kind: KIND_GPU }]" >> models/resnet152_plan/config.pbtxt +echo "instance_group [{ count: 2; kind: KIND_GPU }]" >> models/densenet/config.pbtxt + +# testing use nvidia-smi for Python to validate the reported usage +pip install nvidia-ml-py3 + +# Start server to load all models (in parallel), then gradually unload +# the models and expect the memory usage changes matches what are reported +# in statistic. +SERVER_ARGS="--backend-config=triton-backend-memory-tracker=true --model-repository=models --model-control-mode=explicit --load-model=*" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +set +e +python $TEST_PY > $TEST_LOG 2>&1 +if [ $? -ne 0 ]; then + RET=1 +fi +set -e +kill $SERVER_PID +wait $SERVER_PID + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + cat $SERVER_LOG + cat $TEST_LOG + echo -e "\n***\n*** Test FAILED\n***" +fi + +exit $RET diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 9795e147f4..d6bb1e6189 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -1183,6 +1183,37 @@ CommonHandler::RegisterModelStatistics() batch_statistics->mutable_compute_output()->set_ns(ucnt); } } + + triton::common::TritonJson::Value memory_usage_json; + err = model_stat.MemberAsArray("memory_usage", &memory_usage_json); + GOTO_IF_ERR(err, earlyexit); + + for (size_t idx = 0; idx < memory_usage_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value usage; + err = memory_usage_json.IndexAsObject(idx, &usage); + GOTO_IF_ERR(err, earlyexit); + + auto memory_usage = statistics->add_memory_usage(); + { + const char* type; + size_t type_len; + err = usage.MemberAsString("type", &type, &type_len); + GOTO_IF_ERR(err, earlyexit); + memory_usage->set_type(std::string(type, type_len)); + } + { + int64_t id; + err = usage.MemberAsInt("id", &id); + GOTO_IF_ERR(err, earlyexit); + memory_usage->set_id(id); + } + { + uint64_t byte_size; + err = usage.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); + memory_usage->set_byte_size(byte_size); + } + } } }