Skip to content

Commit

Permalink
fix named kwargs is double counted in graph input names
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Jan 10, 2023
1 parent 3d8b596 commit a86a2e0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
20 changes: 10 additions & 10 deletions orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,28 +481,28 @@ def _add_dynamic_shape(name, input):
dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"})
return dynamic_axes

def _add_input(name, input, onnx_graph, onnx_graph_input_names):
def _add_input(name, input_value, onnx_graph, onnx_graph_input_names):
"""Returns number of expanded non none inputs that _add_input processed"""

if input is None or isinstance(input, str):
if name in input_names or input_value is None or isinstance(input_value, str):
# Drop all None and string inputs and return 0.
return

if isinstance(input, abc.Sequence):
if isinstance(input_value, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself.
for i, val in enumerate(input):
for i, val in enumerate(input_value):
# Name each input with the index appended to the original name of the
# argument.
_add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names)

# Return here since the list by itself is not a valid input.
# All the elements of the list have already been added as inputs individually.
return
elif isinstance(input, abc.Mapping):
elif isinstance(input_value, abc.Mapping):
# If the input is a mapping (like a dict), expand the dict so that
# each element of the dict is an input by itself.
for key, val in input.items():
for key, val in input_value.items():
_add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names)

# Return here since the dict by itself is not a valid input.
Expand All @@ -513,11 +513,11 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
# a part of the onnx graph or not.
input_names.append(name)

if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input, torch.Tensor):
if input.requires_grad:
if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor):
if input_value.requires_grad:
input_names_require_grad.append(name)
dynamic_axes.update(_add_dynamic_shape(name, input))
input_shape.append(list(input.size()))
dynamic_axes.update(_add_dynamic_shape(name, input_value))
input_shape.append(list(input_value.size()))

# Ignore optional inputs explicitly specified as None
# ONNX exporter may remove unused inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5545,23 +5545,39 @@ def forward(self, input):
def test_kwargs_dict_input():
class DictNet(torch.nn.Module):
def __init__(self):
super(DictNet, self).__init__()
super().__init__()
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))

def forward(self, *args, **kwargs):
def forward(self, *args, named_kwarg, **kwargs):
a = named_kwarg["named_one"]
b = named_kwarg["named_two"]["named_three"]
c = named_kwarg["named_two"]["named_four"]
d = named_kwarg["named_five"]["named_six"]
e = named_kwarg["named_five"]["named_seven"]["named_eight"]
batch = kwargs["batch"]
a = batch["one_value"]
b = batch["two_value"]["three_value"]
c = batch["two_value"]["four_value"]
d = batch["five_value"]["six_value"]
e = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e
f = batch["one_value"]
g = batch["two_value"]["three_value"]
h = batch["two_value"]["four_value"]
i = batch["five_value"]["six_value"]
j = batch["five_value"]["seven_value"]["eight_value"]
return self.dummy + a + b + c + d + e + f + g + h + i + j

device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10
N, D_in = 64, 784
pt_model = DictNet().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
named_kwarg = {
"named_one": torch.randn(N, D_in, device=device),
"named_two": {
"named_three": torch.randn(N, D_in, device=device),
"named_four": torch.randn(N, D_in, device=device),
},
"named_five": {
"named_six": torch.randn(N, D_in, device=device),
"named_seven": {"named_eight": torch.randn(N, D_in, device=device)},
},
}
batch = {
"one_value": torch.randn(N, D_in, device=device),
"two_value": {
Expand All @@ -5576,7 +5592,9 @@ def forward(self, *args, **kwargs):
batch_copy = copy.deepcopy(batch)
x_copy = copy.deepcopy(x)

_test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy))
_test_helpers.assert_values_are_close(
pt_model(x, named_kwarg=named_kwarg, batch=batch), ort_model(x_copy, named_kwarg=named_kwarg, batch=batch_copy)
)


@pytest.mark.parametrize("training_mode", [False, True])
Expand Down

0 comments on commit a86a2e0

Please sign in to comment.