Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule support for kwargs input that is a dict #13910

Merged
merged 9 commits into from
Dec 15, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def execution_session_run_forward(execution_session, onnx_model, device, *inputs
run_options = C.RunOptions()

# Use IO binding
_utils._create_iobinding(io_binding, inputs, onnx_model, device)
forward_inputs = []
for input in inputs:
if not input.is_contiguous():
input = input.contiguous()
forward_inputs.append(input)
_utils._create_iobinding(io_binding, forward_inputs, onnx_model, device)
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved

# Run and return module outputs.
ort_output = execution_session.run_forward(io_binding, run_options)
Expand Down
66 changes: 44 additions & 22 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ def get_primitive_dtype(value):
return f"{str(type(value))}_{value}" if isinstance(value, bool) else str(type(value))


def flatten_kwargs(kwargs, device):
def _flatten_kwargs(value, name):
if _PrimitiveType.is_primitive_type(value):
flattened_kwargs[name] = _PrimitiveType.get_tensor(value, device)
elif isinstance(value, torch.Tensor):
flattened_kwargs[name] = value
elif isinstance(value, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list has a corresponding entry in the flattened
# kwargs dict
for idx, val in enumerate(value):
_flatten_kwargs(val, f"{name}_{idx}")
elif isinstance(value, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict has an entry in the flattened kwargs dict
for key, val in value.items():
_flatten_kwargs(val, f"{name}_{key}")

flattened_kwargs = {}
for key, value in kwargs.items():
_flatten_kwargs(value, key)

return flattened_kwargs


class _InputInfo(object):
def __init__(
self,
Expand All @@ -105,6 +130,7 @@ def __init__(
self.num_positionals = num_positionals
self.num_expanded_positionals_non_none = num_expanded_positionals_non_none
self.keyword_names = keyword_names
self.kwargs = None

def __repr__(self) -> str:
return f"""_InputInfo class:
Expand All @@ -121,13 +147,9 @@ def flatten(self, args, kwargs, device):
"""Flatten args and kwargs in a single tuple of tensors with strict ordering"""

ret = [_PrimitiveType.get_tensor(arg, device) if _PrimitiveType.is_primitive_type(arg) else arg for arg in args]
ret += [
_PrimitiveType.get_tensor(kwargs[name], device)
if _PrimitiveType.is_primitive_type(kwargs[name])
else kwargs[name]
for name in self.names
if name in kwargs
]
flattened_kwargs = flatten_kwargs(kwargs, device)
ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs]
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
self.kwargs = kwargs

# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
Expand All @@ -140,13 +162,8 @@ def unflatten(self, flat_args):
"""Unflatten tuple of tensors into args and kwargs"""

args = tuple(flat_args[: self.num_positionals])
kwargs = {
name: arg
for name, arg in zip(
self.names[self.num_expanded_positionals_non_none :], flat_args[self.num_positionals :]
)
if name in self.keyword_names
}
kwargs = self.kwargs
self.kwargs = None
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
return args, kwargs


Expand All @@ -158,7 +175,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
* Initializers: computed from original PyTorch model parameters
"""

def _expand_inputs(current_input, non_none_inputs):
def _expand_inputs(current_input, non_none_inputs, name=""):
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
# The exporter handles input lists by expanding them so that each
# element of the list is its own input.
# ORTModule must match this behavior by also expanding the inputs.
Expand All @@ -168,28 +185,33 @@ def _expand_inputs(current_input, non_none_inputs):
if isinstance(current_input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself
for inp in current_input:
_expand_inputs(inp, non_none_inputs)
for i, inp in enumerate(current_input):
_expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i))
elif isinstance(current_input, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself
for _, val in current_input.items():
_expand_inputs(val, non_none_inputs)
for key, val in current_input.items():
_expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key)
else:
# else just collect all the non none inputs within non_none_inputs
non_none_inputs.append(current_input)
if isinstance(non_none_inputs, abc.Sequence):
non_none_inputs.append(current_input)
elif isinstance(non_none_inputs, abc.Mapping):
non_none_inputs[name] = current_input

# User inputs
non_none_inputs = []
_expand_inputs(inputs, non_none_inputs)
flattened_kwargs_inputs = {}
_expand_inputs(kwargs, flattened_kwargs_inputs)
buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names}
result = []

for input_idx, name in enumerate(onnx_input_names):
inp = None
if name in kwargs and kwargs[name] is not None:
if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None:
# Only use keywords coming from user that are expected by ONNX model
inp = kwargs[name]
inp = flattened_kwargs_inputs[name]

if inp is None:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5535,3 +5535,39 @@ def forward(self, input):
# BatchNormInternal is for training, while BatchNormalization is for inference.
assert "BatchNormInternal" in [node.op_type for node in training_model.graph.node]
assert "BatchNormalization" in [node.op_type for node in eval_model.graph.node]

def test_kwargs_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):
super(DictNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, *args, **kwargs):
batch = kwargs["batch"]
a = batch["one_value"]
b = batch["two_value"]["three_value"]
c = batch["two_value"]["four_value"]
d = batch["five_value"]["six_value"]
e = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e

device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = DictNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
batch = {
"one_value": torch.randn(N, D_in, device=device),
"two_value": {
"three_value": torch.randn(N, D_in, device=device),
"four_value": torch.randn(N, D_in, device=device),
},
"five_value": {
"six_value": torch.randn(N, D_in, device=device),
"seven_value": {"eight_value": torch.randn(N, D_in, device=device)},
},
}
batch_copy = copy.deepcopy(batch)
x_copy = copy.deepcopy(x)

_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))