Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule support for kwargs input that is a dict #13910

Merged
merged 9 commits into from
Dec 15, 2022
101 changes: 52 additions & 49 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ def get_primitive_dtype(value):
return f"{str(type(value))}_{value}" if isinstance(value, bool) else str(type(value))


def flatten_kwargs(kwargs, device):
def _flatten_kwargs(value, name):
if _PrimitiveType.is_primitive_type(value):
flattened_kwargs[name] = _PrimitiveType.get_tensor(value, device)
elif isinstance(value, torch.Tensor):
flattened_kwargs[name] = value
elif isinstance(value, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list has a corresponding entry in the flattened
# kwargs dict
for idx, val in enumerate(value):
_flatten_kwargs(val, f"{name}_{idx}")
elif isinstance(value, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict has an entry in the flattened kwargs dict
for key, val in value.items():
_flatten_kwargs(val, f"{name}_{key}")

flattened_kwargs = {}
for key, value in kwargs.items():
_flatten_kwargs(value, key)

return flattened_kwargs


class _InputInfo(object):
def __init__(
self,
Expand All @@ -94,17 +119,14 @@ def __init__(
dynamic_axes=None,
schema=None,
num_positionals=0,
num_expanded_positionals_non_none=0,
keyword_names=None,
):
self.names = names
self.shape = shape
self.require_grad_names = require_grad_names if require_grad_names else []
self.dynamic_axes = dynamic_axes if dynamic_axes else {}
self.schema = schema if schema else []
self.num_positionals = num_positionals
self.num_expanded_positionals_non_none = num_expanded_positionals_non_none
self.keyword_names = keyword_names
self.kwargs = None

def __repr__(self) -> str:
return f"""_InputInfo class:
Expand All @@ -113,21 +135,15 @@ def __repr__(self) -> str:
\tRequire gradient: {self.require_grad_names}
\tDynamic axes: {self.dynamic_axes}
\tSchema: {self.schema}
\t#Positionals (total): {self.num_positionals}
\t#Expanded Positionals (non-None): {self.num_expanded_positionals_non_none}
\tKeyword names: {self.keyword_names}"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not printing the captured names, we make debugging of scenarios with kwargs harder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Let me merge this change to unblock a model, and I will create a follow up PR to add the flattened keyword names as part of the InputInfo.

\t#Positionals (total): {self.num_positionals}"""

def flatten(self, args, kwargs, device):
"""Flatten args and kwargs in a single tuple of tensors with strict ordering"""

ret = [_PrimitiveType.get_tensor(arg, device) if _PrimitiveType.is_primitive_type(arg) else arg for arg in args]
ret += [
_PrimitiveType.get_tensor(kwargs[name], device)
if _PrimitiveType.is_primitive_type(kwargs[name])
else kwargs[name]
for name in self.names
if name in kwargs
]
flattened_kwargs = flatten_kwargs(kwargs, device)
ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs]
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
self.kwargs = kwargs

# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
Expand All @@ -140,13 +156,8 @@ def unflatten(self, flat_args):
"""Unflatten tuple of tensors into args and kwargs"""

args = tuple(flat_args[: self.num_positionals])
kwargs = {
name: arg
for name, arg in zip(
self.names[self.num_expanded_positionals_non_none :], flat_args[self.num_positionals :]
)
if name in self.keyword_names
}
kwargs = self.kwargs
self.kwargs = None
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
return args, kwargs


Expand All @@ -158,7 +169,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
* Initializers: computed from original PyTorch model parameters
"""

def _expand_inputs(current_input, non_none_inputs):
def _expand_inputs(current_input, non_none_inputs, name=""):
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
# The exporter handles input lists by expanding them so that each
# element of the list is its own input.
# ORTModule must match this behavior by also expanding the inputs.
Expand All @@ -168,28 +179,33 @@ def _expand_inputs(current_input, non_none_inputs):
if isinstance(current_input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself
for inp in current_input:
_expand_inputs(inp, non_none_inputs)
for i, inp in enumerate(current_input):
_expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i))
elif isinstance(current_input, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself
for _, val in current_input.items():
_expand_inputs(val, non_none_inputs)
for key, val in current_input.items():
_expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key)
else:
# else just collect all the non none inputs within non_none_inputs
non_none_inputs.append(current_input)
if isinstance(non_none_inputs, abc.Sequence):
non_none_inputs.append(current_input)
elif isinstance(non_none_inputs, abc.Mapping):
non_none_inputs[name] = current_input

# User inputs
non_none_inputs = []
_expand_inputs(inputs, non_none_inputs)
flattened_kwargs_inputs = {}
_expand_inputs(kwargs, flattened_kwargs_inputs)
buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names}
result = []

for input_idx, name in enumerate(onnx_input_names):
inp = None
if name in kwargs and kwargs[name] is not None:
if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None:
# Only use keywords coming from user that are expected by ONNX model
inp = kwargs[name]
inp = flattened_kwargs_inputs[name]

if inp is None:
try:
Expand Down Expand Up @@ -470,29 +486,28 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):

if input is None or isinstance(input, str):
# Drop all None and string inputs and return 0.
return 0
return

num_expanded_non_none_inputs = 0
if isinstance(input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself.
for i, val in enumerate(input):
# Name each input with the index appended to the original name of the
# argument.
num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)

# Return here since the list by itself is not a valid input.
# All the elements of the list have already been added as inputs individually.
return num_expanded_non_none_inputs
return
elif isinstance(input, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself.
for key, val in input.items():
num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)
_add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)

# Return here since the dict by itself is not a valid input.
# All the elements of the dict have already been added as inputs individually.
return num_expanded_non_none_inputs
return

# InputInfo should contain all the names irrespective of whether they are
# a part of the onnx graph or not.
Expand All @@ -504,9 +519,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
dynamic_axes.update(_add_dynamic_shape(name, input))
input_shape.append(list(input.size()))

# A single input non none input was processed, return 1
return 1

# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
onnx_graph_input_names = []
Expand All @@ -518,7 +530,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
input_names_require_grad = []
input_shape = []
var_positional_idx = 0
num_expanded_non_none_positional_inputs = 0

for input_idx, input_parameter in enumerate(all_input_parameters):
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
Expand All @@ -528,7 +539,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
name = f"{input_parameter.name}_{var_positional_idx}"
var_positional_idx += 1
inp = inputs[args_i]
num_expanded_non_none_positional_inputs += _add_input(name, inp, onnx_graph, onnx_graph_input_names)
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
elif (
input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand All @@ -538,32 +549,24 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
name = input_parameter.name
inp = None
input_idx += var_positional_idx
is_positional = True
if input_idx < len(inputs) and inputs[input_idx] is not None:
inp = inputs[input_idx]
elif name in kwargs and kwargs[name] is not None:
inp = kwargs[name]
is_positional = False
num_expanded_non_none_inputs_local = _add_input(name, inp, onnx_graph, onnx_graph_input_names)
if is_positional:
num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
_add_input(name, inp, onnx_graph, onnx_graph_input_names)
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
# **kwargs is always the last argument of forward()
for name, inp in kwargs.items():
if name not in input_names:
_add_input(name, inp, onnx_graph, onnx_graph_input_names)

# input_names have been expanded so to get the correct number of non none
# positional names, we need to collect the num_expanded_non_none_positional_inputs.
return _InputInfo(
names=input_names,
shape=input_shape,
require_grad_names=input_names_require_grad,
dynamic_axes=dynamic_axes,
schema=schema,
num_positionals=len(inputs),
num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs,
keyword_names=list(kwargs.keys()),
)


Expand Down
7 changes: 6 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def get_device_from_inputs(args, kwargs):
def _create_iobinding(io_binding, inputs, model, device):
"""Creates IO binding for a `model` inputs and output"""
for idx, value_info in enumerate(model.graph.input):
io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx])))
io_binding.bind_ortvalue_input(
value_info.name,
OrtValue(
_ortvalue_from_torch_tensor(inputs[idx] if inputs[idx].is_contiguous() else inputs[idx].contiguous())
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
),
)

device_id = get_device_index(device)
for value_info in model.graph.output:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5535,3 +5535,61 @@ def forward(self, input):
# BatchNormInternal is for training, while BatchNormalization is for inference.
assert "BatchNormInternal" in [node.op_type for node in training_model.graph.node]
assert "BatchNormalization" in [node.op_type for node in eval_model.graph.node]


def test_kwargs_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):
super(DictNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, *args, **kwargs):
batch = kwargs["batch"]
a = batch["one_value"]
b = batch["two_value"]["three_value"]
c = batch["two_value"]["four_value"]
d = batch["five_value"]["six_value"]
e = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e

device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = DictNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
batch = {
"one_value": torch.randn(N, D_in, device=device),
"two_value": {
"three_value": torch.randn(N, D_in, device=device),
"four_value": torch.randn(N, D_in, device=device),
},
"five_value": {
"six_value": torch.randn(N, D_in, device=device),
"seven_value": {"eight_value": torch.randn(N, D_in, device=device)},
},
}
batch_copy = copy.deepcopy(batch)
x_copy = copy.deepcopy(x)

_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))


@pytest.mark.parametrize("training_mode", [False, True])
def test_non_contiguous_tensors_as_inputs(training_mode):
class NonContigousNet(torch.nn.Module):
def __init__(self):
super(NonContigousNet, self).__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, non_contiguous_tensor):
return self.dummy + non_contiguous_tensor

device = "cuda"
pt_model = NonContigousNet().to(device)
pt_model.train(training_mode)
ort_model = ORTModule(copy.deepcopy(pt_model))
ort_model.train(training_mode)
x = torch.arange(12).view(4, 3).t().to(device)
x_copy = copy.deepcopy(x)
assert not x.is_contiguous()
_test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy))