Skip to content

Commit

Permalink
Different inference sessions for different input schema.
Browse files Browse the repository at this point in the history
Before this PR, one GraphModule is exported to one ONNX
model. That model is invoked for all inputs. However,
it's possible that different input schemas result
different ONNX graphs (e.g., ONNX exporter generates
different graphs according the input tensor's rank).
This commit maps one GraphModule to several inference
session and picks up the usable one according the
input schema.
  • Loading branch information
wschin committed Jul 18, 2023
1 parent b5a9b5f commit 78e736d
Showing 1 changed file with 131 additions and 37 deletions.
168 changes: 131 additions & 37 deletions orttraining/orttraining/python/training/torchdynamo/ort_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import dataclasses
import logging
from typing import Any, Dict, Mapping, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union

import numpy as np
import onnx
Expand Down Expand Up @@ -45,6 +45,19 @@
torch.bool: np.bool_,
}

_ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE = {
1: torch.float32,
2: torch.uint8,
3: torch.int8,
5: torch.int16,
6: torch.int32,
7: torch.int64,
9: torch.bool,
10: torch.float16,
}

_TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE = {value: key for key, value in _ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE.items()}


def _nvtx_range_push(name: str):
"""If PyTorch is installed with CUDA support, this starts NVTX range.
Expand Down Expand Up @@ -198,6 +211,15 @@ def _infer_ep_from_device(*args) -> Tuple[str, ...]:
return tuple(eps)


def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
placeholders = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
if hasattr(node, "meta") and "val" in node.meta:
assert isinstance(node.meta["val"], torch.Tensor)
placeholders.append(node)


def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
"""Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
for node in graph_module.graph.nodes:
Expand Down Expand Up @@ -340,28 +362,87 @@ def _assert_allclose_with_detailed_error_message(
)


