Skip to content

Commit

Permalink
ORTModule support for kwargs input that is a dict (#13910)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Dec 15, 2022
1 parent 3b17ab7 commit 1fd6348
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 50 deletions.
101 changes: 52 additions & 49 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ def get_primitive_dtype(value):
return f"{str(type(value))}_{value}" if isinstance(value, bool) else str(type(value))


def flatten_kwargs(kwargs, device):
def _flatten_kwargs(value, name):
if _PrimitiveType.is_primitive_type(value):
flattened_kwargs[name] = _PrimitiveType.get_tensor(value, device)
elif isinstance(value, torch.Tensor):
flattened_kwargs[name] = value
elif isinstance(value, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list has a corresponding entry in the flattened
# kwargs dict
for idx, val in enumerate(value):
_flatten_kwargs(val, f"{name}_{idx}")
elif isinstance(value, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict has an entry in the flattened kwargs dict
for key, val in value.items():
_flatten_kwargs(val, f"{name}_{key}")

flattened_kwargs = {}
for key, value in kwargs.items():
_flatten_kwargs(value, key)

return flattened_kwargs


class _InputInfo(object):
def __init__(
self,
Expand All @@ -94,17 +119,14 @@ def __init__(
dynamic_axes=None,
schema=None,
num_positionals=0,
num_expanded_positionals_non_none=0,
keyword_names=None,
):
self.names = names
self.shape = shape
self.require_grad_names = require_grad_names if require_grad_names else []
self.dynamic_axes = dynamic_axes if dynamic_axes else {}
self.schema = schema if schema else []
self.num_positionals = num_positionals
self.num_expanded_positionals_non_none = num_expanded_positionals_non_none
self.keyword_names = keyword_names
self.kwargs = None

def __repr__(self) -> str:
return f"""_InputInfo class:
Expand All @@ -113,21 +135,15 @@ def __repr__(self) -> str:
\tRequire gradient: {self.require_grad_names}
\tDynamic axes: {self.dynamic_axes}
\tSchema: {self.schema}
\t#Positionals (total): {self.num_positionals}
\t#Expanded Positionals (non-None): {self.num_expanded_positionals_non_none}
\tKeyword names: {self.keyword_names}"""
\t#Positionals (total): {self.num_positionals}"""

def flatten(self, args, kwargs, device):
"""Flatten args and kwargs in a single tuple of tensors with strict ordering"""

ret = [_PrimitiveType.get_tensor(arg, device) if _PrimitiveType.is_primitive_type(arg) else arg for arg in args]
ret += [
_PrimitiveType.get_tensor(kwargs[name], device)
if _PrimitiveType.is_primitive_type(kwargs[name])
else kwargs[name]
for name in self.names
if name in kwargs
]
flattened_kwargs = flatten_kwargs(kwargs, device)
ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs]
self.kwargs = kwargs

# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
Expand All @@ -140,13 +156,8 @@ def unflatten(self, flat_args):
"""Unflatten tuple of tensors into args and kwargs"""

args = tuple(flat_args[: self.num_positionals])
kwargs = {
name: arg
for name, arg in zip(
self.names[self.num_expanded_positionals_non_none :], flat_args[self.num_positionals :]
)
if name in self.keyword_names
}
kwargs = self.kwargs
self.kwargs = None
return args, kwargs


Expand All @@ -158,7 +169,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
* Initializers: computed from original PyTorch model parameters
"""

def _expand_inputs(current_input, non_none_inputs):
def _expand_inputs(current_input, non_none_inputs, name=""):
# The exporter handles input lists by expanding them so that each
# element of the list is its own input.
# ORTModule must match this behavior by also expanding the inputs.
Expand All @@ -168,28 +179,33 @@ def _expand_inputs(current_input, non_none_inputs):
if isinstance(current_input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself
for inp in current_input:
_expand_inputs(inp, non_none_inputs)
for i, inp in enumerate(current_input):
_expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i))
elif isinstance(current_input, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself
for _, val in current_input.items():
_expand_inputs(val, non_none_inputs)
for key, val in current_input.items():
_expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key)
else:
# else just collect all the non none inputs within non_none_inputs
non_none_inputs.append(current_input)
if isinstance(non_none_inputs, abc.Sequence):
non_none_inputs.append(current_input)
elif isinstance(non_none_inputs, abc.Mapping):
non_none_inputs[name] = current_input

# User inputs
non_none_inputs = []
_expand_inputs(inputs, non_none_inputs)
flattened_kwargs_inputs = {}
_expand_inputs(kwargs, flattened_kwargs_inputs)
buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names}
result = []

for input_idx, name in enumerate(onnx_input_names):
inp = None
if name in kwargs and kwargs[name] is not None:
if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None:
# Only use keywords coming from user that are expected by ONNX model
inp = kwargs[name]
inp = flattened_kwargs_inputs[name]

if inp is None:
try:
Expand Down Expand Up @@ -470,29 +486,28 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):

if input is None or isinstance(input, str):
# Drop all None and string inputs and return 0.
return 0
return

num_expanded_non_none_inputs = 0
if isinstance(input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself.
for i, val in enumerate(input):
# Name each input with the index appended to the original name of the
# argument.
num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)

# Return here since the list by itself is not a valid input.
# All the elements of the list have already been added as inputs individually.
return num_expanded_non_none_inputs
return
elif isinstance(input, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself.
for key, val in input.items():
num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)
_add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)

# Return here since the dict by itself is not a valid input.
# All the elements of the dict have already been added as inputs individually.
return num_expanded_non_none_inputs
return

# InputInfo should contain all the names irrespective of whether they are
# a part of the onnx graph or not.
Expand All @@ -504,9 +519,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
dynamic_axes.update(_add_dynamic_shape(name, input))
input_shape.append(list(input.size()))

# A single input non none input was processed, return 1
return 1

# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
onnx_graph_input_names = []
Expand All @@ -518,7 +530,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
input_names_require_grad = []
input_shape = []
var_positional_idx = 0
num_expanded_non_none_positional_inputs = 0

for input_idx, input_parameter in enumerate(all_input_parameters):
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
Expand All @@ -528,7 +539,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
name = f"{input_parameter.name}_{var_positional_idx}"
var_positional_idx += 1
inp = inputs[args_i]
num_expanded_non_none_positional_inputs += _add_input(name, inp, onnx_graph, onnx_graph_input_names)
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
elif (
input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand All @@ -538,32 +549,24 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
name = input_parameter.name
inp = None
input_idx += var_positional_idx
is_positional = True
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
is_positional = False
num_expanded_non_none_inputs_local = _add_input(name, inp, onnx_graph, onnx_graph_input_names)
if is_positional:
num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
# **kwargs is always the last argument of forward()
for name, inp in kwargs.items():
if name not in input_names:
_add_input(name, inp, onnx_graph, onnx_graph_input_names)

# input_names have been expanded so to get the correct number of non none
# positional names, we need to collect the num_expanded_non_none_positional_inputs.
return _InputInfo(
names=input_names,
shape=input_shape,
require_grad_names=input_names_require_grad,
dynamic_axes=dynamic_axes,
schema=schema,
num_positionals=len(inputs),
num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs,
keyword_names=list(kwargs.keys()),
)


Expand Down
7 changes: 6 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def get_device_from_inputs(args, kwargs):
def _create_iobinding(io_binding, inputs, model, device):
"""Creates IO binding for a `model` inputs and output"""
for idx, value_info in enumerate(model.graph.input):
io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx])))
io_binding.bind_ortvalue_input(
value_info.name,
OrtValue(
_ortvalue_from_torch_tensor(inputs[idx] if inputs[idx].is_contiguous() else inputs[idx].contiguous())
),
)

device_id = get_device_index(device)
for value_info in model.graph.output:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5535,3 +5535,61 @@ def forward(self, input):
# BatchNormInternal is for training, while BatchNormalization is for inference.
assert "BatchNormInternal" in [node.op_type for node in training_model.graph.node]
assert "BatchNormalization" in [node.op_type for node in eval_model.graph.node]


def test_kwargs_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):
super(DictNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, *args, **kwargs):
batch = kwargs["batch"]
a = batch["one_value"]
b = batch["two_value"]["three_value"]
c = batch["two_value"]["four_value"]
d = batch["five_value"]["six_value"]
e = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e

device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = DictNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
batch = {
"one_value": torch.randn(N, D_in, device=device),
"two_value": {
"three_value": torch.randn(N, D_in, device=device),
"four_value": torch.randn(N, D_in, device=device),
},
"five_value": {
"six_value": torch.randn(N, D_in, device=device),
"seven_value": {"eight_value": torch.randn(N, D_in, device=device)},
},
}
batch_copy = copy.deepcopy(batch)
x_copy = copy.deepcopy(x)

_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))


@pytest.mark.parametrize("training_mode", [False, True])
def test_non_contiguous_tensors_as_inputs(training_mode):
class NonContigousNet(torch.nn.Module):
def __init__(self):
super(NonContigousNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, non_contiguous_tensor):
return self.dummy + non_contiguous_tensor

device = "cuda"
pt_model = NonContigousNet().to(device)
pt_model.train(training_mode)
ort_model = ORTModule(copy.deepcopy(pt_model))
ort_model.train(training_mode)
x = torch.arange(12).view(4, 3).t().to(device)
x_copy = copy.deepcopy(x)
assert not x.is_contiguous()
_test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy))

0 comments on commit 1fd6348

Please sign in to comment.