Skip to content

Commit

Permalink
Optimize the logic about the num_gpus parameter when loading vllm usi…
Browse files Browse the repository at this point in the history
…ng ray
  • Loading branch information
jieguangzhou committed Jan 10, 2024
1 parent 5dd01f9 commit 672407b
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions superduperdb/ext/llm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import requests

from superduperdb import logging
from superduperdb.ext.llm.base import BaseLLMAPI, BaseLLMModel

__all__ = ["VllmAPI", "VllmModel", "VllmOpenAI"]
__all__ = ["VllmAPI", "VllmModel"]

VLLM_INFERENCE_PARAMETERS_LIST = [
"n",
Expand Down Expand Up @@ -78,14 +79,14 @@ class VllmModel(BaseLLMModel):

def __post_init__(self):
self.on_ray = self.on_ray or bool(self.ray_address)
if 'tensor_parallel_size' not in self.vllm_kwargs:
self.vllm_kwargs['tensor_parallel_size'] = self.tensor_parallel_size
if "tensor_parallel_size" not in self.vllm_kwargs:
self.vllm_kwargs["tensor_parallel_size"] = self.tensor_parallel_size

if 'trust_remote_code' not in self.vllm_kwargs:
self.vllm_kwargs['trust_remote_code'] = self.trust_remote_code
if "trust_remote_code" not in self.vllm_kwargs:
self.vllm_kwargs["trust_remote_code"] = self.trust_remote_code

if 'model' not in self.vllm_kwargs:
self.vllm_kwargs['model'] = self.model_name
if "model" not in self.vllm_kwargs:
self.vllm_kwargs["model"] = self.model_name

super().__post_init__()

Expand Down Expand Up @@ -119,21 +120,26 @@ def generate(self, prompts: List[str], **kwargs) -> List[str]:
except ImportError:
raise Exception("You must install vllm with command 'pip install ray'")

runtime_env = {
"pip": [
"vllm",
]
}
if not ray.is_initialized():
ray.init(address=self.ray_address, runtime_env=runtime_env)
self.ray_config.setdefault("runtime_env", {"pip": ["vllm"]})

if self.vllm_kwargs.get('tensor_parallel_size') == 1:
# must set num_gpus to 1 to avoid error
self.ray_config["num_gpus"] = 1
LLM = ray.remote(**self.ray_config)(_VLLMCore).remote
if not ray.is_initialized():
ray.init(address=self.ray_address, ignore_reinit_error=True)

# fix num_gpus for tensor parallel when using ray
if self.tensor_parallel_size == 1:
if self.ray_config.get("num_gpus", 1) != 1:
logging.warn(
"tensor_parallel_size == 1, num_gpus will be set to 1. "
"If you want to use more gpus, "
"please set tensor_parallel_size > 1."
)
self.ray_config["num_gpus"] = self.tensor_parallel_size
else:
# Don't know why using config will block the process, need to figure out
LLM = ray.remote(_VLLMCore).remote
if "num_gpus" in self.ray_config:
logging.warn("tensor_parallel_size > 1, num_gpus will be ignored.")
self.ray_config.pop("num_gpus", None)

LLM = ray.remote(**self.ray_config)(_VLLMCore).remote
else:
LLM = _VLLMCore

Expand Down

0 comments on commit 672407b

Please sign in to comment.