Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Correctly profile peak memory usage & Upgrade PyTorch XLA #9438

Merged
merged 5 commits into from
Oct 30, 2024

Conversation

WoosukKwon
Copy link
Collaborator

Should be merged after #9437 and after the 10/17 version of PyTorch XLA nightly is available.

This PR upgrades the PyTorch XLA, and uses the peak_bytes_used to correctly profile the peak HBM usage during the dummy profile run.

@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Oct 17, 2024
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) October 20, 2024 17:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 20, 2024
@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Oct 20, 2024

Hey @WoosukKwon - I have been using this PR for the compressed-tensors loading as well as #9437.

I noticed that even with this PR and #9437, value of m["peak_bytes_used"] >> m["bytes_used"] (it can be 2-4GB more) before we do the dummy run. This is often bigger than the activation size, which means we are wasting space for KV cache allocation. This extra memory occurs when we actually load the weights:

  def load_model(self, *, model_config: ModelConfig,
                 device_config: DeviceConfig,
                 lora_config: Optional[LoRAConfig],
                 parallel_config: ParallelConfig,
                 scheduler_config: SchedulerConfig,
                 cache_config: CacheConfig) -> nn.Module:
      target_device = torch.device(device_config.device)
      with set_default_torch_dtype(model_config.dtype):
          with target_device:
              model = _initialize_model(model_config, self.load_config,
                                        lora_config, cache_config,
                                        scheduler_config) # << here, value of peak_bytes_used == bytes_used == weight_size

          model.load_weights(self._get_all_weights(model_config, model)) # here, peak_bytes_used > bytes_used

So, I was thinking we might want to reset peak_bytes_used prior to running the dummy pass

@mergify mergify bot added documentation Improvements or additions to documentation ci/build labels Oct 30, 2024
@robertgshaw2-redhat robertgshaw2-redhat merged commit 211fe91 into main Oct 30, 2024
90 checks passed
@robertgshaw2-redhat robertgshaw2-redhat deleted the tpu-peak-mem branch October 30, 2024 09:41
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
sergeykochetkov pushed a commit to sergeykochetkov/vllm_spec_decoding that referenced this pull request Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants