diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index 08896b42790..66786b58cf2 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License. [[autodoc]] onnxruntime.ORTModelForCausalLM +## ORTModelForCustomTasks + +[[autodoc]] onnxruntime.ORTModelForCustomTasks + ## ORTModelForFeatureExtraction [[autodoc]] onnxruntime.ORTModelForFeatureExtraction diff --git a/optimum/onnxruntime/io_binding/__init__.py b/optimum/onnxruntime/io_binding/__init__.py index e0810d5e807..d218d7a700d 100644 --- a/optimum/onnxruntime/io_binding/__init__.py +++ b/optimum/onnxruntime/io_binding/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .io_binding_helper import TypeHelper +from .io_binding_helper import IOBindingHelper, TypeHelper diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index e8005188bee..1911b1f8794 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,11 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import traceback + import numpy as np import torch +import onnxruntime as ort +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper +from ..utils import is_cupy_available, is_onnxruntime_training_available + + +if is_cupy_available(): + import cupy as cp + # Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12 class TypeHelper(ORTTypeHelper): @@ -58,3 +69,81 @@ def ort_type_to_torch_type(ort_type: str): raise ValueError( f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}" ) + + +# Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97 +class IOBindingHelper: + """ + A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer + tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device. + """ + + def __init__(self, model: ort.InferenceSession, device, **kwargs): + self.model = model + self.device = device + # Create {name:idx} dict for model inputs and outputs + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) + + @staticmethod + def to_pytorch(ort_value: OrtValue) -> torch.Tensor: + """ + Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor. + """ + + if is_onnxruntime_training_available(): + return IOBindingHelper.to_pytorch_via_dlpack(ort_value) + else: + try: + return IOBindingHelper.to_pytorch_via_cupy(ort_value) + except Exception as e: + logging.error(traceback.format_exc()) + logging.info("Unable to access output memory in CUDA, will offload to CPU") + return IOBindingHelper.to_pytorch_via_numpy(ort_value) + + @staticmethod + def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + return torch.tensor(ort_value.numpy()).to(ort_device) + + @staticmethod + def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + if ort_device != "cuda": + raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}") + + ort_type = ort_value.data_type() + numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type) + + # Access CUDA memory via CuPy + memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None) + memory_ptr = cp.cuda.MemoryPointer(memory, 0) + cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type) + torch_tensor = torch.from_dlpack(cp_array.toDlpack()) + + # If is boolean, the dtype will be uint8 and need to be convert back to bool. + if "bool" in ort_type: + torch_tensor = torch_tensor.to(torch.bool) + + torch_tensor = torch_tensor.clone() + + return torch_tensor + + @staticmethod + # dlpack support is available for OrtValue only when `onnxruntime-training` is installed + def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: + from torch._C import _from_dlpack + + torch_tensor = ort_value.to_dlpacks(_from_dlpack) + return torch_tensor + + @staticmethod + def get_device_index(device): + if isinstance(device, str): + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) + elif isinstance(device, int): + return device + return 0 if device.index is None else device.index diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 2debd3a077e..144a0c3f76f 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -48,7 +48,7 @@ from ..exporters.onnx import export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from .io_binding import TypeHelper +from .io_binding import IOBindingHelper, TypeHelper from .utils import ( ONNX_WEIGHTS_NAME, get_device_for_provider, @@ -1546,18 +1546,47 @@ def forward( ) class ORTModelForCustomTasks(ORTModel): """ - Onnx Model for any custom tasks. + Onnx Model for any custom tasks using encoder or decoder-only models. """ - def __init__(self, model=None, config=None, **kwargs): - super().__init__(model, config, **kwargs) - if kwargs.pop("use_io_binding", False): - logger.warning( - "ORTModelForCustomTasks doesn't support IO Binding yet, and the inference will be done without IO binding which could cause" - " significant overhead on data copying. If you want us to enable IO binding for custom use case, please open an issue in " - "Optimum: https://github.com/huggingface/optimum." + def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): + super().__init__(model, config, use_io_binding=True, **kwargs) + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) + + def prepare_io_binding(self, **kwargs) -> ort.IOBinding: + """ + Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs + are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. + """ + + name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model) + + # Bind inputs and outputs to onnxruntime session + io_binding = self.model.io_binding() + + # Bind inputs + for input_name in self.model_input_names: + onnx_input = kwargs.pop(input_name) + onnx_input = onnx_input.contiguous() + + io_binding.bind_input( + input_name, + onnx_input.device.type, + self.device.index, + name_to_np_type[input_name], + list(onnx_input.size()), + onnx_input.data_ptr(), ) + # Bind outputs + for name in self.model_output_names: + io_binding.bind_output(name, self.device.type, device_id=self.device.index) + + return io_binding + @add_start_docstrings_to_model_forward( CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, @@ -1566,13 +1595,30 @@ def __init__(self, model=None, config=None, **kwargs): ) ) def forward(self, **kwargs): - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = self._prepare_onnx_inputs(**kwargs) - # run inference - onnx_outputs = self.model.run(None, onnx_inputs) - outputs = self._prepare_onnx_outputs(onnx_outputs) - # converts outputs to namedtuple for pipelines post-processing if applicable - return ModelOutput(outputs) + if self.device.type == "cuda" and self.use_io_binding: + io_binding = self.prepare_io_binding(**kwargs) + + # run inference with binding + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + outputs = {} + for name, output in zip(self.model_output_names, io_binding._iobinding.get_outputs()): + outputs[name] = IOBindingHelper.to_pytorch(output) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(**outputs) + else: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = self._prepare_onnx_inputs(**kwargs) + + # run inference + onnx_outputs = self.model.run(None, onnx_inputs) + outputs = self._prepare_onnx_outputs(onnx_outputs) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(outputs) def _prepare_onnx_inputs(self, **kwargs): model_inputs = {input_key.name: idx for idx, input_key in enumerate(self.model.get_inputs())} diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 8fe41410dd8..13e404672b8 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib.util import os from enum import Enum from typing import Dict, Tuple, Type, Union @@ -23,6 +24,7 @@ import onnx import onnxruntime as ort +import pkg_resources from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss from ..utils import NormalizedTextConfig @@ -39,7 +41,7 @@ def _is_gpu_available(): """ - checks if a gpu is available. + Checks if a gpu is available. """ available_providers = ort.get_available_providers() if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available(): @@ -48,6 +50,24 @@ def _is_gpu_available(): return False +def is_onnxruntime_training_available(): + """ + Checks if onnxruntime-training is available. + """ + path_training_dependecy = os.path.join(ort.__path__[0], "training") + if os.path.exists(path_training_dependecy): + return True + else: + return False + + +def is_cupy_available(): + """ + Checks if onnxruntime-training is available. + """ + return importlib.util.find_spec("cupy") is not None + + class ORTConfigManager: """ A class that contains all the information needed by ONNX Runtime optimization for a given model type. diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index ab91bbbcecb..4841d01efad 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1684,3 +1684,24 @@ def test_default_pipeline_and_model_device(self, *args, **kwargs): tokenizer = get_preprocessor(model_id) pipe = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pipe.device, onnx_model.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + @require_torch_gpu + def test_compare_to_io_binding(self, *args, **kwargs): + model_arch, model_id = args + set_seed(SEED) + onnx_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=False) + set_seed(SEED) + io_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=True) + tokenizer = get_preprocessor(model_id) + tokens = tokenizer("This is a sample output", return_tensors="pt") + onnx_outputs = onnx_model(**tokens) + io_outputs = io_model(**tokens) + + self.assertTrue("pooler_output" in io_outputs) + self.assertIsInstance(io_outputs.pooler_output, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs.pooler_output, io_outputs.pooler_output)) + + gc.collect()