Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output mismatch of torch.add due to an intermediate result output when running on GPU #3452

Closed
Azyka opened this issue Nov 14, 2023 · 4 comments
Assignees
Labels
internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers

Comments

@Azyka
Copy link

Azyka commented Nov 14, 2023

Description

When outputting an intermediate result of torch.neg in this model:
image

New:
image

The original output of torch.add is expected to be the same for the same input in this 2 graphs. However, it mismatched bewteew the 2 models.

Environment

TensorRT Version: 8.6.1.post1

NVIDIA GPU: RTX 1660

NVIDIA Driver Version: 525.147.05

CUDA Version: 12.0

CUDNN Version: 8.9.4.25

Operating System: Ubuntu 22.04.3 LTS (x86_64)

Python Version (if applicable): 3.10.12

Tensorflow Version (if applicable): 2.13.0

PyTorch Version (if applicable): 2.1.0+cu118

Relevant Files

Model link:
models.zip

Input data file:
input_data.zip

Steps To Reproduce

Script:

from dataclasses import dataclass
from numpy import testing
import numpy as np
import torch
import tensorrt as trt
import pycuda.driver as cuda
from pycuda.driver import DeviceAllocation
import pickle


@dataclass
class HostDeviceMem:
    host: np.ndarray
    device: DeviceAllocation


class ONNXClassifierWrapper():
    def __init__(self, engine):
        self.engine = engine

    def allocate_memory(self):
        engine = self.engine
        inputs = []
        outputs = []
        bindings = []
        stream = cuda.Stream()
        onames = []
        name2idx = {}
        for idx, binding in enumerate(engine):
            name2idx[binding] = idx
            size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
            dtype = trt.nptype(engine.get_binding_dtype(binding))
            # Allocate host and device buffers
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            # Append the device buffer to device bindings.
            bindings.append(int(device_mem))
            # Append to the appropriate list.
            if engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem))
                onames.append(binding)
        return inputs, outputs, bindings, stream, onames, name2idx

    def predict(self, inputs):  # result gets copied into output
        (
            trt_inputs,
            trt_outputs,
            trt_bindings,
            stream,
            onames,
            name2idx,
        ) = self.allocate_memory()
        context = self.engine.create_execution_context()
        # print(name2idx)
        # print(inputs.keys())
        for iname in inputs:
            np.copyto(
                trt_inputs[name2idx[iname]].host,
                inputs[iname]
                .astype(trt.nptype(self.engine.get_binding_dtype(iname)))
                .ravel(),
            )

        [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in trt_inputs]
        context.execute_async_v2(bindings=trt_bindings, stream_handle=stream.handle)
        [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in trt_outputs]
        stream.synchronize()
        trt_outputs = [out.host for out in trt_outputs]
        return {
            n: v.reshape(self.engine.get_binding_shape(n))
            for n, v in zip(onames, trt_outputs)
        }


def convert_onnx_to_engine(onnx_filename):
    logger = trt.Logger(trt.Logger.WARNING)
    with trt.Builder(logger) as builder, builder.create_network(
            1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
            trt.OnnxParser(network, logger) as parser:
        # builder.max_workspace_size = max_workspace_size
        # builder.fp16_mode = fp16_mode
        # builder.max_batch_size = max_batch_size

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30)

        with open(onnx_filename, 'rb') as model:
            if not parser.parse(model.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))

        engine_bytes = builder.build_serialized_network(network, config)

        return trt.Runtime(trt.Logger(trt.Logger.WARNING)).deserialize_cuda_engine(
            engine_bytes
        )

import pycuda.autoinit
class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        gelu = torch._C._nn.gelu(getitem)
        neg = torch.neg(gelu)
        add = torch.add(neg, getitem)
        return (add,)

