diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py index 7b24bb400b162..1213342004d48 100644 --- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import copy +import logging import os from typing import List, Optional, Set, Tuple, Union @@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti def _gradient_model_for( model: onnx.ModelProto, requires_grad: Set[str], - output_names: List[str], loss_name: str, options: Optional[SessionOptions] = None, ) -> onnx.ModelProto: """Builds the gradient graph on top of the given input forward only graph.""" - builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options) + logging.debug( + "The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name + ) + + builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options) builder.build() return onnx.load_from_string(builder.get_model()) @@ -123,7 +127,7 @@ def build_gradient_graph( optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options)) # Assumption is that the first graph output is the loss output - gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options) + gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options) _reorder_outputs(gradient_model, output_names, requires_grad) diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 9f90a5a0c30cd..a2922353ac70e 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs): model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY ) + logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__) + _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad) accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model)