Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve perf for mem efficient grad mgmt #20480

Merged
merged 4 commits into from
May 10, 2024
Merged

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Apr 26, 2024

Improve perf for mem efficient grad mgmt

When memory efficient gradient mangement feature is enabled, the weight retrieval PythonOp for every layers will be launched at the beginning of the forward, which would make GPU stream idle for few milliseconds. The reason is the ReversedDFS ordering cannot ALWAYS handle such input branching well, so we introduce a distantance-to-input_leaf concepts when doing the reversedDFS, which not only move the problematical PythonOp to the place where it is needed, but also those Cast ops following the weight retrieval to the place where it is needed.

Main branch: 102.19 - 26.35s = 75.84s for 260 steps(4627samples), 61.04sample/second
This PR: 100.28s - 25.10s = 75.18s for 260 steps. 61.54samples/second (+0.8% gains)

Main branch:

image

This PR:

image

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Apr 26, 2024
Copy link
Contributor

@AdamLouly AdamLouly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa
Copy link
Contributor Author

pengwa commented May 10, 2024

Thanks @AdamLouly !!

@pengwa pengwa merged commit 56f7035 into main May 10, 2024
95 checks passed
@pengwa pengwa deleted the pengwa/perf_efficient_mem branch May 10, 2024 00:09
poweiw pushed a commit to poweiw/onnxruntime that referenced this pull request Jun 25, 2024
### Improve perf for mem efficient grad mgmt

When memory efficient gradient mangement feature is enabled, the weight
retrieval PythonOp for every layers will be launched at the beginning of
the forward, which would make GPU stream idle for few milliseconds. The
reason is the ReversedDFS ordering cannot ALWAYS handle such input
branching well, so we introduce a distantance-to-input_leaf concepts
when doing the reversedDFS, which not only move the problematical
PythonOp to the place where it is needed, but also those Cast ops
following the weight retrieval to the place where it is needed.

Main branch: 102.19 - 26.35s = 75.84s for 260 steps(4627samples),
61.04sample/second
This PR: 100.28s - 25.10s = 75.18s for 260 steps. 61.54samples/second
(+0.8% gains)

Main branch:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/75c4131e-dade-49b0-aa8b-ee1c637ad9a8)


This PR:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/e590a536-3b80-4f51-b89f-f25a55ddd7e2)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants