Skip to content

Commit

Permalink
[TPU] Suppress import custom_ops warning (vllm-project#7458)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 13, 2024
1 parent 4d2dc50 commit d6e634f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if not current_platform.is_tpu():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

with contextlib.suppress(ImportError):
# ruff: noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from typing_extensions import ParamSpec, TypeIs, assert_never

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -359,6 +358,7 @@ def is_xpu() -> bool:
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
from vllm import _custom_ops as ops
max_shared_mem = (
ops.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
Expand Down

0 comments on commit d6e634f

Please sign in to comment.