diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index b46116391db26..07b9c6b3c6be6 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,10 +1,12 @@ import contextlib import functools import gc +from typing import Callable, TypeVar import pytest import ray import torch +from typing_extensions import ParamSpec from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) @@ -22,12 +24,16 @@ def cleanup(): torch.cuda.empty_cache() -def retry_until_skip(n): +_P = ParamSpec("_P") +_R = TypeVar("_R") - def decorator_retry(func): + +def retry_until_skip(n: int): + + def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]: @functools.wraps(func) - def wrapper_retry(*args, **kwargs): + def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R: for i in range(n): try: return func(*args, **kwargs) @@ -35,7 +41,9 @@ def wrapper_retry(*args, **kwargs): gc.collect() torch.cuda.empty_cache() if i == n - 1: - pytest.skip("Skipping test after attempts..") + pytest.skip(f"Skipping test after {n} attempts.") + + raise AssertionError("Code should not be reached") return wrapper_retry diff --git a/tests/test_utils.py b/tests/test_utils.py index 8d22c20bb1977..c157be1c08f81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,8 @@ import asyncio import os import socket -import sys from functools import partial -from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, - Tuple, TypeVar) +from typing import AsyncIterator, Tuple import pytest @@ -13,26 +11,11 @@ from .utils import error_on_warning -if sys.version_info < (3, 10): - if TYPE_CHECKING: - _AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any]) - _AwaitableT_co = TypeVar("_AwaitableT_co", - bound=Awaitable[Any], - covariant=True) - - class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]): - - def __anext__(self) -> _AwaitableT_co: - ... - - def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT": - return i.__anext__() - @pytest.mark.asyncio async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + async def mock_async_iterator(idx: int): try: while True: yield f"item from iterator {idx}" @@ -41,8 +24,10 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]: print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( - *iterators, is_cancelled=partial(asyncio.sleep, 0, result=False)) + merged_iterator = merge_async_iterators(*iterators, + is_cancelled=partial(asyncio.sleep, + 0, + result=False)) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: @@ -56,7 +41,8 @@ async def stream_output(generator: AsyncIterator[Tuple[int, str]]): for iterator in iterators: try: - await asyncio.wait_for(anext(iterator), 1) + # Can use anext() in python >= 3.10 + await asyncio.wait_for(iterator.__anext__(), 1) except StopAsyncIteration: # All iterators should be cancelled and print this message. print("Iterator was cancelled normally") diff --git a/tests/utils.py b/tests/utils.py index e3d04cc638a95..697bf7d93c36e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,12 +7,13 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import openai import ray import requests from transformers import AutoTokenizer +from typing_extensions import ParamSpec from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) @@ -360,13 +361,17 @@ def wait_for_gpu_memory_to_clear(devices: List[int], time.sleep(5) -def fork_new_process_for_each_test(f): +_P = ParamSpec("_P") + + +def fork_new_process_for_each_test( + f: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ @functools.wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4a7e5c5832917..006dc8e146a6c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,10 +1,10 @@ import functools from dataclasses import dataclass -from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, - TypeVar) +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type from torch import nn from transformers import PretrainedConfig +from typing_extensions import TypeVar from vllm.logger import init_logger @@ -17,7 +17,7 @@ logger = init_logger(__name__) -C = TypeVar("C", bound=PretrainedConfig) +C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) @dataclass(frozen=True) @@ -44,7 +44,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": return multimodal_config - def get_hf_config(self, hf_config_type: Type[C]) -> C: + def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: """ Get the HuggingFace configuration (:class:`transformers.PretrainedConfig`) of the model, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 26c02d46a18ed..3a8e4baccc6fa 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -165,7 +165,7 @@ def get_internvl_num_patches(image_size: int, patch_size: int, def get_max_internvl_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config use_thumbnail = hf_config.use_thumbnail @@ -187,7 +187,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config image_size = vision_config.image_size @@ -260,7 +260,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): image_feature_size = get_max_internvl_image_tokens(ctx) model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index fc962434cab0b..85522beb0f204 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -34,7 +34,7 @@ from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ -from transformers.configuration_utils import PretrainedConfig +from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -404,7 +404,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def get_max_minicpmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() return getattr(hf_config, "query_num", 64) @@ -420,7 +420,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() seq_data = dummy_seq_data_for_minicpmv(seq_len) mm_data = dummy_image_for_minicpmv(hf_config) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 51d3a75ea6ff9..e0e427218bdd4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -341,7 +341,7 @@ def get_phi3v_image_feature_size( def get_max_phi3v_image_tokens(ctx: InputContext): return get_phi3v_image_feature_size( - ctx.get_hf_config(PretrainedConfig), + ctx.get_hf_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) @@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index a1c16d6957c2f..3e83c9ef381ac 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,13 +3,12 @@ import torch from PIL import Image -from transformers import PreTrainedTokenizerBase from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.utils import is_list_of from .base import MultiModalInputs, MultiModalPlugin @@ -40,7 +39,7 @@ def repeat_and_pad_token( def repeat_and_pad_image_tokens( - tokenizer: PreTrainedTokenizerBase, + tokenizer: AnyTokenizer, prompt: Optional[str], prompt_token_ids: List[int], *, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a7e760cc16408..6642bb8e71bd1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,9 +4,10 @@ import os from functools import lru_cache, wraps -from typing import List, Tuple +from typing import Callable, List, Tuple, TypeVar import pynvml +from typing_extensions import ParamSpec from vllm.logger import init_logger @@ -14,16 +15,19 @@ logger = init_logger(__name__) +_P = ParamSpec("_P") +_R = TypeVar("_R") + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. # the major benefit of using NVML is that it will not initialize CUDA -def with_nvml_context(fn): +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: pynvml.nvmlInit() try: return fn(*args, **kwargs) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 001af67f3bb9e..b7624c471cdb2 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,10 +1,9 @@ -from typing import Dict, List, Optional, Tuple, Union - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing import Dict, List, Optional, Tuple from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) + +from .tokenizer import AnyTokenizer +from .tokenizer_group import BaseTokenizerGroup # Used eg. for marking rejected tokens in spec decoding. INVALID_TOKEN_ID = -1 @@ -16,8 +15,7 @@ class Detokenizer: def __init__(self, tokenizer_group: BaseTokenizerGroup): self.tokenizer_group = tokenizer_group - def get_tokenizer_for_seq(self, - sequence: Sequence) -> "PreTrainedTokenizer": + def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: """Returns the HF tokenizer to use for a given sequence.""" return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) @@ -174,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]): def _convert_tokens_to_string_with_added_encoders( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, output_tokens: List[str], skip_special_tokens: bool, spaces_between_special_tokens: bool, @@ -213,7 +211,7 @@ def _convert_tokens_to_string_with_added_encoders( def convert_prompt_ids_to_tokens( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, prompt_ids: List[int], skip_special_tokens: bool = False, ) -> Tuple[List[str], int, int]: @@ -240,7 +238,7 @@ def convert_prompt_ids_to_tokens( # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license def detokenize_incrementally( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, all_input_ids: List[int], prev_tokens: Optional[List[str]], prefix_offset: int, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 25e4c41592c68..0271aa809320e 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,10 +12,10 @@ from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async -from .tokenizer_group import AnyTokenizer - logger = init_logger(__name__) +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. @@ -141,7 +141,7 @@ def get_tokenizer( def get_lora_tokenizer(lora_request: LoRARequest, *args, - **kwargs) -> Optional[PreTrainedTokenizer]: + **kwargs) -> Optional[AnyTokenizer]: if lora_request is None: return None try: diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index eeab19899b022..9a4149251d747 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -8,8 +8,7 @@ from .tokenizer_group import TokenizerGroup if ray: - from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( - RayTokenizerGroupPool) + from .ray_tokenizer_group import RayTokenizerGroupPool else: RayTokenizerGroupPool = None # type: ignore diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index abbcdf2807f6f..8f78ef65bbf1a 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Union - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing import List, Optional from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest - -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +from vllm.transformers_utils.tokenizer import AnyTokenizer class BaseTokenizerGroup(ABC): @@ -24,9 +21,10 @@ def ping(self) -> bool: pass @abstractmethod - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_max_input_len( + self, + lora_request: Optional[LoRARequest] = None, + ) -> Optional[int]: """Get the maximum input length for the LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 79081c04ddc14..9a999a0d6067d 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -13,8 +13,9 @@ from vllm.executor.ray_utils import ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .base_tokenizer_group import BaseTokenizerGroup from .tokenizer_group import TokenizerGroup logger = init_logger(__name__) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index a5186e48068e9..e516eeabaadef 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,12 +2,13 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_lora_tokenizer, get_lora_tokenizer_async, get_tokenizer) from vllm.utils import LRUCache -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .base_tokenizer_group import BaseTokenizerGroup class TokenizerGroup(BaseTokenizerGroup): diff --git a/vllm/utils.py b/vllm/utils.py index 9b5f5589340e2..30bb81722aa04 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1101,9 +1101,9 @@ def cuda_device_count_stateless() -> int: #From: https://stackoverflow.com/a/4104188/2749989 -def run_once(f): +def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args, **kwargs) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if not wrapper.has_run: # type: ignore[attr-defined] wrapper.has_run = True # type: ignore[attr-defined] return f(*args, **kwargs)