Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize label input sparsity check and refactor #20636

Merged
merged 2 commits into from
May 10, 2024

Conversation

guyang3532
Copy link
Contributor

Description

The InsertGatherBeforeSceLoss optimization is enabled when the density of label padding less than 90%. We need to check the density of the label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather than the input of graph at the first graph execution when exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:

           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph

When the InsertGatherBeforeSceLoss is invoked, it check if there is the flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input density each step by the PythonOp (func_name: FlagAndPrintDensity). In this case the PythonOp will not be removed.

@guyang3532 guyang3532 requested a review from pengwa May 10, 2024 05:52
@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label May 10, 2024
@guyang3532 guyang3532 force-pushed the yangu/refactor_lable_check_pattern branch from 1ef1d50 to 2c29b27 Compare May 10, 2024 07:43
@pengwa pengwa requested a review from wschin May 10, 2024 09:51
@pengwa pengwa changed the title generalize label input sparsity check and refactor Generalize label input sparsity check and refactor May 10, 2024
@guyang3532 guyang3532 force-pushed the yangu/refactor_lable_check_pattern branch from 4ac0c06 to f0e0134 Compare May 10, 2024 11:16
@guyang3532 guyang3532 merged commit cfe830b into main May 10, 2024
95 checks passed
@guyang3532 guyang3532 deleted the yangu/refactor_lable_check_pattern branch May 10, 2024 13:55
poweiw pushed a commit to poweiw/onnxruntime that referenced this pull request Jun 25, 2024
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants