Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: Support destroying all KV cache during runtime #10810

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ def _initialize_kv_caches(self) -> None:

self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def _destroy_kv_caches(self) -> None:
"""Destroy the KV cache in the worker(s) without shutting down.
"""
self.model_executor.stop_remote_worker_execution_loop()
self.model_executor.destroy_cache()

@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
Expand Down
5 changes: 5 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def initialize_cache(self, num_gpu_blocks: int,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

def destroy_cache(self) -> None:
"""Destroy the KV cache in all workers.
"""
self._run_workers("destroy_cache")

def execute_model(
self,
execute_model_req: ExecuteModelRequest,
Expand Down
7 changes: 7 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def initialize_cache(self, num_gpu_blocks: int,
"""
raise NotImplementedError

# TODO: Make this an abstract method and all executors should implement it.
# @abstractmethod
def destroy_cache(self) -> None:
"""Destroy the KV cache.
"""
raise NotImplementedError

@abstractmethod
def execute_model(
self, execute_model_req: ExecuteModelRequest
Expand Down
5 changes: 5 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:

self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def destroy_cache(self) -> None:
"""Destroy the KV cache by invoking the underlying worker.
"""
self.driver_worker.destroy_cache()

def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
Expand Down
13 changes: 13 additions & 0 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def _allocate_kv_cache(
device=device))
return kv_cache

def destroy(self) -> None:
# Iterate over all the caches and destroy them.
while self.gpu_cache:
tensor = self.gpu_cache.pop()
del tensor

while self.cpu_cache:
tensor = self.cpu_cache.pop()
del tensor

import gc
gc.collect()

def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
Expand Down
9 changes: 9 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def initialize_cache(self, num_gpu_blocks: int,
self._init_cache_engine()
self._warm_up_model()

def destroy_cache(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
while self.cache_engine:
cache_engine = self.cache_engine.pop()
cache_engine.destroy()
self.gpu_cache = None
torch.cuda.empty_cache()

def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
Expand Down
7 changes: 7 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def initialize_cache(self, num_gpu_blocks: int,
"""
raise NotImplementedError

# TODO: Make this an abstract method and all workers should implement it.
# @abstractmethod
def destroy_cache(self) -> None:
"""Clear out all the KV cache in the current worker.
"""
raise NotImplementedError

@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
Expand Down