Skip to content

Commit

Permalink
Add and document memory usage in statistic protocol (#5642)
Browse files Browse the repository at this point in the history
* Add and document memory usage in statistic protocol

* Fix doc

* Fix up

* [DO NOT MERGE Add test. FIXME: model generation

* Fix up

* Fix style

* Address comment

* Fix up

* Set memory tracker backend option in build.py

* Fix up

* Add CUPTI library in Windows image build

* Add note to build with memory tracker by default
  • Loading branch information
GuanLuo authored May 25, 2023
1 parent cbb1964 commit eaa2fa0
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 2 deletions.
1 change: 1 addition & 0 deletions Dockerfile.win10.min
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ ARG CUDA_PACKAGES="nvcc_${CUDA_MAJOR}.${CUDA_MINOR} \
curand_${CUDA_MAJOR}.${CUDA_MINOR} curand_dev_${CUDA_MAJOR}.${CUDA_MINOR} \
cusolver_${CUDA_MAJOR}.${CUDA_MINOR} cusolver_dev_${CUDA_MAJOR}.${CUDA_MINOR} \
cusparse_${CUDA_MAJOR}.${CUDA_MINOR} cusparse_dev_${CUDA_MAJOR}.${CUDA_MINOR} \
cupti_${CUDA_MAJOR}.${CUDA_MINOR} \
thrust_${CUDA_MAJOR}.${CUDA_MINOR} \
visual_studio_integration_${CUDA_MAJOR}.${CUDA_MINOR}"
ARG CUDA_INSTALL_ROOT_WP="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${CUDA_MAJOR}.${CUDA_MINOR}"
Expand Down
12 changes: 12 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,18 @@ def backend_cmake_args(images, components, be, install_dir, library_paths):
cargs.append(
cmake_backend_enable(be, 'TRITON_ENABLE_METRICS', FLAGS.enable_metrics))

# [DLIS-4950] always enable below once Windows image is updated with CUPTI
# cargs.append(cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', True))
if (target_platform() == 'windows') and (not FLAGS.no_container_build):
print(
"Warning: Detected docker build is used for Windows, backend utility 'device memory tracker' will be disabled due to missing library in CUDA Windows docker image."
)
cargs.append(
cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', False))
elif FLAGS.enable_gpu:
cargs.append(
cmake_backend_enable(be, 'TRITON_ENABLE_MEMORY_TRACKER', True))

cargs += cmake_backend_extra_args(be)
cargs.append('..')
return cargs
Expand Down
52 changes: 50 additions & 2 deletions docs/protocol/extension_statistics.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ $model_stat =
"inference_count" : $number,
"execution_count" : $number,
"inference_stats" : $inference_stats,
"batch_stats" : [ $batch_stat, ... ]
"batch_stats" : [ $batch_stat, ... ],
"memory_usage" : [ $memory_usage, ...]
}
```

Expand Down Expand Up @@ -119,6 +120,14 @@ $model_stat =
due to different batch size (for example, larger batches typically
take longer to compute).

- "memory_usage" : The memory usage detected during model loading, which may be
used to estimate the memory to be released once the model is unloaded. Note
that the estimation is inferenced by the profiling tools and framework's
memory schema, therefore it is advised to perform experiments to understand
the scenario that the reported memory usage can be relied on. As a starting
point, the GPU memory usage for models in ONNX Runtime backend and TensorRT
backend is usually aligned.

```
$inference_stats =
{
Expand Down Expand Up @@ -217,6 +226,22 @@ $duration_stat =

- “ns” : The total duration for the statistic in nanoseconds.

```
$memory_usage =
{
"type" : $string,
"id" : $number,
"byte_size" : $number
}
```

- "type" : The type of memory, the value can be "CPU", "CPU_PINNED", "GPU".

- "id" : The id of the memory, typically used with "type" to identify
a device that hosts the memory.

- "byte_size" : The byte size of the memory.

### Statistics Response JSON Error Object

A failed statistics request will be indicated by an HTTP error status
Expand Down Expand Up @@ -325,7 +350,16 @@ message ModelStatistics
// executed in the model. The batch statistics indicate how many actual
// model executions were performed and show differences due to different
// batch size (for example, larger batches typically take longer to compute).
InferBatchStatistics batch_stats = 7;
repeated InferBatchStatistics batch_stats = 7;
// The memory usage detected during model loading, which may be
// used to estimate the memory to be released once the model is unloaded. Note
// that the estimation is inferenced by the profiling tools and framework's
// memory schema, therefore it is advised to perform experiments to understand
// the scenario that the reported memory usage can be relied on. As a starting
// point, the GPU memory usage for models in ONNX Runtime backend and TensorRT
// backend is usually aligned.
repeated MemoryUsage memory_usage = 8;
}
// Inference statistics.
Expand Down Expand Up @@ -416,4 +450,18 @@ message InferBatchStatistics
// tensor data from the GPU.
StatisticDuration compute_output = 4;
}
// Memory usage.
message MemoryUsage
{
// The type of memory, the value can be "CPU", "CPU_PINNED", "GPU".
string type = 1;
// The id of the memory, typically used with "type" to identify
// a device that hosts the memory.
int64_t id = 2;
// The byte size of the memory.
uint64_t byte_size = 3;
}
```
111 changes: 111 additions & 0 deletions qa/L0_device_memory_tracker/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import unittest
import time
from functools import partial

import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient

import nvidia_smi


class UnifiedClientProxy:

def __init__(self, client):
self.client_ = client

def __getattr__(self, attr):
forward_attr = getattr(self.client_, attr)
if type(self.client_) == grpcclient.InferenceServerClient:
if attr == "get_model_config":
return lambda *args, **kwargs: forward_attr(
*args, **kwargs, as_json=True)["config"]
elif attr == "get_inference_statistics":
return partial(forward_attr, as_json=True)
return forward_attr


class MemoryUsageTest(unittest.TestCase):

def setUp(self):
nvidia_smi.nvmlInit()
self.gpu_handle_ = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
self.http_client_ = httpclient.InferenceServerClient(
url="localhost:8000")
self.grpc_client_ = grpcclient.InferenceServerClient(
url="localhost:8001")

def tearDown(self):
nvidia_smi.nvmlShutdown()

def report_used_gpu_memory(self):
info = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle_)
return info.used

def is_testing_backend(self, model_name, backend_name):
return self.client_.get_model_config(
model_name)["backend"] == backend_name

def verify_recorded_usage(self, model_stat):
recorded_gpu_usage = 0
for usage in model_stat["memory_usage"]:
if usage["type"] == "GPU":
recorded_gpu_usage += int(usage["byte_size"])
# unload and verify recorded usage
before_total_usage = self.report_used_gpu_memory()
self.client_.unload_model(model_stat["name"])
# unload can return before the model is fully unloaded,
# wait to be finished
time.sleep(2)
usage_delta = before_total_usage - self.report_used_gpu_memory()
# check with tolerance as gpu usage obtained is overall usage
self.assertTrue(
usage_delta * 0.9 <= recorded_gpu_usage <= usage_delta * 1.1,
msg=
"For model {}, expect recorded usage to be in range [{}, {}], got {}"
.format(model_stat["name"], usage_delta * 0.9, usage_delta * 1.1,
recorded_gpu_usage))

def test_onnx_http(self):
self.client_ = UnifiedClientProxy(self.http_client_)
model_stats = self.client_.get_inference_statistics()["model_stats"]
for model_stat in model_stats:
if self.is_testing_backend(model_stat["name"], "onnxruntime"):
self.verify_recorded_usage(model_stat)

def test_plan_grpc(self):
self.client_ = UnifiedClientProxy(self.grpc_client_)
model_stats = self.client_.get_inference_statistics()["model_stats"]
for model_stat in model_stats:
if self.is_testing_backend(model_stat["name"], "tensorrt"):
self.verify_recorded_usage(model_stat)


if __name__ == "__main__":
unittest.main()
128 changes: 128 additions & 0 deletions qa/L0_device_memory_tracker/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/bin/bash
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION}
if [ "$#" -ge 1 ]; then
REPO_VERSION=$1
fi
if [ -z "$REPO_VERSION" ]; then
echo -e "Repository version must be specified"
echo -e "\n***\n*** Test Failed\n***"
exit 1
fi
if [ ! -z "$TEST_REPO_ARCH" ]; then
REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH}
fi

export CUDA_VISIBLE_DEVICES=0

TEST_LOG="./test.log"
TEST_PY=test.py

DATADIR=/data/inferenceserver/${REPO_VERSION}
rm -f *.log

TEST_RESULT_FILE='test_results.txt'
SERVER=/opt/tritonserver/bin/tritonserver
SERVER_LOG="./server.log"

source ../common/util.sh

RET=0

# prepare model repository, only contains ONNX and TRT models as the
# corresponding backend are known to be memory.
rm -rf models && mkdir models
# ONNX
cp -r /data/inferenceserver/${REPO_VERSION}/onnx_model_store/* models/.
rm -r models/*cpu

# Convert to get TRT models against the system
CAFFE2PLAN=../common/caffe2plan
set +e
mkdir -p models/vgg19_plan/1 && rm -f models/vgg19_plan/1/model.plan && \
$CAFFE2PLAN -b32 -n prob -o models/vgg19_plan/1/model.plan \
$DATADIR/caffe_models/vgg19.prototxt $DATADIR/caffe_models/vgg19.caffemodel
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Failed to generate vgg19 PLAN\n***"
exit 1
fi

mkdir -p models/resnet50_plan/1 && rm -f models/resnet50_plan/1/model.plan && \
$CAFFE2PLAN -b32 -n prob -o models/resnet50_plan/1/model.plan \
$DATADIR/caffe_models/resnet50.prototxt $DATADIR/caffe_models/resnet50.caffemodel
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Failed to generate resnet50 PLAN\n***"
exit 1
fi

mkdir -p models/resnet152_plan/1 && rm -f models/resnet152_plan/1/model.plan && \
$CAFFE2PLAN -h -b32 -n prob -o models/resnet152_plan/1/model.plan \
$DATADIR/caffe_models/resnet152.prototxt $DATADIR/caffe_models/resnet152.caffemodel
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Failed to generate resnet152 PLAN\n***"
exit 1
fi
set -e

# Set multiple instances on selected model to test instance-wise collection
# and accumulation.
echo "instance_group [{ count: 2; kind: KIND_GPU }]" >> models/resnet152_plan/config.pbtxt
echo "instance_group [{ count: 2; kind: KIND_GPU }]" >> models/densenet/config.pbtxt

# testing use nvidia-smi for Python to validate the reported usage
pip install nvidia-ml-py3

# Start server to load all models (in parallel), then gradually unload
# the models and expect the memory usage changes matches what are reported
# in statistic.
SERVER_ARGS="--backend-config=triton-backend-memory-tracker=true --model-repository=models --model-control-mode=explicit --load-model=*"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e
python $TEST_PY > $TEST_LOG 2>&1
if [ $? -ne 0 ]; then
RET=1
fi
set -e
kill $SERVER_PID
wait $SERVER_PID

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
cat $SERVER_LOG
cat $TEST_LOG
echo -e "\n***\n*** Test FAILED\n***"
fi

exit $RET
31 changes: 31 additions & 0 deletions src/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,37 @@ CommonHandler::RegisterModelStatistics()
batch_statistics->mutable_compute_output()->set_ns(ucnt);
}
}

triton::common::TritonJson::Value memory_usage_json;
err = model_stat.MemberAsArray("memory_usage", &memory_usage_json);
GOTO_IF_ERR(err, earlyexit);

for (size_t idx = 0; idx < memory_usage_json.ArraySize(); ++idx) {
triton::common::TritonJson::Value usage;
err = memory_usage_json.IndexAsObject(idx, &usage);
GOTO_IF_ERR(err, earlyexit);

auto memory_usage = statistics->add_memory_usage();
{
const char* type;
size_t type_len;
err = usage.MemberAsString("type", &type, &type_len);
GOTO_IF_ERR(err, earlyexit);
memory_usage->set_type(std::string(type, type_len));
}
{
int64_t id;
err = usage.MemberAsInt("id", &id);
GOTO_IF_ERR(err, earlyexit);
memory_usage->set_id(id);
}
{
uint64_t byte_size;
err = usage.MemberAsUInt("byte_size", &byte_size);
GOTO_IF_ERR(err, earlyexit);
memory_usage->set_byte_size(byte_size);
}
}
}
}

Expand Down

0 comments on commit eaa2fa0

Please sign in to comment.