Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] overhaul memory profiling and fix backward compatibility #10511

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Nov 21, 2024

fixes #10451 , and clearly explain the memory classification and the procedure.

I also added the initial pytorch memory, to be aligned with the pytorch memory profiler.

the profiling procedure is extracted into vllm/utils , so that we can use it later in v1 too.

Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@mgoin mgoin self-requested a review November 21, 2024 02:58
Comment on lines +1663 to +1667
| cuda memory |
| | torch memory | |
Before profiling: | --------- | +++++++++ | |
During profiling (peak): | --------- | +++++++++++++ | *** |
After profiling: | --------- | +++++++++++ | *** |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alignment here is a bit off.

Comment on lines +236 to 242
(result.before_profile.cuda_memory_in_bytes) / (1024**3),
(result.before_profile.torch_memory_in_bytes) / (1024**3),
(result.after_profile.cuda_memory_in_bytes) / (1024**3),
(result.after_profile.torch_memory_in_bytes) / (1024**3),
result.torch_peak_memory_in_bytes / (1024**3),
result.non_torch_memory_in_bytes / (1024**3),
available_kv_cache_memory / (1024**3),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw we can import GiB_bytes from vllm.utils.

@DarkLight1337
Copy link
Member

cc @joerunde

self.model_runner.profile_run()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we encapsulate more of the code into the profiler? The device sychronization, garbage collection / cache clean up, and the call to torch.cuda.reset_peak_memory_stats() could be moved into the context.

result.after_profile.measure()

result.torch_peak_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
Copy link
Contributor

@tjohnson31415 tjohnson31415 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This torch peak memory will include bytes from other objects in the Python process other than the vLLM worker / LLM instance that we want to profile. Hence, the gpu_memory_utilization parameter would limit the total usage of the python process, not the usage of just the LLM instance. To fix, we need to measure the peak relative to a baseline torch memory recorded as part of the profiling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, got your point. I think we need to record the baseline before we load the model.

Copy link

mergify bot commented Nov 23, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youkaichao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Breaking Change in gpu_memory_utilization Behavior in vLLM 0.6.4
3 participants