Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize label input sparsity check and refactor #20636

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,6 @@ debugging).
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_ENABLE_SPARSE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the input data sparsity
based performance optimizations, including embedding sparsity and label sparsity.
This optimization is applicable when using optimum, which has an implementation of the ModuleWithLoss class that wraps the HuggingFace Training that allows loss computation inside ONNX Runtime (ORT).
If you're not using optimum but want to implement a similar wrapper in your codebase to compute the loss inside ONNX Runtime (ORT), you can refer to this [Link](ORTModule_ModuleWithLoss_Wrapper.md) for detailed steps and guidelines on how to achieve this.

```bash
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_PRINT_INPUT_DENSITY

- **Feature Area**: *ORTMODULE/RuntimeInspector*
Expand Down Expand Up @@ -254,6 +241,17 @@ data sparsity based performance optimizations.
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the label input
data sparsity based performance optimizations.

```bash
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER=1 # Enable
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER=0 # Disable
```

#### ORTMODULE_CACHE_DIR

- **Feature Area**: *ORTMODULE/RuntimeOptions*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ constexpr const char* kInspectActivationFuncName =
"onnxruntime.training.utils.hooks._statistics_subscriber._InspectActivation";
constexpr const char* kIncrementStepFuncName =
"onnxruntime.training.utils.hooks._subscriber_manager._IncrementStep";
constexpr const char* kFlagPaddingEliminationFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagPaddingElimination";
constexpr const char* kFlagAndPrintDensityFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagAndPrintDensity";

void PushAllOutputNode(Graph& graph, std::queue<Node*>& q, Node* node, std::unordered_set<Node*>& visited) {
for (auto iter = node->OutputNodesBegin(); iter != node->OutputNodesEnd(); ++iter) {
Expand Down Expand Up @@ -396,26 +396,28 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
if (outputNodeCount != 1) {
continue;
}
auto embedding_output_node = graph.GetNode(node.OutputNodesBegin()->Index());
if (embedding_output_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_output_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_output_node->GetAttributes().at("func_name").s()) !=
kFlagPaddingEliminationFuncName) {
Node* embedding_input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[1]->Name());
if (embedding_input_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_input_node, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(embedding_input_node->GetAttributes().at("func_name").s()) !=
kFlagAndPrintDensityFuncName) {
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (graph_utils::CanRemoveNode(graph, *embedding_output_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_output_node)) {
modified = true;
if (!print_density_) {
if (graph_utils::CanRemoveNode(graph, *embedding_input_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_input_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_output_node->Name() +
"(" + embedding_output_node->OpType() + ")");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,15 @@ namespace onnxruntime {
*/
class PaddingElimination : public GraphTransformer {
public:
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers) {}
explicit PaddingElimination(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const bool print_input_density = false) noexcept
: GraphTransformer("PaddingElimination", compatible_execution_providers),
print_density_(print_input_density) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
bool print_density_ = false;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

namespace onnxruntime {

namespace {

constexpr const char* kFlagAndPrintDensityFuncName =
"onnxruntime.training.ortmodule._runtime_inspector.FlagAndPrintDensity";
} // namespace

Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/,
const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter InsertGatherBeforeSceLoss");

if (sparse_label_input_names_.size() == 0) {
LOG_DEBUG_INFO(logger, "Exit InsertGatherBeforeSceLoss, no sparse label input names.");
return Status::OK();
}

GraphViewer graph_viewer(graph);
[[maybe_unused]] size_t handled_sce_node_count = 0; // For summary
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
Expand All @@ -48,7 +49,7 @@ Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*
const NodeArg* label_input_arg = node.InputDefs()[1];

// Check whether this SCE node is handled or not.
const Node* labels_producer = graph.GetProducerNode(label_input_arg->Name());
Node* labels_producer = graph.GetMutableProducerNode(label_input_arg->Name());
// Skip if already inserted a ShrunkenGather node.
if (labels_producer && graph_utils::IsSupportedOptypeVersionAndDomain(
*labels_producer, "ShrunkenGather", {1}, kMSDomain)) {
Expand All @@ -57,18 +58,28 @@ Status InsertGatherBeforeSceLoss::ApplyImpl(Graph& graph, bool& modified, int /*
continue;
}

// Label input can be a graph input or from a Reshape node taking a graph input as its data input.
if (labels_producer && graph_utils::IsSupportedOptypeVersionAndDomain(
*labels_producer, "Reshape", {1, 5, 13, 14}, kOnnxDomain)) {
label_input_arg = labels_producer->InputDefs()[0];
}
// Then check if the label input is graph input and in the sparse label input list.
if (!graph.IsInputsIncludingInitializers(label_input_arg) ||
std::find(sparse_label_input_names_.begin(), sparse_label_input_names_.end(),
label_input_arg->Name()) == sparse_label_input_names_.end()) {
if (labels_producer == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*labels_producer, "PythonOp", {1}, kMSDomain) ||
static_cast<std::string>(labels_producer->GetAttributes().at("func_name").s()) !=
kFlagAndPrintDensityFuncName) {
LOG_DEBUG_INFO(logger, "Skip node " + node.Name() + "(" + node.OpType() +
") due to labels input is not a graph input or not in the sparse label input list.");
") due to labels input is not produced by a PythonOp node with flag " +
kFlagAndPrintDensityFuncName + ".");
continue;
} else if (!print_density_) {
if (graph_utils::CanRemoveNode(graph, *labels_producer, logger)) {
if (graph_utils::RemoveNode(graph, *labels_producer)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + labels_producer->Name() +
"(" + labels_producer->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + labels_producer->Name() +
"(" + labels_producer->OpType() + ")");
continue;
}
}

// Check shape requirements.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ namespace onnxruntime {
class InsertGatherBeforeSceLoss : public GraphTransformer {
public:
InsertGatherBeforeSceLoss(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const std::vector<std::string>& sparse_label_input_names = {}) noexcept
const bool print_input_density = false) noexcept
: GraphTransformer("InsertGatherBeforeSceLoss", compatible_execution_providers),
sparse_label_input_names_{sparse_label_input_names} {
print_density_(print_input_density) {
}

/**
Expand All @@ -79,7 +79,7 @@ class InsertGatherBeforeSceLoss : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
std::vector<std::string> sparse_label_input_names_;
bool print_density_ = false;
};

} // namespace onnxruntime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable compute optimizer.
bool enable_compute_optimizer{false};

bool print_input_density{false};

// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;
guyang3532 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));
transformers.emplace_back(std::make_unique<UpStreamReshapeGraphTransformer>(compatible_eps));
transformers.emplace_back(std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps,
config.sparse_label_input_names));
config.print_input_density));
#if defined(USE_CUDA) || defined(USE_ROCM)
// Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel.
// Once we have a CPU kernel for PadAndUnflatten, we can remove the guard.
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps));
transformers.emplace_back(std::make_unique<PaddingElimination>(compatible_eps,
config.print_input_density));
transformers.emplace_back(std::make_unique<Conv1dReplacement>(compatible_eps));
#endif
}
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("print_input_density", &TrainingGraphTransformerConfiguration::print_input_density)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
Expand Down
Loading
Loading