From 8eca8b1655d3ff3ecc16e7a0a8cfdc24803c38fe Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 1 Dec 2024 22:54:36 +0200 Subject: [PATCH] [Core]: Support destroying all KV cache during runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements #10714 API Design: - Destroy (this PR implements): `vllm.LLM().llm_engine._destroy_kv_caches()` - ReInitialize (already have): `vllm.LLM().llm_engine._initialize_kv_caches()` - Stop loop (already have): `vllm.LLM().llm_engine.model_executor.stop_remote_worker_execution_loop()` This PR only implements `_destroy_kv_caches` for GPU executor and workers, as I don’t have other available hardware, feel free to take over this PR to implement others, and once we finish all the implementations, we can make `destroy_cache()` an abstract method. Also, since the engine won’t generate without KV Caches (will throw errors), this PR assumes that the developers will handle everything on their side so that no request will be sent to generate after `_destroy_kv_caches()` and before `_initialize_kv_caches()` (in sleep mode) Code for testing: ```python import ray, time from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @ray.remote class LLMRayActor: def __init__(self, *args, **kwargs): import vllm if not kwargs["tensor_parallel_size"] == 1: kwargs["worker_use_ray"] = True self.llm = vllm.LLM(*args, **kwargs) def generate(self, *args, **kwargs): return self.llm.generate(*args, **kwargs) def destroy_cache(self): self.stop_remote_worker_execution_loop() self.llm.llm_engine._destroy_kv_caches() def load_cache(self): self.stop_remote_worker_execution_loop() self.llm.llm_engine._initialize_kv_caches() def stop_remote_worker_execution_loop(self): self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop() def create_vllm_engines( num_engines: int, tensor_parallel_size: int, model: str, ): vllm_engines = [] for _ in range(num_engines): num_gpus = int(tensor_parallel_size == 1) scheduling_strategy = None if tensor_parallel_size > 1: bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size pg = placement_group(bundles) ray.get(pg.ready()) scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0 ) vllm_engines.append( LLMRayActor.options( num_cpus=1, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, ).remote( model, tensor_parallel_size=tensor_parallel_size, ) ) return vllm_engines if __name__ == "__main__": # engines = create_vllm_engines(2, 2, "meta-llama/Llama-3.1-8B-Instruct") engines = create_vllm_engines(4, 1, "meta-llama/Llama-3.1-8B-Instruct") ref = [] for engine in engines: ref.append(engine.generate.remote("San Francisco is a")) print(f"output: {ray.get(ref)}") ref = [] for engine in engines: ref.append(engine.destroy_cache.remote()) ray.get(ref) time.sleep(5) ref = [] for engine in engines: ref.append(engine.load_cache.remote()) ray.get(ref) ref = [] for engine in engines: ref.append(engine.generate.remote("New York is a")) print(f"output: {ray.get(ref)}") ``` Signed-off-by: Hollow Man --- vllm/engine/llm_engine.py | 6 ++++++ vllm/executor/distributed_gpu_executor.py | 5 +++++ vllm/executor/executor_base.py | 7 +++++++ vllm/executor/gpu_executor.py | 5 +++++ vllm/worker/cache_engine.py | 13 +++++++++++++ vllm/worker/worker.py | 9 +++++++++ vllm/worker/worker_base.py | 7 +++++++ 7 files changed, 52 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7911dc8d04500..414d9b30b1a4b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -488,6 +488,12 @@ def _initialize_kv_caches(self) -> None: self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def _destroy_kv_caches(self) -> None: + """Destroy the KV cache in the worker(s) without shutting down. + """ + self.model_executor.stop_remote_worker_execution_loop() + self.model_executor.destroy_cache() + @classmethod def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index deb7cb1c97ef5..fd2a9b5def8fc 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -68,6 +68,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def destroy_cache(self) -> None: + """Destroy the KV cache in all workers. + """ + self._run_workers("destroy_cache") + def execute_model( self, execute_model_req: ExecuteModelRequest, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 9cba189dd57f9..eeb51206e7d49 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -62,6 +62,13 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError + # TODO: Make this an abstract method and all executors should implement it. + # @abstractmethod + def destroy_cache(self) -> None: + """Destroy the KV cache. + """ + raise NotImplementedError + @abstractmethod def execute_model( self, execute_model_req: ExecuteModelRequest diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7fa34456028dd..87975251b9369 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -82,6 +82,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def destroy_cache(self) -> None: + """Destroy the KV cache by invoking the underlying worker. + """ + self.driver_worker.destroy_cache() + def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..79c8d480e2b8a 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -84,6 +84,19 @@ def _allocate_kv_cache( device=device)) return kv_cache + def destroy(self) -> None: + # Iterate over all the caches and destroy them. + while self.gpu_cache: + tensor = self.gpu_cache.pop() + del tensor + + while self.cpu_cache: + tensor = self.cpu_cache.pop() + del tensor + + import gc + gc.collect() + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d58cb029618e9..f819646a901bd 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -285,6 +285,15 @@ def initialize_cache(self, num_gpu_blocks: int, self._init_cache_engine() self._warm_up_model() + def destroy_cache(self) -> None: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + while self.cache_engine: + cache_engine = self.cache_engine.pop() + cache_engine.destroy() + self.gpu_cache = None + torch.cuda.empty_cache() + def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7aaa8b453cff1..2b072262fbe36 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -73,6 +73,13 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError + # TODO: Make this an abstract method and all workers should implement it. + # @abstractmethod + def destroy_cache(self) -> None: + """Clear out all the KV cache in the current worker. + """ + raise NotImplementedError + @current_platform.inference_mode() def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker.