diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 26f0ce515157b..dd9220e5c082a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -85,6 +85,31 @@ def get_primitive_dtype(value): return f"{str(type(value))}_{value}" if isinstance(value, bool) else str(type(value)) +def flatten_kwargs(kwargs, device): + def _flatten_kwargs(value, name): + if _PrimitiveType.is_primitive_type(value): + flattened_kwargs[name] = _PrimitiveType.get_tensor(value, device) + elif isinstance(value, torch.Tensor): + flattened_kwargs[name] = value + elif isinstance(value, abc.Sequence): + # If the input is a sequence (like a list), expand the list so that + # each element of the list has a corresponding entry in the flattened + # kwargs dict + for idx, val in enumerate(value): + _flatten_kwargs(val, f"{name}_{idx}") + elif isinstance(value, abc.Mapping): + # If the input is a mapping (like a dict), expand the dict so that + # each element of the dict has an entry in the flattened kwargs dict + for key, val in value.items(): + _flatten_kwargs(val, f"{name}_{key}") + + flattened_kwargs = {} + for key, value in kwargs.items(): + _flatten_kwargs(value, key) + + return flattened_kwargs + + class _InputInfo(object): def __init__( self, @@ -94,8 +119,6 @@ def __init__( dynamic_axes=None, schema=None, num_positionals=0, - num_expanded_positionals_non_none=0, - keyword_names=None, ): self.names = names self.shape = shape @@ -103,8 +126,7 @@ def __init__( self.dynamic_axes = dynamic_axes if dynamic_axes else {} self.schema = schema if schema else [] self.num_positionals = num_positionals - self.num_expanded_positionals_non_none = num_expanded_positionals_non_none - self.keyword_names = keyword_names + self.kwargs = None def __repr__(self) -> str: return f"""_InputInfo class: @@ -113,21 +135,15 @@ def __repr__(self) -> str: \tRequire gradient: {self.require_grad_names} \tDynamic axes: {self.dynamic_axes} \tSchema: {self.schema} - \t#Positionals (total): {self.num_positionals} - \t#Expanded Positionals (non-None): {self.num_expanded_positionals_non_none} - \tKeyword names: {self.keyword_names}""" + \t#Positionals (total): {self.num_positionals}""" def flatten(self, args, kwargs, device): """Flatten args and kwargs in a single tuple of tensors with strict ordering""" ret = [_PrimitiveType.get_tensor(arg, device) if _PrimitiveType.is_primitive_type(arg) else arg for arg in args] - ret += [ - _PrimitiveType.get_tensor(kwargs[name], device) - if _PrimitiveType.is_primitive_type(kwargs[name]) - else kwargs[name] - for name in self.names - if name in kwargs - ] + flattened_kwargs = flatten_kwargs(kwargs, device) + ret += [flattened_kwargs[name] for name in self.names if name in flattened_kwargs] + self.kwargs = kwargs # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. @@ -140,13 +156,8 @@ def unflatten(self, flat_args): """Unflatten tuple of tensors into args and kwargs""" args = tuple(flat_args[: self.num_positionals]) - kwargs = { - name: arg - for name, arg in zip( - self.names[self.num_expanded_positionals_non_none :], flat_args[self.num_positionals :] - ) - if name in self.keyword_names - } + kwargs = self.kwargs + self.kwargs = None return args, kwargs @@ -158,7 +169,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu * Initializers: computed from original PyTorch model parameters """ - def _expand_inputs(current_input, non_none_inputs): + def _expand_inputs(current_input, non_none_inputs, name=""): # The exporter handles input lists by expanding them so that each # element of the list is its own input. # ORTModule must match this behavior by also expanding the inputs. @@ -168,28 +179,33 @@ def _expand_inputs(current_input, non_none_inputs): if isinstance(current_input, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself - for inp in current_input: - _expand_inputs(inp, non_none_inputs) + for i, inp in enumerate(current_input): + _expand_inputs(inp, non_none_inputs, f"{name}_{i}" if name else str(i)) elif isinstance(current_input, abc.Mapping): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself - for _, val in current_input.items(): - _expand_inputs(val, non_none_inputs) + for key, val in current_input.items(): + _expand_inputs(val, non_none_inputs, f"{name}_{key}" if name else key) else: # else just collect all the non none inputs within non_none_inputs - non_none_inputs.append(current_input) + if isinstance(non_none_inputs, abc.Sequence): + non_none_inputs.append(current_input) + elif isinstance(non_none_inputs, abc.Mapping): + non_none_inputs[name] = current_input # User inputs non_none_inputs = [] _expand_inputs(inputs, non_none_inputs) + flattened_kwargs_inputs = {} + _expand_inputs(kwargs, flattened_kwargs_inputs) buffer_names_dict = {buffer_name: inp for buffer_name, inp in buffer_names} result = [] for input_idx, name in enumerate(onnx_input_names): inp = None - if name in kwargs and kwargs[name] is not None: + if name in flattened_kwargs_inputs and flattened_kwargs_inputs[name] is not None: # Only use keywords coming from user that are expected by ONNX model - inp = kwargs[name] + inp = flattened_kwargs_inputs[name] if inp is None: try: @@ -470,29 +486,28 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): if input is None or isinstance(input, str): # Drop all None and string inputs and return 0. - return 0 + return - num_expanded_non_none_inputs = 0 if isinstance(input, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself. for i, val in enumerate(input): # Name each input with the index appended to the original name of the # argument. - num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. - return num_expanded_non_none_inputs + return elif isinstance(input, abc.Mapping): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. for key, val in input.items(): - num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) + _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. - return num_expanded_non_none_inputs + return # InputInfo should contain all the names irrespective of whether they are # a part of the onnx graph or not. @@ -504,9 +519,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): dynamic_axes.update(_add_dynamic_shape(name, input)) input_shape.append(list(input.size())) - # A single input non none input was processed, return 1 - return 1 - # Ignore optional inputs explicitly specified as None # ONNX exporter may remove unused inputs onnx_graph_input_names = [] @@ -518,7 +530,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): input_names_require_grad = [] input_shape = [] var_positional_idx = 0 - num_expanded_non_none_positional_inputs = 0 for input_idx, input_parameter in enumerate(all_input_parameters): if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL: @@ -528,7 +539,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = inputs[args_i] - num_expanded_non_none_positional_inputs += _add_input(name, inp, onnx_graph, onnx_graph_input_names) + _add_input(name, inp, onnx_graph, onnx_graph_input_names) elif ( input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -538,23 +549,17 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): name = input_parameter.name inp = None input_idx += var_positional_idx - is_positional = True if input_idx < len(inputs) and inputs[input_idx] is not None: inp = inputs[input_idx] elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] - is_positional = False - num_expanded_non_none_inputs_local = _add_input(name, inp, onnx_graph, onnx_graph_input_names) - if is_positional: - num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local + _add_input(name, inp, onnx_graph, onnx_graph_input_names) elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() for name, inp in kwargs.items(): if name not in input_names: _add_input(name, inp, onnx_graph, onnx_graph_input_names) - # input_names have been expanded so to get the correct number of non none - # positional names, we need to collect the num_expanded_non_none_positional_inputs. return _InputInfo( names=input_names, shape=input_shape, @@ -562,8 +567,6 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): dynamic_axes=dynamic_axes, schema=schema, num_positionals=len(inputs), - num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs, - keyword_names=list(kwargs.keys()), ) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index feb5ed7d12e64..be256d047d1cc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -194,7 +194,12 @@ def get_device_from_inputs(args, kwargs): def _create_iobinding(io_binding, inputs, model, device): """Creates IO binding for a `model` inputs and output""" for idx, value_info in enumerate(model.graph.input): - io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx]))) + io_binding.bind_ortvalue_input( + value_info.name, + OrtValue( + _ortvalue_from_torch_tensor(inputs[idx] if inputs[idx].is_contiguous() else inputs[idx].contiguous()) + ), + ) device_id = get_device_index(device) for value_info in model.graph.output: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ce6412f5f485e..554184ef7d787 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5535,3 +5535,61 @@ def forward(self, input): # BatchNormInternal is for training, while BatchNormalization is for inference. assert "BatchNormInternal" in [node.op_type for node in training_model.graph.node] assert "BatchNormalization" in [node.op_type for node in eval_model.graph.node] + + +def test_kwargs_dict_input(): + class DictNet(torch.nn.Module): + def __init__(self): + super(DictNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, *args, **kwargs): + batch = kwargs["batch"] + a = batch["one_value"] + b = batch["two_value"]["three_value"] + c = batch["two_value"]["four_value"] + d = batch["five_value"]["six_value"] + e = batch["five_value"]["seven_value"]["eight_value"] + return self.dummy + a + b + c + d + e + + device = "cuda" + N, D_in, H, D_out = 64, 784, 500, 10 + pt_model = DictNet().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = torch.randn(N, D_in, device=device) + batch = { + "one_value": torch.randn(N, D_in, device=device), + "two_value": { + "three_value": torch.randn(N, D_in, device=device), + "four_value": torch.randn(N, D_in, device=device), + }, + "five_value": { + "six_value": torch.randn(N, D_in, device=device), + "seven_value": {"eight_value": torch.randn(N, D_in, device=device)}, + }, + } + batch_copy = copy.deepcopy(batch) + x_copy = copy.deepcopy(x) + + _test_helpers.assert_values_are_close(pt_model(x, batch=batch), ort_model(x_copy, batch=batch_copy)) + + +@pytest.mark.parametrize("training_mode", [False, True]) +def test_non_contiguous_tensors_as_inputs(training_mode): + class NonContigousNet(torch.nn.Module): + def __init__(self): + super(NonContigousNet, self).__init__() + self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) + + def forward(self, non_contiguous_tensor): + return self.dummy + non_contiguous_tensor + + device = "cuda" + pt_model = NonContigousNet().to(device) + pt_model.train(training_mode) + ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_model.train(training_mode) + x = torch.arange(12).view(4, 3).t().to(device) + x_copy = copy.deepcopy(x) + assert not x.is_contiguous() + _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy))