Skip to content

Commit

Permalink
[Platform] move current_memory_usage() into platform (#11369)
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Shen <[email protected]>
  • Loading branch information
shen-shanshan authored Jan 15, 2025
1 parent 1a51b9f commit 9ddac56
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
7 changes: 7 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
Expand Down
9 changes: 9 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,15 @@ def is_pin_memory_available(cls) -> bool:
return False
return True

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
"""
Return the memory usage in bytes.
"""
raise NotImplementedError

@classmethod
def get_punica_wrapper(cls) -> str:
"""
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,10 @@ def verify_quantization(cls, quant: str) -> None:
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
7 changes: 7 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")
return False

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)
8 changes: 1 addition & 7 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif current_platform.is_xpu():
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
return current_platform.get_current_memory_usage(self.device)

def __enter__(self):
self.initial_memory = self.current_memory_usage()
Expand Down

0 comments on commit 9ddac56

Please sign in to comment.