model_0 = Model0()
output_names_0 = ['v2_0']
input_dict_0 = pickle.load(open('0.pickle', 'rb'))
inputs_0 = tuple(torch.from_numpy(v).to('cuda') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v0_0'], output_names=output_names_0, opset_version=14, do_constant_folding=False)

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        gelu = torch._C._nn.gelu(getitem)
        neg = torch.neg(gelu)
        add = torch.add(getitem, neg)
        return (neg, add)

model_1 = Model1()
output_names_1 = ['v1_0',  'v2_0']
input_dict_1 = pickle.load(open('0.pickle', 'rb'))
inputs_1 = tuple(torch.from_numpy(v).to('cuda') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0'], output_names=output_names_1, opset_version=14, do_constant_folding=False)

engine_0 = convert_onnx_to_engine('0.onnx')
wrapper_0 = ONNXClassifierWrapper(engine_0)
output_0 = wrapper_0.predict(input_dict_0)

engine_1 = convert_onnx_to_engine('1.onnx')
wrapper_1 = ONNXClassifierWrapper(engine_1)
output_1 = wrapper_1.predict(input_dict_1)
output_name_dict = {'v2_0': 'v2_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("tensorRT does not trigger assertion")
except AssertionError as e:
    print("tensorRT triggers assertion")
    print(e)
print('=========================')

Steps to repro:

  1. Download the input data file and put it at the same dir of the script.
  2. Run the script

Output assertion:

=========================
tensorRT triggers assertion

Not equal to tolerance rtol=1, atol=0
at v2_0, v2_0
Mismatched elements: 9194 / 19488 (47.2%)
Max absolute difference: 0.000977
Max relative difference: 0.5
 x: array([[[[1.9073e-06, 4.7684e-07, 1.9073e-06, ..., 2.4796e-05,
          2.1458e-05, 0.0000e+00],
         [0.0000e+00, 6.1512e-05, 0.0000e+00, ..., 8.7547e-04,...
 y: array([[[[0.      , 0.      , 0.      , ..., 0.      , 0.      ,
          0.      ],
         [0.      , 0.      , 0.      , ..., 0.      , 0.      ,...
=========================

Have you tried the latest release?: No

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): Yes

@Azyka
Copy link
Author

Azyka commented Nov 16, 2023

I can repro in TRT 9.1.0.post12.dev4. @zerollzeng

@zerollzeng zerollzeng self-assigned this Nov 18, 2023
@zerollzeng zerollzeng added triaged Issue has been triaged by maintainers internal-bug-tracked Tracked internally, will be fixed in a future release. labels Nov 18, 2023
@zerollzeng
Copy link
Collaborator

I'll file internal bug to track this, we need to debug further to decide whether this is a bug or not. maybe the diff is caused by layer fusion

$ polygraphy run 1.onnx --onnxrt --trt --data-loader-script data_loader.py
[I] RUNNING | Command: /home/scratch.zeroz_sw/miniconda3/bin/polygraphy run 1.onnx --onnxrt --trt --data-loader-script data_loader.py
{'v0_0': array([[[[4.95 , 5.33 , 4.92 , ..., 4.39 , 4.42 , 6.715],
         [6.906, 4.176, 5.363, ..., 3.479, 5.285, 5.44 ],
         [4.793, 4.316, 6.754, ..., 5.312, 3.057, 6.816],
         ...,
         [5.996, 4.965, 5.594, ..., 6.055, 6.164, 5.67 ],
         [6.754, 4.906, 3.184, ..., 6.184, 3.838, 3.016],
         [5.277, 5.152, 3.113, ..., 5.17 , 3.076, 5.152]],

        [[3.637, 3.576, 3.836, ..., 4.973, 3.662, 3.66 ],
         [4.93 , 4.473, 4.816, ..., 5.746, 3.982, 4.508],
         [3.506, 3.928, 4.39 , ..., 4.78 , 6.055, 5.613],
         ...,
         [6.098, 6.715, 6.242, ..., 3.62 , 3.176, 5.773],
         [6.016, 6.89 , 4.85 , ..., 4.33 , 4.21 , 6.973],
         [6.51 , 3.371, 4.74 , ..., 6.824, 3.193, 5.043]],

        [[6.38 , 4.16 , 6.79 , ..., 3.342, 4.25 , 6.176],
         [6.758, 5.395, 6.504, ..., 6.8  , 3.648, 4.652],
         [6.58 , 5.668, 5.754, ..., 4.13 , 3.719, 3.223],
         ...,
         [3.316, 5.34 , 3.666, ..., 6.547, 4.867, 5.867],
         [6.164, 5.305, 6.848, ..., 6.973, 3.912, 6.758],
         [3.695, 5.066, 3.623, ..., 5.207, 6.56 , 6.086]],

        ...,

        [[5.984, 5.164, 4.227, ..., 6.594, 6.56 , 5.816],
         [5.24 , 3.324, 3.104, ..., 6.203, 5.375, 5.484],
         [6.68 , 3.396, 3.79 , ..., 3.766, 3.03 , 4.93 ],
         ...,
         [5.367, 5.01 , 6.605, ..., 4.277, 4.48 , 4.117],
         [4.844, 3.637, 3.855, ..., 6.258, 3.164, 3.902],
         [4.176, 6.965, 4.95 , ..., 4.914, 4.63 , 6.37 ]],

        [[6.434, 5.832, 4.363, ..., 6.16 , 4.527, 4.742],
         [5.59 , 4.887, 3.354, ..., 3.543, 6.613, 6.062],
         [6.4  , 6.965, 4.523, ..., 5.36 , 5.56 , 5.008],
         ...,
         [5.93 , 6.35 , 3.627, ..., 4.492, 4.45 , 3.865],
         [5.1  , 4.527, 6.188, ..., 4.582, 6.17 , 3.588],
         [6.35 , 3.16 , 6.86 , ..., 5.5  , 3.928, 3.328]],

        [[3.66 , 4.574, 4.188, ..., 4.65 , 5.562, 4.258],
         [3.781, 5.64 , 6.707, ..., 6.457, 6.21 , 4.676],
         [6.664, 6.984, 3.145, ..., 4.04 , 3.318, 3.018],
         ...,
         [4.305, 6.96 , 5.508, ..., 4.938, 5.25 , 4.92 ],
         [3.793, 5.44 , 4.438, ..., 3.459, 6.805, 3.816],
         [4.215, 4.91 , 3.756, ..., 3.115, 6.47 , 4.953]]]], dtype=float16)}
[I] onnxrt-runner-N0-11/18/23-10:29:03  | Activating and starting inference
[I] Creating ONNX-Runtime Inference Session with providers: ['CPUExecutionProvider']
[I] onnxrt-runner-N0-11/18/23-10:29:03
    ---- Inference Input(s) ----
    {v0_0 [dtype=float16, shape=(1, 24, 28, 29)]}
[I] onnxrt-runner-N0-11/18/23-10:29:03
    ---- Inference Output(s) ----
    {v1_0 [dtype=float16, shape=(1, 24, 28, 29)],
     v2_0 [dtype=float16, shape=(1, 24, 28, 29)]}
[I] onnxrt-runner-N0-11/18/23-10:29:03  | Completed 1 iteration(s) in 0.3161 ms | Average inference time: 0.3161 ms.
[I] trt-runner-N0-11/18/23-10:29:03     | Activating and starting inference
[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[I]     Configuring with profiles: [Profile().add('v0_0', min=[1, 24, 28, 29], opt=[1, 24, 28, 29], max=[1, 24, 28, 29])]
[I] Building engine with configuration:
    Flags                  | []
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 22551.38 MiB, TACTIC_DRAM: 22551.38 MiB]
    Tactic Sources         | [CUBLAS, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
[I] Finished engine building in 18.267 seconds
[I] trt-runner-N0-11/18/23-10:29:03
    ---- Inference Input(s) ----
    {v0_0 [dtype=float16, shape=(1, 24, 28, 29)]}
[I] trt-runner-N0-11/18/23-10:29:03
    ---- Inference Output(s) ----
    {v1_0 [dtype=float16, shape=(1, 24, 28, 29)],
     v2_0 [dtype=float16, shape=(1, 24, 28, 29)]}
[I] trt-runner-N0-11/18/23-10:29:03     | Completed 1 iteration(s) in 0.5124 ms | Average inference time: 0.5124 ms.
[I] Accuracy Comparison | onnxrt-runner-N0-11/18/23-10:29:03 vs. trt-runner-N0-11/18/23-10:29:03
[I]     Comparing Output: 'v1_0' (dtype=float16, shape=(1, 24, 28, 29)) with 'v1_0' (dtype=float16, shape=(1, 24, 28, 29))
[I]         Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         onnxrt-runner-N0-11/18/23-10:29:03: v1_0 | Stats: mean=-4.9962, std-dev=1.1535, var=1.3305, median=-5.0039, min=-7 at (0, 0, 1, 4), max=-2.9961 at (0, 12, 16, 19), avg-magnitude=4.9962
[I]         trt-runner-N0-11/18/23-10:29:03: v1_0 | Stats: mean=-4.9962, std-dev=1.1535, var=1.3305, median=-5.0039, min=-7 at (0, 0, 1, 4), max=-2.9961 at (0, 12, 16, 19), avg-magnitude=4.9962
[I]         Error Metrics: v1_0
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0, 0), max=0 at (0, 0, 0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0, 0), max=0 at (0, 0, 0, 0), avg-magnitude=0
[I]         PASSED | Output: 'v1_0' | Difference is within tolerance (rel=1e-05, abs=1e-05)
[I]     Comparing Output: 'v2_0' (dtype=float16, shape=(1, 24, 28, 29)) with 'v2_0' (dtype=float16, shape=(1, 24, 28, 29))
[I]         Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[W]         trt-runner-N0-11/18/23-10:29:03     | Output: v2_0: Some values are 0. Will add a small epsilon quantity to these when computing relative difference. Note that this may cause some relative differences to be extremely high.
[I]         onnxrt-runner-N0-11/18/23-10:29:03: v2_0 | Stats: mean=0.00031189, std-dev=0.00074569, var=5.5605e-07, median=1.4305e-06, min=0 at (0, 0, 0, 14), max=0.0040436 at (0, 12, 16, 19), avg-magnitude=0.00031189
[I]             ---- Histogram ----
                Bin Range            |  Num Elems | Visualization
                (0       , 0.000404) |      16076 | ########################################
                (0.000404, 0.000809) |        968 | ##
                (0.000809, 0.00121 ) |        570 | #
                (0.00121 , 0.00162 ) |        446 | #
                (0.00162 , 0.00202 ) |        321 |
                (0.00202 , 0.00243 ) |        287 |
                (0.00243 , 0.00283 ) |        264 |
                (0.00283 , 0.00323 ) |        211 |
                (0.00323 , 0.00364 ) |        176 |
                (0.00364 , 0.00404 ) |        169 |
[I]         trt-runner-N0-11/18/23-10:29:03: v2_0 | Stats: mean=0.0002708, std-dev=0.00080704, var=6.5132e-07, median=0, min=0 at (0, 0, 0, 0), max=0.0039062 at (0, 0, 0, 3), avg-magnitude=0.0002708
[I]             ---- Histogram ----
                Bin Range            |  Num Elems | Visualization
                (0       , 0.000404) |      17286 | ########################################
                (0.000404, 0.000809) |          0 |
                (0.000809, 0.00121 ) |          0 |
                (0.00121 , 0.00162 ) |          0 |
                (0.00162 , 0.00202 ) |       1702 | ###
                (0.00202 , 0.00243 ) |          0 |
                (0.00243 , 0.00283 ) |          0 |
                (0.00283 , 0.00323 ) |          0 |
                (0.00323 , 0.00364 ) |          0 |
                (0.00364 , 0.00404 ) |        500 | #
[I]         Error Metrics: v2_0
[I]             Minimum Required Tolerance: elemwise error | [abs=0.00097656] OR [rel=4.3895e+12] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.00012312, std-dev=0.00023616, var=5.5772e-08, median=1.4305e-06, min=0 at (0, 0, 0, 14), max=0.00097656 at (0, 6, 22, 27), avg-magnitude=0.00012312
[I]                 ---- Histogram ----
                    Bin Range            |  Num Elems | Visualization
                    (0       , 9.77e-05) |      14495 | ########################################
                    (9.77e-05, 0.000195) |       1034 | ##
                    (0.000195, 0.000293) |        729 | ##
                    (0.000293, 0.000391) |        614 | #
                    (0.000391, 0.000488) |        512 | #
                    (0.000488, 0.000586) |        461 | #
                    (0.000586, 0.000684) |        420 | #
                    (0.000684, 0.000781) |        466 | #
                    (0.000781, 0.000879) |        374 | #
                    (0.000879, 0.000977) |        383 | #
[I]             Relative Difference | Stats: mean=2.9801e+11, std-dev=7.613e+11, var=5.7958e+23, median=0.3623, min=0 at (0, 0, 0, 14), max=4.3895e+12 at (0, 1, 22, 4), avg-magnitude=2.9801e+11
[I]                 ---- Histogram ----
                    Bin Range            |  Num Elems | Visualization
                    (0       , 4.39e+11) |      16474 | ########################################
                    (4.39e+11, 8.78e+11) |        831 | ##
                    (8.78e+11, 1.32e+12) |        527 | #
                    (1.32e+12, 1.76e+12) |        396 |
                    (1.76e+12, 2.19e+12) |        316 |
                    (2.19e+12, 2.63e+12) |        253 |
                    (2.63e+12, 3.07e+12) |        197 |
                    (3.07e+12, 3.51e+12) |        208 |
                    (3.51e+12, 3.95e+12) |        156 |
                    (3.95e+12, 4.39e+12) |        130 |
[E]         FAILED | Output: 'v2_0' | Difference exceeds tolerance (rel=1e-05, abs=1e-05)
[E]     FAILED | Mismatched outputs: ['v2_0']
[E] Accuracy Summary | onnxrt-runner-N0-11/18/23-10:29:03 vs. trt-runner-N0-11/18/23-10:29:03 | Passed: 0/1 iterations | Pass Rate: 0.0%
[E] FAILED | Runtime: 21.381s | Command: /home/scratch.zeroz_sw/miniconda3/bin/polygraphy run 1.onnx --onnxrt --trt --data-loader-script data_loader.py

@zerollzeng
Copy link
Collaborator

Filed internal bug 4383767 for this.

@zerollzeng
Copy link
Collaborator

Diff is expected for FP16, not a bug, close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants