From cfb909907744534ad43f4ccf7ff3c0e2558aea21 Mon Sep 17 00:00:00 2001 From: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Date: Wed, 19 Jan 2022 19:12:29 +0800 Subject: [PATCH] fix ort wrap about input type (#81) --- mmdeploy/backend/onnxruntime/wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmdeploy/backend/onnxruntime/wrapper.py b/mmdeploy/backend/onnxruntime/wrapper.py index 680c9c0c63..51116716cd 100644 --- a/mmdeploy/backend/onnxruntime/wrapper.py +++ b/mmdeploy/backend/onnxruntime/wrapper.py @@ -2,7 +2,6 @@ import os.path as osp from typing import Dict, Optional, Sequence -import numpy as np import onnxruntime as ort import torch @@ -80,11 +79,12 @@ def forward(self, inputs: Dict[str, input_tensor = input_tensor.contiguous() if not self.is_cuda_available: input_tensor = input_tensor.cpu() + element_type = input_tensor.numpy().dtype self.io_binding.bind_input( name=name, device_type=self.device_type, device_id=self.device_id, - element_type=np.float32, + element_type=element_type, shape=input_tensor.shape, buffer_ptr=input_tensor.data_ptr())