Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Larger batch size makes slower inference time #3239

Closed
osirisFdragon opened this issue Aug 20, 2023 · 3 comments
Closed

Larger batch size makes slower inference time #3239

osirisFdragon opened this issue Aug 20, 2023 · 3 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@osirisFdragon
Copy link

osirisFdragon commented Aug 20, 2023

Description

I'm newbie in tensorrt and wanna use Yolov7's tensorrt engine to run inference on a batch of images. The problem is when I increase the batch size, the inference time is become slower. My following pipeline is first convert torch model to ONNX, then use ONNX to create tensorrt engine.

Environment

TensorRT Version: 8.6.1.6-1+cuda11.8

NVIDIA GPU: NVIDIA RTX A5000

NVIDIA Driver Version: 530.41.03

CUDA Version: 11.8

CUDNN Version: 8.9.4.25-1+cuda11.8

Operating System: Ubuntu 22.04.1

Relevant Files

Here is my torch2onnx.py script:

import sys
from pathlib import Path
FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[1].as_posix())

import torch
import argparse
import yaml
import cv2
import math
import os.path as osp
import numpy as np

import os
import json
import onnx

from utils.torch_utils import select_device
from utils.general import check_img_size, scale_coords, landmark_non_max_suppression
from models.experimental import attempt_load_model
from utils.augmentations import letterbox

def load_image(img_path, img_size):
    im = cv2.imread(img_path)  # BGR
    h0, w0 = im.shape[:2]  # orig hw
    r = img_size / max(h0, w0)  # ratio
    if r != 1:  # if sizes are not equal
        im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
                        interpolation=cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR)
    return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized

def preprocess_image(img_path, img_size):
    # Load image: return im, hw_original, hw_resized: (img_size, T) or (T, img_size), where T < img_size
    img, (h0, w0), (h, w) = load_image(img_path, img_size)     # 
    
    # Letterbox: padding img to get shape (img_size, img_size)
    img, ratio, pad = letterbox(img, img_size, auto=False, scaleup=False)
    
    shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling
 
    # Convert
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)

    return torch.from_numpy(img), shapes

def convert(opt):
    
    bs, weights, onnx_path, source, imgsz, device = \
        opt.batch_size, opt.weights, opt.onnx_path, opt.source_file, opt.imgsz, opt.device
    
    if not os.path.exists(source):
        print("Input file is not exists", file=sys.stderr)
        
    device = select_device(device)
    model = attempt_load_model(weights, map_location=device)
    stride = int(model.stride.max())  # model stride
    imgsz = check_img_size(imgsz, s=stride)  # check image size
    model.eval()
    model.model[-1].export = True
    # imgsz = check_img_size(imgsz, s=stride)  # check image size
    inp, shape = preprocess_image(source, imgsz)
    inp = inp.to(device).float()
    inp /= 255
    if len(inp.shape) == 3:
        inp = inp.repeat(bs, 1, 1, 1)
    
    # breakpoint()
    pred = model(inp)
    torch.onnx.export(model, inp, onnx_path, input_names=['images'],
                      output_names=['output'],
                      export_params=True,
                      opset_version=12)
    
    onnx_model = onnx.load_model(onnx_path)
    onnx.checker.check_model(onnx_model)

def parse_opt():
    parser = argparse.ArgumentParser(prog='val.py')
    parser.add_argument('--source_file', type=str)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--data', type=str, default='', help='dataset.yaml path')
    parser.add_argument('--weights', default='./weights/yolov7.pt')
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
    opt = parser.parse_args()
    return opt

def run():
    opt = parse_opt()
    onnx_path = opt.weights.replace(".pt", f"_bs{opt.batch_size}.onnx")
    opt.onnx_path = onnx_path
    convert(opt)
      
if __name__ == "__main__":
    run()

Here is my onnx2tensorrt.py script:

import tensorrt as trt
import sys
import argparse

import os.path as osp

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
# TRT_LOGGER.min_severity = trt.Logger.Severity.VERBORSE
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
print(EXPLICIT_BATCH)
    
def build_engine(onnx_file_path, output_engine_path, max_batch_size, imgsz):   
    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(EXPLICIT_BATCH)
    parser = trt.OnnxParser(network, TRT_LOGGER)

    config = builder.create_builder_config()

    config.max_workspace_size = (1<<32)
    config.set_flag(trt.BuilderFlag.FP16)
    config.default_device_type = trt.DeviceType.GPU

    profile = builder.create_optimization_profile()
    profile.set_shape('input', (max_batch_size, 3, imgsz, imgsz), (max_batch_size, 3, imgsz, imgsz), (max_batch_size, 3, imgsz, imgsz))    # random nubmers for min. opt. max batch
    config.add_optimization_profile(profile)

    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            for error in range(parser.num_errors):
                logger.info(parser.get_error(error))

    engine = builder.build_engine(network, config)
    buf = engine.serialize()
    with open(output_engine_path, 'wb') as f:
        f.write(buf)
    
def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str)
    parser.add_argument('--imgsz', type=int, default=640)
    parser.add_argument('--max-batch-size', type=int, default=32)
    opt = parser.parse_args()
    return opt
    
if __name__ == "__main__":
    opt = parse_opt()
    opt.onnx_path = opt.weights.replace(".pt", f"_bs{opt.max_batch_size}.onnx").replace(".pth", f"_bs{opt.max_batch_size}.onnx")
    build_engine(opt.onnx_path, opt.onnx_path.replace(".onnx", ".engine"), opt.max_batch_size, opt.imgsz)

And the last one is the C++ script I used to run batch inference:


#include <fstream>
#include <common/utils.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <chrono>

#include <helper/BaseHelper.h>

#include <common/struct.h>
#include <process/TRTProcessor.h>

namespace bfj {
    TRTProcessor::TRTProcessor(std::string param_file, std::string log_dir){
        this->init_logger(log_dir);
        this->load_param(param_file);
        this->load_engine();
        this->load_buffers();
        this->init_preprocessor();
        this->init_postprocessor();
        
        this->warmup();
    }

    TRTProcessor::~TRTProcessor(){
        // Free allocated buffers:
        for (size_t i = 0; i < this->output_cnt; i++)
            delete[] host_outputs[i];
        delete[] host_outputs;

        for (auto &buffer: this->m_buffers)
            CUDA_CHECK(cudaFree(buffer));
        this->m_buffers.clear();

        // Destroy tensorrt contexts
        CUDA_CHECK(cudaStreamDestroy(m_stream));
        // delete m_context;       // delete first because m_context created from m_engine // no need to delete shared_ptr
        // delete m_engine;
    }
    
    void TRTProcessor::init_logger(std::string log_dir){
        logger::QDTLog::init(log_dir);
    }

    void TRTProcessor::init_preprocessor(){
        this->m_preprocessor = std::make_shared<PreProcessor>(params->imgsz);
        LOG_INFO("Initializing preprocessing object...");
    }

    void TRTProcessor::init_postprocessor(){
        this->m_postprocessor = std::make_shared<PostProcessor>(params->imgsz, params->num_offsets, params->num_lmks,
                                                                params->nc, params->labels,
                                                                params->conf_thres, params->conf_thres_part,
                                                                params->iou_thres, params->iou_thres_part, params->matching_iou_thres,
                                                                params->nms_method);
        this->m_postprocessor->setDimension(output_dims);
        LOG_INFO("Initializing postprocessing object...");
    }

    void TRTProcessor::load_param(std::string param_file){
        LOG_INFO("Loading configuration in config file {}...", param_file);
        this->params = std::make_shared<param::Params>(param_file);
    }

    void TRTProcessor::load_engine(){
        // Read engine file and save to a char * object.
        std::ifstream file(params->engine, std::ios::binary | std::ios::ate);
        LOG_CERROR(!file.good(), "Unable to read engine file ({})!", params->engine);

        size_t size = 0;
        file.seekg(0, file.end);
        size = file.tellg();
        file.seekg(0, file.beg);
        char *serialized_engine = new char[size];
        LOG_CERROR(!serialized_engine, "Unable to load engine file ({})!", params->engine);
        file.read(serialized_engine, size);
        file.close();

        // Using serialized_engine (format char*) to make ICudaEngine object.
        m_runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(m_logger));
        LOG_CERROR(!m_runtime, "Couldnot create IRuntime object!");

        // Using runtime to make ICudaEngine object
        m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(m_runtime->deserializeCudaEngine(serialized_engine, size));
        LOG_CERROR(!m_engine, "Couldnot create ICudaEngine object!");

        // Using engine to make IExecutionContext
        m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
        LOG_CERROR(!m_context, "Couldnot create IExecutionContext object!");

        // Create cudaStream:
        CUDA_CHECK(cudaStreamCreate(&m_stream));

        // Destroy:
        delete[] serialized_engine;
    }

    void TRTProcessor::load_buffers(){
        // Allocate the input and output buffers
        int nIOTensors = m_engine->getNbIOTensors();
        m_buffers.resize(static_cast<size_t>(nIOTensors));

        for (int i = 0; i < nIOTensors; i++) {
            std::string bname = m_engine->getIOTensorName(i);
            nvinfer1::Dims bsize = m_engine->getTensorShape(bname.c_str());    // input: bxcxhxw, output: bxnum_objsxobj_dim      
            // nvinfer1::DataType btype = m_engine->getTensorDataType(bname.c_str());

            if (m_engine->getTensorIOMode(bname.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
                this->input_names.push_back(bname);
                this->input_dims.push_back(bsize);      // 3, imgsz, imgsz
                LOG_CWARN((bname == "images") && (bsize.d[2] != params->imgsz || bsize.d[3] != params->imgsz), "The `imgsz` param is expected to be equal to the network input's shapes ({}, {}, {}).  Receive {} instead.", bsize.d[0], bsize.d[1], bsize.d[2], params->imgsz);
                LOG_DEBUG("The engine contains input binding `{}` (dimension {} - size {})", bname, getDimensionStr(bsize), std::to_string(getSizeByDim(bsize)));
            } else {
                this->output_names.push_back(bname);
                this->output_dims.push_back(bsize);
                LOG_DEBUG("The engine contains output binding `{}` (dimension {} - size {})", bname, getDimensionStr(bsize), std::to_string(getSizeByDim(bsize)));
            }
            CUDA_CHECK(cudaMalloc(&m_buffers[static_cast<size_t>(i)], getSizeByDim(bsize) * sizeof(float) * params->batch_size));
            LOG_CINFO((i == 1), BOLD("Allocating {} bytes for gpu output 0 in `load_buffers`"), getSizeByDim(bsize) * sizeof(float) * params->batch_size);
        }

        this->input_cnt = input_names.size();
        this->output_cnt = output_names.size();
        LOG_CWARN((this->input_cnt != 1), "In the context of Body-Face-Joint problem, we only accept one binding input, that is 'images' binding.")
        LOG_CWARN((this->output_cnt != 1), "In the context of Body-Face-Joint problem, we only accept one binding output, that is 'output' binding.")

        // Init output data size
        this->host_outputs = new float*[output_cnt];
        for (size_t index = 0; index < output_names.size(); index++) 
            this->host_outputs[index] = new float[getSizeByDim(output_dims[index]) * params->batch_size];
            LOG_DEBUG("Host outputs: ")
    }

    void TRTProcessor::warmup(){
        LOG_INFO(BOLD("Warming engine up..."))
        auto start = std::chrono::high_resolution_clock::now();
        cv::Mat fake_img({input_dims[0].d[1], input_dims[0].d[2], input_dims[0].d[3]}, CV_32F, cv::Scalar{0, 0, 0});
        std::vector<cv::Mat> imgs;
        for (int i = 0; i < params->batch_size; i++)
            imgs.push_back(fake_img);

        this->execute_trt(imgs);
        auto end = std::chrono::high_resolution_clock::now();
        std::chrono::duration<double> elapsed = end - start;
        LOG_INFO(BOLD("Completed in {} secs."), elapsed);

    }

    void TRTProcessor::inference(std::string source_dir, std::vector<IMatches> &results, float &fps){
        std::vector<std::string> img_paths = helper::BaseHelper::listDir(source_dir, true);
        std::sort(img_paths.begin(), img_paths.end());

        std::vector<cv::Mat> imgs;
        for (std::string img_path : img_paths){
            cv::Mat img = cv::imread(img_path);
            if (img.empty()){
                LOG_CWARN("Unable to read image file {}", img_path);
                continue;
            }
            imgs.push_back(img); 
        }
        this->inference(imgs, results, fps);
        for (size_t i = 0; i < results.size(); i++)
            results[i].setImgPath(img_paths[i]);
    }

    void TRTProcessor::execute_trt(std::vector<cv::Mat> &imgs){
        // copy input from host2device: In the context of this problem, we only have one binding input, named "images". 
        for (size_t i = 0; i < imgs.size(); i++){
            size_t c = static_cast<size_t>(imgs[i].size[0]);
            size_t h = static_cast<size_t>(imgs[i].size[1]);
            size_t w = static_cast<size_t>(imgs[i].size[2]);
            size_t size = c * h * w;
            size_t volume = size * sizeof(float);
            // logger::ILogger::logFloatCVMat(imgs[i], "chw");
            // LOG_DEBUG(BOLD("Volume size: {}, size = {}, batchsize = {}, imgs = {}, c = {}, h = {}, w = {}"), volume, size, params->batch_size, imgs.size(), c, h, w);
            CUDA_CHECK(cudaMemcpy((float*)(m_buffers[0]) + i * size, imgs[i].ptr<float>(0, 0, 0), volume, cudaMemcpyHostToDevice));
            
        }
        // int dims[] = {3, 640, 640};
        // logger::ILogger::logGpuInputBuffer(m_buffers[0], params->batch_size, 3, dims);

        // execute enqueue
        // this->m_context->executeV2(m_buffers.data());
        this->m_context->enqueueV2(m_buffers.data(), m_stream, nullptr);

        // copy output from device2host
        for (size_t i = 0; i < output_cnt; i++){
            size_t volume = getSizeByDim(output_dims[i]) * sizeof(float) * params->batch_size;
            CUDA_CHECK(cudaMemcpy(host_outputs[i], m_buffers[input_cnt + i], volume, cudaMemcpyDeviceToHost));
        }
        // int dims[] = {output_dims[0].d[1], output_dims[0].d[2]};    // .d[0] = batch_size when running onnx2tensorrt
        // logger::ILogger::logCpuOutputBuffer(host_outputs[0], params->batch_size, 2, dims);
    }

    void TRTProcessor::inference(std::vector<cv::Mat> imgs, std::vector<IMatches> &results, float &fps){
        cv::Mat img;
        ScaleInfo info;

        std::vector<IMatches> batch_imatches;
        std::vector<cv::Mat> batch_imgs;
        std::vector<ScaleInfo> batch_infos;
        std::vector<int> batch_invalids;

        auto start = std::chrono::high_resolution_clock::now();
        double pre_time_total = 0.0;
        double exe_time_total = 0.0;
        double post_time_total = 0.0;
        LOG_DEBUG(BOLD("Starting running inference in BGR format..."))

        for (size_t i = 0; i < imgs.size(); ){
            int cur = i;
            // auto start_pre = std::chrono::high_resolution_clock::now();
            for (size_t b = cur; b < imgs.size(); b++){
                img = this->m_preprocessor->run(imgs[i], info);
                if (img.empty()){
                    LOG_CWARN("Image with index {} could not be preprocessed! It will be ignored!", i);
                    batch_invalids.push_back(b - i);
                    continue;
                }

                batch_imgs.push_back(img);
                batch_infos.push_back(info);
                i = b + 1;
                if (batch_imgs.size() == static_cast<size_t>(params->batch_size))
                    break;
            }
            auto end_pre = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> pre_time = (end_pre - start_pre);

            this->execute_trt(batch_imgs);                      // output will be save in 
            auto end_exe = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> exe_time = (end_exe - end_pre);

            this->m_postprocessor->run(host_outputs, batch_infos, batch_imatches);
            for (int invalid_index : batch_invalids)
                batch_imatches.insert(batch_imatches.begin() + invalid_index, IMatches());

            results.insert(results.end(), batch_imatches.begin(), batch_imatches.end());
            size_t bsize = batch_imgs.size();
            batch_imgs.clear();
            batch_infos.clear();
            batch_invalids.clear();
            batch_imatches.clear();

            auto end_post = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> post_time = (end_post - end_exe);
            // LOG_DEBUG("Preprocessing time: {}, Execution time: {}, Postprocessing time: {}", pre_time.count() / bsize, exe_time.count() / bsize, post_time.count() / bsize);

            pre_time_total += pre_time.count();
            exe_time_total += exe_time.count();
            post_time_total += post_time.count();
        }
        LOG_CWARN("TOTAL: Preprocessing: {}, Execution: {}, Postprocessing: {}", pre_time_total / imgs.size(), 
                                                                                         exe_time_total / imgs.size(), 
                                                                                         post_time_total / imgs.size());

        auto end = std::chrono::high_resolution_clock::now();
        std::chrono::duration<double> elapsed = end - start;
        fps = 1.0f * imgs.size() / elapsed.count();
        LOG_DEBUG(BOLD("Complete!"))
        LOG_DEBUG(BOLD("Process {} images in time: {}, fps = {}"), imgs.size(), elapsed.count(), fps);
    }
}

The main function to run inference is void TRTProcessor::inference(std::vector<cv::Mat> imgs, std::vector<IMatches> &results, float &fps) (I'm running with 500 images), it will log three kinds of calculation time: Preprocess, tensorrt execution, and postprocess.
The result when I run with params->batch_size = 1: Preprocess = 0.016 sec, execution = 0.006 sec, post = 0.0001
The result when I run with params->batch_size = 2: Preprocess = 0.0148 sec, execution = 0.0041 sec, post = 0.0001
The result when I run with params->batch_size = 4: Preprocess = 0.0143 sec, execution = 0.0047 sec, post = 0.0001
The result when I run with params->batch_size = 8: Preprocess = 0.016 sec, execution = 0.0058, post = 0.0001.

As we can see, the execution time increases when we increase the batch_size. Am I missing something?

@zerollzeng
Copy link
Collaborator

I'm newbie in tensorrt and wanna use Yolov7's tensorrt engine to run inference on a batch of images. The problem is when I increase the batch size, the inference time is become slower. My following pipeline is first convert torch model to ONNX, then use ONNX to create tensorrt engine.

It's expected, larger batch size require larger computation in proportional unless GPU is still under-utilization.

@zerollzeng zerollzeng self-assigned this Aug 22, 2023
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Aug 22, 2023
@ttyio
Copy link
Collaborator

ttyio commented Sep 12, 2023

closing since no activity for more than 3 weeks, pls reopen if you still have question. thanks!

@ttyio ttyio closed this as completed Sep 12, 2023
@watertianyi
Copy link

@ttyio @osirisFdragon

I also found that as the batchsize increases, the inference time increases linearly. I observed that when batchsize=20, the GPU percentage is 100%, but the GPU memory is only 5G. At this time, it makes no sense to increase the batchsize. It will only increase the inference time. In addition to batchsize, which reduces the inference time, is there any other method for tensorrt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants