Skip to content

Commit

Permalink
Build gradient graph starting at the loss alone (#17240)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored and centwang committed Aug 28, 2023
1 parent 3913ef7 commit de20ce5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import copy
import logging
import os
from typing import List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti
def _gradient_model_for(
model: onnx.ModelProto,
requires_grad: Set[str],
output_names: List[str],
loss_name: str,
options: Optional[SessionOptions] = None,
) -> onnx.ModelProto:
"""Builds the gradient graph on top of the given input forward only graph."""

builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options)
logging.debug(
"The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name
)

builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options)
builder.build()
return onnx.load_from_string(builder.get_model())

Expand Down Expand Up @@ -123,7 +127,7 @@ def build_gradient_graph(
optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))

# Assumption is that the first graph output is the loss output
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options)
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)

_reorder_outputs(gradient_model, output_names, requires_grad)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs):
model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
)

logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__)

_training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)

accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model)
Expand Down

0 comments on commit de20ce5

Please sign in to comment.