Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IO binding support for custom ORTModel #447

Merged
merged 8 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License.

[[autodoc]] onnxruntime.ORTModelForCausalLM

## ORTModelForCustomTasks

[[autodoc]] onnxruntime.ORTModelForCustomTasks

## ORTModelForFeatureExtraction

[[autodoc]] onnxruntime.ORTModelForFeatureExtraction
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/io_binding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .io_binding_helper import TypeHelper
from .io_binding_helper import IOBindingHelper, TypeHelper
89 changes: 89 additions & 0 deletions optimum/onnxruntime/io_binding/io_binding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback

import numpy as np
import torch

import onnxruntime as ort
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper

from ..utils import is_cupy_available, is_onnxruntime_training_available


if is_cupy_available():
import cupy as cp


# Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12
class TypeHelper(ORTTypeHelper):
Expand Down Expand Up @@ -58,3 +69,81 @@ def ort_type_to_torch_type(ort_type: str):
raise ValueError(
f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}"
)


# Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97
class IOBindingHelper:
"""
A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer
tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device.
"""

def __init__(self, model: ort.InferenceSession, device, **kwargs):
self.model = model
self.device = device
# Create {name:idx} dict for model inputs and outputs
self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())}
self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())}
self.model_input_names = list(self.model_inputs.keys())
self.model_output_names = list(self.model_outputs.keys())

@staticmethod
def to_pytorch(ort_value: OrtValue) -> torch.Tensor:
"""
Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor.
"""

if is_onnxruntime_training_available():
return IOBindingHelper.to_pytorch_via_dlpack(ort_value)
else:
try:
return IOBindingHelper.to_pytorch_via_cupy(ort_value)
except Exception as e:
logging.error(traceback.format_exc())
logging.info("Unable to access output memory in CUDA, will offload to CPU")
return IOBindingHelper.to_pytorch_via_numpy(ort_value)

@staticmethod
def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor:
ort_device = ort_value.device_name().lower()
return torch.tensor(ort_value.numpy()).to(ort_device)

@staticmethod
def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor:
ort_device = ort_value.device_name().lower()
if ort_device != "cuda":
raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}")

ort_type = ort_value.data_type()
numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type)

# Access CUDA memory via CuPy
memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None)
memory_ptr = cp.cuda.MemoryPointer(memory, 0)
cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type)
torch_tensor = torch.from_dlpack(cp_array.toDlpack())

# If is boolean, the dtype will be uint8 and need to be convert back to bool.
if "bool" in ort_type:
torch_tensor = torch_tensor.to(torch.bool)

torch_tensor = torch_tensor.clone()

return torch_tensor

@staticmethod
# dlpack support is available for OrtValue only when `onnxruntime-training` is installed
def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor:
from torch._C import _from_dlpack

torch_tensor = ort_value.to_dlpacks(_from_dlpack)
return torch_tensor

@staticmethod
def get_device_index(device):
if isinstance(device, str):
# could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
device = torch.device(device)
elif isinstance(device, int):
return device
return 0 if device.index is None else device.index
78 changes: 62 additions & 16 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ..exporters.onnx import export
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .io_binding import IOBindingHelper, TypeHelper
from .utils import (
ONNX_WEIGHTS_NAME,
get_device_for_provider,
Expand Down Expand Up @@ -1546,18 +1546,47 @@ def forward(
)
class ORTModelForCustomTasks(ORTModel):
"""
Onnx Model for any custom tasks.
Onnx Model for any custom tasks using encoder or decoder-only models.
"""

def __init__(self, model=None, config=None, **kwargs):
super().__init__(model, config, **kwargs)
if kwargs.pop("use_io_binding", False):
logger.warning(
"ORTModelForCustomTasks doesn't support IO Binding yet, and the inference will be done without IO binding which could cause"
" significant overhead on data copying. If you want us to enable IO binding for custom use case, please open an issue in "
"Optimum: https://github.com/huggingface/optimum."
def __init__(self, model=None, config=None, use_io_binding=True, **kwargs):
super().__init__(model, config, use_io_binding=True, **kwargs)
self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())}
self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())}
self.model_input_names = list(self.model_inputs.keys())
self.model_output_names = list(self.model_outputs.keys())

def prepare_io_binding(self, **kwargs) -> ort.IOBinding:
"""
Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs
are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks.
"""

name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model)

# Bind inputs and outputs to onnxruntime session
io_binding = self.model.io_binding()

# Bind inputs
for input_name in self.model_input_names:
onnx_input = kwargs.pop(input_name)
onnx_input = onnx_input.contiguous()

io_binding.bind_input(
input_name,
onnx_input.device.type,
self.device.index,
name_to_np_type[input_name],
list(onnx_input.size()),
onnx_input.data_ptr(),
)

# Bind outputs
for name in self.model_output_names:
io_binding.bind_output(name, self.device.type, device_id=self.device.index)

return io_binding

@add_start_docstrings_to_model_forward(
CUSTOM_TASKS_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
Expand All @@ -1566,13 +1595,30 @@ def __init__(self, model=None, config=None, **kwargs):
)
)
def forward(self, **kwargs):
# converts pytorch inputs into numpy inputs for onnx
onnx_inputs = self._prepare_onnx_inputs(**kwargs)
# run inference
onnx_outputs = self.model.run(None, onnx_inputs)
outputs = self._prepare_onnx_outputs(onnx_outputs)
# converts outputs to namedtuple for pipelines post-processing if applicable
return ModelOutput(outputs)
if self.device.type == "cuda" and self.use_io_binding:
io_binding = self.prepare_io_binding(**kwargs)

# run inference with binding
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

outputs = {}
for name, output in zip(self.model_output_names, io_binding._iobinding.get_outputs()):
outputs[name] = IOBindingHelper.to_pytorch(output)

# converts output to namedtuple for pipelines post-processing
return ModelOutput(**outputs)
else:
# converts pytorch inputs into numpy inputs for onnx
onnx_inputs = self._prepare_onnx_inputs(**kwargs)

# run inference
onnx_outputs = self.model.run(None, onnx_inputs)
outputs = self._prepare_onnx_outputs(onnx_outputs)

# converts output to namedtuple for pipelines post-processing
return ModelOutput(outputs)

def _prepare_onnx_inputs(self, **kwargs):
model_inputs = {input_key.name: idx for idx, input_key in enumerate(self.model.get_inputs())}
Expand Down
22 changes: 21 additions & 1 deletion optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Utility functions, classes and constants for ONNX Runtime."""

import importlib.util
import os
from enum import Enum
from typing import Dict, Tuple, Type, Union
Expand All @@ -23,6 +24,7 @@

import onnx
import onnxruntime as ort
import pkg_resources

from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss
from ..utils import NormalizedTextConfig
Expand All @@ -39,7 +41,7 @@

def _is_gpu_available():
"""
checks if a gpu is available.
Checks if a gpu is available.
"""
available_providers = ort.get_available_providers()
if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available():
Expand All @@ -48,6 +50,24 @@ def _is_gpu_available():
return False


def is_onnxruntime_training_available():
"""
Checks if onnxruntime-training is available.
"""
path_training_dependecy = os.path.join(ort.__path__[0], "training")
if os.path.exists(path_training_dependecy):
return True
else:
return False


def is_cupy_available():
"""
Checks if onnxruntime-training is available.
"""
return importlib.util.find_spec("cupy") is not None


class ORTConfigManager:
"""
A class that contains all the information needed by ONNX Runtime optimization for a given model type.
Expand Down
21 changes: 21 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,3 +1684,24 @@ def test_default_pipeline_and_model_device(self, *args, **kwargs):
tokenizer = get_preprocessor(model_id)
pipe = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer)
self.assertEqual(pipe.device, onnx_model.device)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items())
@require_torch_gpu
def test_compare_to_io_binding(self, *args, **kwargs):
model_arch, model_id = args
set_seed(SEED)
onnx_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=False)
set_seed(SEED)
io_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=True)
tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt")
onnx_outputs = onnx_model(**tokens)
io_outputs = io_model(**tokens)

self.assertTrue("pooler_output" in io_outputs)
self.assertIsInstance(io_outputs.pooler_output, torch.Tensor)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs.pooler_output, io_outputs.pooler_output))

gc.collect()