-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize label input sparsity check and refactor #20636
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pengwa
added
the
training
issues related to ONNX Runtime training; typically submitted using template
label
May 10, 2024
guyang3532
force-pushed
the
yangu/refactor_lable_check_pattern
branch
from
May 10, 2024 07:43
1ef1d50
to
2c29b27
Compare
pengwa
reviewed
May 10, 2024
orttraining/orttraining/core/optimizer/graph_transformer_config.h
Outdated
Show resolved
Hide resolved
pengwa
reviewed
May 10, 2024
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
pengwa
reviewed
May 10, 2024
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
pengwa
reviewed
May 10, 2024
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
pengwa
reviewed
May 10, 2024
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
pengwa
changed the title
generalize label input sparsity check and refactor
Generalize label input sparsity check and refactor
May 10, 2024
guyang3532
force-pushed
the
yangu/refactor_lable_check_pattern
branch
from
May 10, 2024 11:16
4ac0c06
to
f0e0134
Compare
pengwa
approved these changes
May 10, 2024
poweiw
pushed a commit
to poweiw/onnxruntime
that referenced
this pull request
Jun 25, 2024
### Description The InsertGatherBeforeSceLoss optimization is enabled when the density of label padding less than 90%. We need to check the density of the label padding to decide whether enable the optimization. Before this pr, we just check the inputs of graph and correlate one with the SCE node by iterate graph from the SCE node back to one graph input. This is hard to be general because there may be complicated pattern between graph input and SCE node. This pr check padding density by the direct input of SCE module rather than the input of graph at the first graph execution when exporting onnx graph. And if the density < 90%, insert a flag PythonOp after the SCE node as: ``` SoftmaxCrossEntropy | PythonOp (func_name: FlagAndPrintDensity) (insert if density < 90%) | Following graph ``` When the InsertGatherBeforeSceLoss is invoked, it check if there is the flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if it is, remove it and do the padding elimination optimization. If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input density each step by the PythonOp (func_name: FlagAndPrintDensity). In this case the PythonOp will not be removed.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The InsertGatherBeforeSceLoss optimization is enabled when the density of label padding less than 90%. We need to check the density of the label padding to decide whether enable the optimization.
Before this pr, we just check the inputs of graph and correlate one with the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern between graph input and SCE node.
This pr check padding density by the direct input of SCE module rather than the input of graph at the first graph execution when exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
When the InsertGatherBeforeSceLoss is invoked, it check if there is the flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if it is, remove it and do the padding elimination optimization.
If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input density each step by the PythonOp (func_name: FlagAndPrintDensity). In this case the PythonOp will not be removed.