diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index dd9220e5c082a..4fc339a5d806e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -481,17 +481,17 @@ def _add_dynamic_shape(name, input): dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) return dynamic_axes - def _add_input(name, input, onnx_graph, onnx_graph_input_names): + def _add_input(name, input_value, onnx_graph, onnx_graph_input_names): """Returns number of expanded non none inputs that _add_input processed""" - if input is None or isinstance(input, str): + if name in input_names or input_value is None or isinstance(input_value, str): # Drop all None and string inputs and return 0. return - if isinstance(input, abc.Sequence): + if isinstance(input_value, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. - for i, val in enumerate(input): + for i, val in enumerate(input_value): # Name each input with the index appended to the original name of the # argument. _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) @@ -499,10 +499,10 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. return - elif isinstance(input, abc.Mapping): + elif isinstance(input_value, abc.Mapping): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. - for key, val in input.items(): + for key, val in input_value.items(): _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) # Return here since the dict by itself is not a valid input. @@ -513,11 +513,11 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): # a part of the onnx graph or not. input_names.append(name) - if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input, torch.Tensor): - if input.requires_grad: + if (onnx_graph is None or name in onnx_graph_input_names) and isinstance(input_value, torch.Tensor): + if input_value.requires_grad: input_names_require_grad.append(name) - dynamic_axes.update(_add_dynamic_shape(name, input)) - input_shape.append(list(input.size())) + dynamic_axes.update(_add_dynamic_shape(name, input_value)) + input_shape.append(list(input_value.size())) # Ignore optional inputs explicitly specified as None # ONNX exporter may remove unused inputs diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7758603c484fc..e9576848a20cb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5545,23 +5545,39 @@ def forward(self, input): def test_kwargs_dict_input(): class DictNet(torch.nn.Module): def __init__(self): - super(DictNet, self).__init__() + super().__init__() self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) - def forward(self, *args, **kwargs): + def forward(self, *args, named_kwarg, **kwargs): + a = named_kwarg["named_one"] + b = named_kwarg["named_two"]["named_three"] + c = named_kwarg["named_two"]["named_four"] + d = named_kwarg["named_five"]["named_six"] + e = named_kwarg["named_five"]["named_seven"]["named_eight"] batch = kwargs["batch"] - a = batch["one_value"] - b = batch["two_value"]["three_value"] - c = batch["two_value"]["four_value"] - d = batch["five_value"]["six_value"] - e = batch["five_value"]["seven_value"]["eight_value"] - return self.dummy + a + b + c + d + e + f = batch["one_value"] + g = batch["two_value"]["three_value"] + h = batch["two_value"]["four_value"] + i = batch["five_value"]["six_value"] + j = batch["five_value"]["seven_value"]["eight_value"] + return self.dummy + a + b + c + d + e + f + g + h + i + j device = "cuda" - N, D_in, H, D_out = 64, 784, 500, 10 + N, D_in = 64, 784 pt_model = DictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) x = torch.randn(N, D_in, device=device) + named_kwarg = { + "named_one": torch.randn(N, D_in, device=device), + "named_two": { + "named_three": torch.randn(N, D_in, device=device), + "named_four": torch.randn(N, D_in, device=device), + }, + "named_five": { + "named_six": torch.randn(N, D_in, device=device), + "named_seven": {"named_eight": torch.randn(N, D_in, device=device)}, + }, + } batch = { "one_value": torch.randn(N, D_in, device=device), "two_value": { @@ -5576,7 +5592,9 @@ def forward(self, *args, **kwargs): batch_copy = copy.deepcopy(batch) x_copy = copy.deepcopy(x) - _test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy)) + _test_helpers.assert_values_are_close( + pt_model(x, named_kwarg=named_kwarg, batch=batch), ort_model(x_copy, named_kwarg=named_kwarg, batch=batch_copy) + ) @pytest.mark.parametrize("training_mode", [False, True])