@dataclasses.dataclass
class OrtExecutionInfo:
class OrtExecutionInfoPerSession:
"""Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""

def __init__(
self,
session: onnxruntime.InferenceSession,
input_names: Tuple[str, ...],
input_value_infos: Tuple[onnx.ValueInfoProto, ...],
output_names: Tuple[str, ...],
output_value_infos: Tuple[onnx.ValueInfoProto, ...],
input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
):
# Carrier of ONNX model and its executor.
self.session: onnxruntime.InferenceSession = session
# For the ONNX model stored in self.session, self.input_names[i] is the
# name of the i-th positional input.
self.input_names: Tuple[str, ...] = input_names
# self.input_name[i]'s type information is stored in self.input_value_infos[i].
self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos
# Similar to self.input_names, but for outputs.
self.output_names: Tuple[str, ...] = output_names
# Similar to self.input_value_infos but for outputs.
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos
# For the ONNX model stored in self.session, self.input_devices[i] is the
# i-th positional input's device.
self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices # type: ignore
# Similar to self.input_devices, but for outputs.
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices # type: ignore
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
# (i.e., args passed into OrtBackend._ort_acclerated_call).
self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = example_outputs

def is_supported(self, *args):
# Compare the args and the input schema in ONNX model and
# return the first match.
if len(args) != len(self.input_value_infos):
return False
for arg, value_info in zip(args, self.input_value_infos):
if not isinstance(arg, torch.Tensor):
return False
onnx_dtype = _TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE[arg.dtype]
if onnx_dtype != value_info.type.tensor_type.elem_type:
return False
for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
if isinstance(dim, int) and (onnx_dim.dim_value == dim or onnx_dim.dim_param):
continue
elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
continue
else:
return False
return True


@dataclasses.dataclass
class OrtExecutionInfoForAllGraphModules:
def __init__(self):
# session self.sessions[mod] is created for computing the graph in mod.
self.sessions: Dict[torch.fx.GraphModule, onnxruntime.InferenceSession] = {}
# self.input_names[mod] contains all input names in the ONNX model exported from mod.
# self.input_names[mod][i] is the name of the i-th positional input of the graph in mod.
self.input_names: Dict[torch.fx.GraphModule, Tuple[str, ...]] = {}
# Similar to self.input_names, but for outputs of the graph.
self.output_names: Dict[torch.fx.GraphModule, Tuple[str, ...]] = {}
# self.input_devices[mod] contains devices of inputs fed to mod.forward (excluding self).
# self.input_devices[mod][i] is the i-th positional input's device.
self.input_devices: Dict[torch.fx.GraphModule, Tuple[ORTC.OrtDevice, ...]] = {} # type: ignore
# Similar to self.input_devices, but for outputs of the graph.
self.output_devices: Dict[torch.fx.GraphModule, Tuple[ORTC.OrtDevice, ...]] = {} # type: ignore
# This is a debug flag. When True, this backend will compare its
self.assert_allclose_to_baseline: bool = False
# We need example outputs to determine output schema of ORT run.
# self.example_outputs[mod] is the outputs of mod.forward(*self.example_inputs[mod]).
self.example_outputs: Dict[torch.fx.GraphModule, Union[Tuple[torch.Tensor, ...], torch.Tensor]] = {}
# All sessions (and their related information) created by exporting the same GraphModule
# with different inputs.
self.execution_info_per_grpah_module: Dict[torch.fx.GraphModule, List[OrtExecutionInfoPerSession]] = {}

def search_reusable_session_execution_info(self, graph_module: torch.fx.GraphModule, *args):
if graph_module not in self.execution_info_per_grpah_module:
return
# All execution information for ONNX models exported from the same `graph_module`
# with different inputs.
candidates = self.execution_info_per_grpah_module[graph_module]

for candidate in candidates:
if candidate.is_supported(*args):
# Returns the first session that accepts this input schema.
return candidate
# No reusable session found.
return None

def cache_session_execution_info(self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession):
if graph_module not in self.execution_info_per_grpah_module:
self.execution_info_per_grpah_module[graph_module] = [info]
else:
self.execution_info_per_grpah_module[graph_module].append(info)


class OrtBackend:
Expand Down Expand Up @@ -415,7 +496,11 @@ def __init__(
# TODO: this is a naive implementation of cache without proper guard
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
# TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs
self._ort_execution_info = OrtExecutionInfo()
# self._ort_execution_info = OrtExecutionInfo()

self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()

self._assert_allclose_to_baseline = False

self.ep = ep
self.session_options = session_options
Expand All @@ -431,14 +516,16 @@ def __init__(
self.preallocate_output = preallocate_output

def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
if graph_module in self._ort_execution_info.sessions:
# We have seen this graph before, so we can use cached objects including session.
onnx_session = self._ort_execution_info.sessions[graph_module]
input_names = self._ort_execution_info.input_names[graph_module]
output_names = self._ort_execution_info.output_names[graph_module]
input_devices = self._ort_execution_info.input_devices[graph_module]
output_devices = self._ort_execution_info.output_devices[graph_module]
prim_outputs = self._ort_execution_info.example_outputs[graph_module]
cached_execution_info_per_session = self._all_ort_execution_info.search_reusable_session_execution_info(
graph_module, *args
)
if cached_execution_info_per_session:
onnx_session = cached_execution_info_per_session.session
input_names = cached_execution_info_per_session.input_names
output_names = cached_execution_info_per_session.output_names
input_devices = cached_execution_info_per_session.input_devices
output_devices = cached_execution_info_per_session.output_devices
prim_outputs = cached_execution_info_per_session.example_outputs
else:
# It's first time seeing such as graph. Let's make a new session
# (type: onnxruntime.InferenceSession) for it.
Expand Down Expand Up @@ -476,7 +563,6 @@ def maybe_map_to_meta_val(value):

# rethrow FakeTensorProb failure because it is not yet currently handled.
raise
self._ort_execution_info.example_outputs[graph_module] = prim_outputs

from torch.onnx._internal.fx import fx_onnx_interpreter

Expand Down Expand Up @@ -521,7 +607,6 @@ def maybe_map_to_meta_val(value):

onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options)
# Cache ORT session. It's reused for the same "graph_module".
self._ort_execution_info.sessions[graph_module] = onnx_session
# Generate ONNX model and extract its input and output names.
onnx_model = _create_onnx_model(onnx_proto)
# TODO(wechi): ORT session should provide a API to extract
Expand All @@ -536,10 +621,19 @@ def maybe_map_to_meta_val(value):
output_devices = _get_onnx_devices(prim_outputs)
else:
output_devices = _get_onnx_devices((prim_outputs,))
self._ort_execution_info.input_names[graph_module] = input_names
self._ort_execution_info.output_names[graph_module] = output_names
self._ort_execution_info.input_devices[graph_module] = input_devices
self._ort_execution_info.output_devices[graph_module] = output_devices

execution_info_per_session = OrtExecutionInfoPerSession(
session=onnx_session,
input_names=input_names,
input_value_infos=tuple(input for input in onnx_model.graph.input),
output_names=output_names,
output_value_infos=tuple(output for output in onnx_model.graph.output),
input_devices=input_devices,
output_devices=output_devices,
example_outputs=prim_outputs,
)

self._all_ort_execution_info.cache_session_execution_info(graph_module, execution_info_per_session)

if isinstance(prim_outputs, tuple):
assert all(isinstance(elem, torch.Tensor) for elem in prim_outputs)
Expand All @@ -557,7 +651,7 @@ def maybe_map_to_meta_val(value):
self.preallocate_output,
)
_nvtx_range_pop()
if self._ort_execution_info.assert_allclose_to_baseline:
if self._assert_allclose_to_baseline:
# Compute baseline.
baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten")
# Ensure every output tensor is close to the corresponding baseline.
Expand All @@ -580,7 +674,7 @@ def maybe_map_to_meta_val(value):
self.preallocate_output,
)
assert len(onnx_outputs) == 1
if self._ort_execution_info.assert_allclose_to_baseline:
if self._assert_allclose_to_baseline:
# Compute baseline.
baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten")
# Ensure output tensor is close to the corresponding baseline.
Expand Down

0 comments on commit 78e736d

Please sign in to comment.