From 4563e25d0eaeaf14063f09096e61b87f457bab63 Mon Sep 17 00:00:00 2001
From: youkaichao <youkaichao@gmail.com>
Date: Mon, 30 Dec 2024 20:24:45 +0800
Subject: [PATCH] [platforms] enable platform plugins (#11602)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: xcnick <xcnick0412@gmail.com>
---
 .buildkite/test-pipeline.yaml                 |  25 +-
 docs/source/design/plugin_system.md           |   6 +-
 tests/conftest.py                             |   2 +-
 tests/kernels/test_attention_selector.py      |  16 +-
 .../plugins/vllm_add_dummy_platform/setup.py  |  11 +
 .../vllm_add_dummy_platform/__init__.py       |   5 +
 .../vllm_add_dummy_platform/dummy_platform.py |   5 +
 tests/plugins_tests/test_platform_plugins.py  |  16 +
 vllm/config.py                                |  15 +-
 vllm/distributed/parallel_state.py            |   3 +-
 vllm/engine/arg_utils.py                      |   2 +-
 vllm/executor/ray_utils.py                    |   2 +-
 .../guided_decoding/__init__.py               |   3 +-
 vllm/model_executor/models/registry.py        |   2 +-
 vllm/model_executor/utils.py                  |   4 +-
 vllm/platforms/__init__.py                    | 320 ++++++++++++------
 vllm/plugins/__init__.py                      |  72 ++--
 vllm/spec_decode/metrics.py                   |   2 +-
 vllm/usage/usage_lib.py                       |   2 +-
 vllm/utils.py                                 |   8 +-
 vllm/worker/model_runner_base.py              |   5 +-
 vllm/worker/multi_step_model_runner.py        |   1 +
 vllm/worker/worker_base.py                    |  14 +-
 23 files changed, 360 insertions(+), 181 deletions(-)
 create mode 100644 tests/plugins/vllm_add_dummy_platform/setup.py
 create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
 create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
 create mode 100644 tests/plugins_tests/test_platform_plugins.py

diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index b563c96343f92..bee968b4d2e43 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -106,14 +106,12 @@ steps:
   source_file_dependencies:
   - vllm/
   commands:
-  - pip install -e ./plugins/vllm_add_dummy_model
   - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
   - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
   - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
   - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
   - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
   - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
-  - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
   - pytest -v -s entrypoints/test_chat_utils.py
   - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
 
@@ -333,8 +331,6 @@ steps:
   - vllm/
   - tests/models
   commands:
-    - pip install -e ./plugins/vllm_add_dummy_model
-    - pytest -v -s models/test_oot_registration.py # it needs a clean process
     - pytest -v -s models/test_registry.py
     - pytest -v -s models/test_initialization.py
 
@@ -469,11 +465,28 @@ steps:
   - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
   - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
   - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
-  - pip install -e ./plugins/vllm_add_dummy_model
-  - pytest -v -s distributed/test_distributed_oot.py
   - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
   - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
 
+- label: Plugin Tests (2 GPUs) # 40min
+  working_dir: "/vllm-workspace/tests"
+  num_gpus: 2
+  fast_check: true
+  source_file_dependencies:
+  - vllm/plugins/
+  - tests/plugins/
+  commands:
+  # begin platform plugin tests, all the code in-between runs on dummy platform
+  - pip install -e ./plugins/vllm_add_dummy_platform
+  - pytest -v -s plugins_tests/test_platform_plugins.py
+  - pip uninstall vllm_add_dummy_platform -y
+  # end platform plugin tests
+  # other tests continue here:
+  - pip install -e ./plugins/vllm_add_dummy_model
+  - pytest -v -s distributed/test_distributed_oot.py
+  - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
+  - pytest -v -s models/test_oot_registration.py # it needs a clean process
+
 - label: Multi-step Tests (4 GPUs) # 36min
   working_dir: "/vllm-workspace/tests"
   num_gpus: 4
diff --git a/docs/source/design/plugin_system.md b/docs/source/design/plugin_system.md
index 79aff757518f2..225030885f629 100644
--- a/docs/source/design/plugin_system.md
+++ b/docs/source/design/plugin_system.md
@@ -41,9 +41,11 @@ Every plugin has three parts:
 2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
 3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
 
-## What Can Plugins Do?
+## Types of supported plugins
 
-Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
+- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.
+
+- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
 
 ## Guidelines for Writing Plugins
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 4e939221329cd..6e2f75e33654f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -31,7 +31,6 @@
                          to_enc_dec_tuple_list, zip_enc_dec_prompts)
 from vllm.logger import init_logger
 from vllm.outputs import RequestOutput
-from vllm.platforms import current_platform
 from vllm.sampling_params import BeamSearchParams
 from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
                         identity)
@@ -242,6 +241,7 @@ def video_assets() -> _VideoAssets:
 class HfRunner:
 
     def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
+        from vllm.platforms import current_platform
         if x is None or isinstance(x, (bool, )):
             return x
 
diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py
index d37f95d48d5b2..916cc2efa3895 100644
--- a/tests/kernels/test_attention_selector.py
+++ b/tests/kernels/test_attention_selector.py
@@ -5,7 +5,10 @@
 
 from tests.kernels.utils import override_backend_env_variable
 from vllm.attention.selector import which_attn_to_use
-from vllm.platforms import cpu, cuda, openvino, rocm
+from vllm.platforms.cpu import CpuPlatform
+from vllm.platforms.cuda import CudaPlatform
+from vllm.platforms.openvino import OpenVinoPlatform
+from vllm.platforms.rocm import RocmPlatform
 from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
 
 
@@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch):
     override_backend_env_variable(monkeypatch, name)
 
     if device == "cpu":
-        with patch("vllm.attention.selector.current_platform",
-                   cpu.CpuPlatform()):
+        with patch("vllm.attention.selector.current_platform", CpuPlatform()):
             backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                         False)
         assert backend.name == "TORCH_SDPA"
     elif device == "hip":
-        with patch("vllm.attention.selector.current_platform",
-                   rocm.RocmPlatform()):
+        with patch("vllm.attention.selector.current_platform", RocmPlatform()):
             backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                         False)
         assert backend.name == "ROCM_FLASH"
     elif device == "openvino":
         with patch("vllm.attention.selector.current_platform",
-                   openvino.OpenVinoPlatform()):
+                   OpenVinoPlatform()):
             backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                         False)
         assert backend.name == "OPENVINO"
     else:
-        with patch("vllm.attention.selector.current_platform",
-                   cuda.CudaPlatform()):
+        with patch("vllm.attention.selector.current_platform", CudaPlatform()):
             backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
                                         False)
         assert backend.name == name
diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py
new file mode 100644
index 0000000000000..31639906898db
--- /dev/null
+++ b/tests/plugins/vllm_add_dummy_platform/setup.py
@@ -0,0 +1,11 @@
+from setuptools import setup
+
+setup(
+    name='vllm_add_dummy_platform',
+    version='0.1',
+    packages=['vllm_add_dummy_platform'],
+    entry_points={
+        'vllm.platform_plugins': [
+            "dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin"  # noqa
+        ]
+    })
diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
new file mode 100644
index 0000000000000..594cef520a7de
--- /dev/null
+++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
@@ -0,0 +1,5 @@
+from typing import Optional
+
+
+def dummy_platform_plugin() -> Optional[str]:
+    return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
new file mode 100644
index 0000000000000..fde93142f1103
--- /dev/null
+++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
@@ -0,0 +1,5 @@
+from vllm.platforms.cuda import CudaPlatform
+
+
+class DummyPlatform(CudaPlatform):
+    device_name = "DummyDevice"
diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py
new file mode 100644
index 0000000000000..0d27cf9f152e0
--- /dev/null
+++ b/tests/plugins_tests/test_platform_plugins.py
@@ -0,0 +1,16 @@
+def test_platform_plugins():
+    # simulate workload by running an example
+    import runpy
+    current_file = __file__
+    import os
+    example_file = os.path.join(
+        os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
+        "examples", "offline_inference.py")
+    runpy.run_path(example_file)
+
+    # check if the plugin is loaded correctly
+    from vllm.platforms import _init_trace, current_platform
+    assert current_platform.device_name == "DummyDevice", (
+        f"Expected DummyDevice, got {current_platform.device_name}, "
+        "possibly because current_platform is imported before the plugin"
+        f" is loaded. The first import:\n{_init_trace}")
diff --git a/vllm/config.py b/vllm/config.py
index 765a46e6aeee3..e72c53b6130d0 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -22,7 +22,7 @@
 from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
                                                      get_quantization_config)
 from vllm.model_executor.models import ModelRegistry
-from vllm.platforms import current_platform, interface
+from vllm.platforms import CpuArchEnum
 from vllm.tracing import is_otel_available, otel_import_error_traceback
 from vllm.transformers_utils.config import (
     ConfigFormat, get_config, get_hf_image_processor_config,
@@ -349,6 +349,7 @@ def __init__(self,
         self.is_hybrid = self._init_is_hybrid()
         self.has_inner_state = self._init_has_inner_state()
 
+        from vllm.platforms import current_platform
         if current_platform.is_neuron():
             self.override_neuron_config = override_neuron_config
         else:
@@ -589,6 +590,7 @@ def _verify_quantization(self) -> None:
                 raise ValueError(
                     f"Unknown quantization method: {self.quantization}. Must "
                     f"be one of {supported_quantization}.")
+            from vllm.platforms import current_platform
             current_platform.verify_quantization(self.quantization)
             if self.quantization not in optimized_quantization_methods:
                 logger.warning(
@@ -644,6 +646,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
 
         # Reminder: Please update docs/source/usage/compatibility_matrix.md
         # If the feature combo become valid
+        from vllm.platforms import current_platform
         if not current_platform.is_async_output_supported(self.enforce_eager):
             logger.warning(
                 "Async output processing is not supported on the "
@@ -1012,6 +1015,7 @@ def _verify_args(self) -> None:
             raise ValueError(
                 "GPU memory utilization must be less than 1.0. Got "
                 f"{self.gpu_memory_utilization}.")
+        from vllm.platforms import current_platform
         if (current_platform.is_cuda() and self.block_size is not None
                 and self.block_size > 32):
             raise ValueError("CUDA Paged Attention kernel only supports "
@@ -1279,6 +1283,7 @@ def __post_init__(self) -> None:
                                  f"distributed executor backend "
                                  f"'{self.distributed_executor_backend}'.")
         ray_only_devices = ["tpu", "hpu"]
+        from vllm.platforms import current_platform
         if (current_platform.device_type in ray_only_devices
                 and self.world_size > 1):
             if self.distributed_executor_backend is None:
@@ -1327,7 +1332,7 @@ def use_ray(self) -> bool:
     def _verify_args(self) -> None:
         # Lazy import to avoid circular import
         from vllm.executor.executor_base import ExecutorBase
-
+        from vllm.platforms import current_platform
         if self.distributed_executor_backend not in (
                 "ray", "mp", None) and not (isinstance(
                     self.distributed_executor_backend, type) and issubclass(
@@ -1528,6 +1533,7 @@ def compute_hash(self) -> str:
     def __init__(self, device: str = "auto") -> None:
         if device == "auto":
             # Automated device type detection
+            from vllm.platforms import current_platform
             self.device_type = current_platform.device_type
             if not self.device_type:
                 raise RuntimeError("Failed to infer device type")
@@ -2241,9 +2247,10 @@ def _get_and_verify_dtype(
             else:
                 torch_dtype = config_dtype
 
+            from vllm.platforms import current_platform
             if (current_platform.is_cpu()
                     and current_platform.get_cpu_architecture()
-                    == interface.CpuArchEnum.POWERPC
+                    == CpuArchEnum.POWERPC
                     and (config_dtype == torch.float16
                          or config_dtype == torch.float32)):
                 logger.info(
@@ -3083,6 +3090,7 @@ def _get_quantization_config(
             model_config: ModelConfig,
             load_config: LoadConfig) -> Optional[QuantizationConfig]:
         """Get the quantization config."""
+        from vllm.platforms import current_platform
         if model_config.quantization is not None:
             from vllm.model_executor.model_loader.weight_utils import (
                 get_quant_config)
@@ -3145,6 +3153,7 @@ def __post_init__(self):
             self.quant_config = VllmConfig._get_quantization_config(
                 self.model_config, self.load_config)
 
+        from vllm.platforms import current_platform
         if self.scheduler_config is not None and \
             self.model_config is not None and \
             self.scheduler_config.chunked_prefill_enabled and \
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 5b9236f8c56b6..e6768467f4c27 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -39,7 +39,6 @@
 import vllm.envs as envs
 from vllm.distributed.utils import StatelessProcessGroup
 from vllm.logger import init_logger
-from vllm.platforms import current_platform
 from vllm.utils import direct_register_custom_op, supports_custom_op
 
 if TYPE_CHECKING:
@@ -194,6 +193,7 @@ def __init__(
         assert self.cpu_group is not None
         assert self.device_group is not None
 
+        from vllm.platforms import current_platform
         if current_platform.is_cuda_alike():
             self.device = torch.device(f"cuda:{local_rank}")
         else:
@@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
         import ray  # Lazy import Ray
         ray.shutdown()
     gc.collect()
+    from vllm.platforms import current_platform
     if not current_platform.is_cpu():
         torch.cuda.empty_cache()
 
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 21966d003c7ef..69c7c5077fe32 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -18,7 +18,6 @@
 from vllm.executor.executor_base import ExecutorBase
 from vllm.logger import init_logger
 from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
-from vllm.platforms import current_platform
 from vllm.transformers_utils.utils import check_gguf_file
 from vllm.usage.usage_lib import UsageContext
 from vllm.utils import FlexibleArgumentParser, StoreBoolean
@@ -1094,6 +1093,7 @@ def create_engine_config(self,
                 use_sliding_window = (model_config.get_sliding_window()
                                       is not None)
                 use_spec_decode = self.speculative_model is not None
+                from vllm.platforms import current_platform
                 if (is_gpu and not use_sliding_window and not use_spec_decode
                         and not self.enable_lora
                         and not self.enable_prompt_adapter
diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py
index 426aa1b5c728f..8d766bad1a072 100644
--- a/vllm/executor/ray_utils.py
+++ b/vllm/executor/ray_utils.py
@@ -8,7 +8,6 @@
 from vllm.config import ParallelConfig
 from vllm.executor.msgspec_utils import decode_hook, encode_hook
 from vllm.logger import init_logger
-from vllm.platforms import current_platform
 from vllm.sequence import ExecuteModelRequest, IntermediateTensors
 from vllm.utils import get_ip
 from vllm.worker.worker_base import WorkerWrapperBase
@@ -229,6 +228,7 @@ def initialize_ray_cluster(
             the default Ray cluster address.
     """
     assert_ray_available()
+    from vllm.platforms import current_platform
 
     # Connect to a ray cluster.
     if current_platform.is_rocm() or current_platform.is_xpu():
diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py
index 694c5b68b1cbd..18b435a42544a 100644
--- a/vllm/model_executor/guided_decoding/__init__.py
+++ b/vllm/model_executor/guided_decoding/__init__.py
@@ -6,7 +6,7 @@
 from vllm.model_executor.guided_decoding.utils import (
     convert_lark_to_gbnf, grammar_is_likely_lark,
     has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
-from vllm.platforms import CpuArchEnum, current_platform
+from vllm.platforms import CpuArchEnum
 
 if TYPE_CHECKING:
     from transformers import PreTrainedTokenizer
@@ -39,6 +39,7 @@ def maybe_backend_fallback(
 
     if guided_params.backend == "xgrammar":
         # xgrammar only has x86 wheels for linux, fallback to outlines
+        from vllm.platforms import current_platform
         if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
             logger.warning("xgrammar is only supported on x86 CPUs. "
                            "Falling back to use outlines instead.")
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 67268eb4bb85f..07f4b5a3b3bc8 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -18,7 +18,6 @@
 import torch.nn as nn
 
 from vllm.logger import init_logger
-from vllm.platforms import current_platform
 
 from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
                          supports_cross_encoding, supports_multimodal,
@@ -273,6 +272,7 @@ def _try_load_model_cls(
     model_arch: str,
     model: _BaseRegisteredModel,
 ) -> Optional[Type[nn.Module]]:
+    from vllm.platforms import current_platform
     current_platform.verify_model_arch(model_arch)
     try:
         return model.load_model_cls()
diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py
index 39ead08c238ce..6f1cc9d5e0c30 100644
--- a/vllm/model_executor/utils.py
+++ b/vllm/model_executor/utils.py
@@ -3,10 +3,9 @@
 
 import torch
 
-from vllm.platforms import current_platform
-
 
 def set_random_seed(seed: int) -> None:
+    from vllm.platforms import current_platform
     current_platform.seed_everything(seed)
 
 
@@ -38,6 +37,7 @@ def set_weight_attrs(
         # This sometimes causes OOM errors during model loading. To avoid this,
         # we sync the param tensor after its weight loader is called.
         # TODO(woosuk): Remove this hack once we have a better solution.
+        from vllm.platforms import current_platform
         if current_platform.is_tpu() and key == "weight_loader":
             value = _make_synced_weight_loader(value)
         setattr(weight, key, value)
diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py
index 419237c252ffd..f6ac14446c021 100644
--- a/vllm/platforms/__init__.py
+++ b/vllm/platforms/__init__.py
@@ -1,123 +1,223 @@
+import logging
+import traceback
+from itertools import chain
+from typing import TYPE_CHECKING, Optional
+
+from vllm.plugins import load_plugins_by_group
+from vllm.utils import resolve_obj_by_qualname
+
 from .interface import _Backend  # noqa: F401
-from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
+from .interface import CpuArchEnum, Platform, PlatformEnum
 
-current_platform: Platform
+logger = logging.getLogger(__name__)
 
-# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
-# they only indicate the build configuration, not the runtime environment.
-# For example, people can install a cuda build of pytorch but run on tpu.
 
-is_tpu = False
-try:
-    # While it's technically possible to install libtpu on a non-TPU machine,
-    # this is a very uncommon scenario. Therefore, we assume that libtpu is
-    # installed if and only if the machine has TPUs.
-    import libtpu  # noqa: F401
-    is_tpu = True
-except Exception:
-    pass
+def tpu_platform_plugin() -> Optional[str]:
+    is_tpu = False
+    try:
+        # While it's technically possible to install libtpu on a
+        # non-TPU machine, this is a very uncommon scenario. Therefore,
+        # we assume that libtpu is installed if and only if the machine
+        # has TPUs.
+        import libtpu  # noqa: F401
+        is_tpu = True
+    except Exception:
+        pass
+
+    return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
 
-is_cuda = False
 
-try:
-    import pynvml
-    pynvml.nvmlInit()
+def cuda_platform_plugin() -> Optional[str]:
+    is_cuda = False
+
     try:
-        if pynvml.nvmlDeviceGetCount() > 0:
+        import pynvml
+        pynvml.nvmlInit()
+        try:
+            if pynvml.nvmlDeviceGetCount() > 0:
+                is_cuda = True
+        finally:
+            pynvml.nvmlShutdown()
+    except Exception:
+        # CUDA is supported on Jetson, but NVML may not be.
+        import os
+
+        def cuda_is_jetson() -> bool:
+            return os.path.isfile("/etc/nv_tegra_release") \
+                or os.path.exists("/sys/class/tegra-firmware")
+
+        if cuda_is_jetson():
             is_cuda = True
-    finally:
-        pynvml.nvmlShutdown()
-except Exception:
-    # CUDA is supported on Jetson, but NVML may not be.
-    import os
 
-    def cuda_is_jetson() -> bool:
-        return os.path.isfile("/etc/nv_tegra_release") \
-            or os.path.exists("/sys/class/tegra-firmware")
+    return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
+
+
+def rocm_platform_plugin() -> Optional[str]:
+    is_rocm = False
+
+    try:
+        import amdsmi
+        amdsmi.amdsmi_init()
+        try:
+            if len(amdsmi.amdsmi_get_processor_handles()) > 0:
+                is_rocm = True
+        finally:
+            amdsmi.amdsmi_shut_down()
+    except Exception:
+        pass
+
+    return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
+
+
+def hpu_platform_plugin() -> Optional[str]:
+    is_hpu = False
+    try:
+        from importlib import util
+        is_hpu = util.find_spec('habana_frameworks') is not None
+    except Exception:
+        pass
+
+    return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
+
+
+def xpu_platform_plugin() -> Optional[str]:
+    is_xpu = False
+
+    try:
+        # installed IPEX if the machine has XPUs.
+        import intel_extension_for_pytorch  # noqa: F401
+        import oneccl_bindings_for_pytorch  # noqa: F401
+        import torch
+        if hasattr(torch, 'xpu') and torch.xpu.is_available():
+            is_xpu = True
+    except Exception:
+        pass
+
+    return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
+
+
+def cpu_platform_plugin() -> Optional[str]:
+    is_cpu = False
+    try:
+        from importlib.metadata import version
+        is_cpu = "cpu" in version("vllm")
+    except Exception:
+        pass
+
+    return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
+
+
+def neuron_platform_plugin() -> Optional[str]:
+    is_neuron = False
+    try:
+        import transformers_neuronx  # noqa: F401
+        is_neuron = True
+    except ImportError:
+        pass
 
-    if cuda_is_jetson():
-        is_cuda = True
+    return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
 
-is_rocm = False
 
-try:
-    import amdsmi
-    amdsmi.amdsmi_init()
+def openvino_platform_plugin() -> Optional[str]:
+    is_openvino = False
     try:
-        if len(amdsmi.amdsmi_get_processor_handles()) > 0:
-            is_rocm = True
-    finally:
-        amdsmi.amdsmi_shut_down()
-except Exception:
-    pass
-
-is_hpu = False
-try:
-    from importlib import util
-    is_hpu = util.find_spec('habana_frameworks') is not None
-except Exception:
-    pass
-
-is_xpu = False
-
-try:
-    # installed IPEX if the machine has XPUs.
-    import intel_extension_for_pytorch  # noqa: F401
-    import oneccl_bindings_for_pytorch  # noqa: F401
-    import torch
-    if hasattr(torch, 'xpu') and torch.xpu.is_available():
-        is_xpu = True
-except Exception:
-    pass
-
-is_cpu = False
-try:
-    from importlib.metadata import version
-    is_cpu = "cpu" in version("vllm")
-except Exception:
-    pass
-
-is_neuron = False
-try:
-    import transformers_neuronx  # noqa: F401
-    is_neuron = True
-except ImportError:
-    pass
-
-is_openvino = False
-try:
-    from importlib.metadata import version
-    is_openvino = "openvino" in version("vllm")
-except Exception:
-    pass
-
-if is_tpu:
-    # people might install pytorch built with cuda but run on tpu
-    # so we need to check tpu first
-    from .tpu import TpuPlatform
-    current_platform = TpuPlatform()
-elif is_cuda:
-    from .cuda import CudaPlatform
-    current_platform = CudaPlatform()
-elif is_rocm:
-    from .rocm import RocmPlatform
-    current_platform = RocmPlatform()
-elif is_hpu:
-    from .hpu import HpuPlatform
-    current_platform = HpuPlatform()
-elif is_xpu:
-    from .xpu import XPUPlatform
-    current_platform = XPUPlatform()
-elif is_cpu:
-    from .cpu import CpuPlatform
-    current_platform = CpuPlatform()
-elif is_neuron:
-    from .neuron import NeuronPlatform
-    current_platform = NeuronPlatform()
-elif is_openvino:
-    from .openvino import OpenVinoPlatform
-    current_platform = OpenVinoPlatform()
-else:
-    current_platform = UnspecifiedPlatform()
-
-__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
+        from importlib.metadata import version
+        is_openvino = "openvino" in version("vllm")
+    except Exception:
+        pass
+
+    return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None
+
+
+builtin_platform_plugins = {
+    'tpu': tpu_platform_plugin,
+    'cuda': cuda_platform_plugin,
+    'rocm': rocm_platform_plugin,
+    'hpu': hpu_platform_plugin,
+    'xpu': xpu_platform_plugin,
+    'cpu': cpu_platform_plugin,
+    'neuron': neuron_platform_plugin,
+    'openvino': openvino_platform_plugin,
+}
+
+
+def resolve_current_platform_cls_qualname() -> str:
+    platform_plugins = load_plugins_by_group('vllm.platform_plugins')
+
+    activated_plugins = []
+
+    for name, func in chain(builtin_platform_plugins.items(),
+                            platform_plugins.items()):
+        try:
+            assert callable(func)
+            platform_cls_qualname = func()
+            if platform_cls_qualname is not None:
+                activated_plugins.append(name)
+        except Exception:
+            pass
+
+    activated_builtin_plugins = list(
+        set(activated_plugins) & set(builtin_platform_plugins.keys()))
+    activated_oot_plugins = list(
+        set(activated_plugins) & set(platform_plugins.keys()))
+
+    if len(activated_oot_plugins) >= 2:
+        raise RuntimeError(
+            "Only one platform plugin can be activated, but got: "
+            f"{activated_oot_plugins}")
+    elif len(activated_oot_plugins) == 1:
+        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
+        logger.info("Platform plugin %s is activated",
+                    activated_oot_plugins[0])
+    elif len(activated_builtin_plugins) >= 2:
+        raise RuntimeError(
+            "Only one platform plugin can be activated, but got: "
+            f"{activated_builtin_plugins}")
+    elif len(activated_builtin_plugins) == 1:
+        platform_cls_qualname = builtin_platform_plugins[
+            activated_builtin_plugins[0]]()
+        logger.info("Automatically detected platform %s.",
+                    activated_builtin_plugins[0])
+    else:
+        platform_cls_qualname = "vllm.interface.UnspecifiedPlatform"
+        logger.info(
+            "No platform detected, vLLM is running on UnspecifiedPlatform")
+    return platform_cls_qualname
+
+
+_current_platform = None
+_init_trace: str = ''
+
+if TYPE_CHECKING:
+    current_platform: Platform
+
+
+def __getattr__(name: str):
+    if name == 'current_platform':
+        # lazy init current_platform.
+        # 1. out-of-tree platform plugins need `from vllm.platforms import
+        #    Platform` so that they can inherit `Platform` class. Therefore,
+        #    we cannot resolve `current_platform` during the import of
+        #    `vllm.platforms`.
+        # 2. when users use out-of-tree platform plugins, they might run
+        #    `import vllm`, some vllm internal code might access
+        #    `current_platform` during the import, and we need to make sure
+        #    `current_platform` is only resolved after the plugins are loaded
+        #    (we have tests for this, if any developer violate this, they will
+        #    see the test failures).
+        global _current_platform
+        if _current_platform is None:
+            platform_cls_qualname = resolve_current_platform_cls_qualname()
+            _current_platform = resolve_obj_by_qualname(
+                platform_cls_qualname)()
+            global _init_trace
+            _init_trace = "".join(traceback.format_stack())
+        return _current_platform
+    else:
+        return globals()[name]
+
+
+__all__ = [
+    'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
+    "_init_trace"
+]
diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py
index 17f604ea0e202..c50eb2cef4cd5 100644
--- a/vllm/plugins/__init__.py
+++ b/vllm/plugins/__init__.py
@@ -1,10 +1,10 @@
 import logging
 import os
+from typing import Callable, Dict
 
 import torch
 
 import vllm.envs as envs
-from vllm.platforms import current_platform
 
 logger = logging.getLogger(__name__)
 
@@ -12,6 +12,39 @@
 plugins_loaded = False
 
 
+def load_plugins_by_group(group: str) -> Dict[str, Callable]:
+    import sys
+    if sys.version_info < (3, 10):
+        from importlib_metadata import entry_points
+    else:
+        from importlib.metadata import entry_points
+
+    allowed_plugins = envs.VLLM_PLUGINS
+
+    discovered_plugins = entry_points(group=group)
+    if len(discovered_plugins) == 0:
+        logger.debug("No plugins for group %s found.", group)
+        return {}
+    logger.info("Available plugins for group %s:", group)
+    for plugin in discovered_plugins:
+        logger.info("name=%s, value=%s", plugin.name, plugin.value)
+    if allowed_plugins is None:
+        logger.info("all available plugins for group %s will be loaded.",
+                    group)
+        logger.info("set environment variable VLLM_PLUGINS to control"
+                    " which plugins to load.")
+    plugins = {}
+    for plugin in discovered_plugins:
+        if allowed_plugins is None or plugin.name in allowed_plugins:
+            try:
+                func = plugin.load()
+                plugins[plugin.name] = func
+                logger.info("plugin %s loaded.", plugin.name)
+            except Exception:
+                logger.exception("Failed to load plugin %s", plugin.name)
+    return plugins
+
+
 def load_general_plugins():
     """WARNING: plugins can be loaded for multiple times in different
     processes. They should be designed in a way that they can be loaded
@@ -26,6 +59,9 @@ def load_general_plugins():
     os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
     # see https://github.com/vllm-project/vllm/issues/10619
     torch._inductor.config.compile_threads = 1
+
+    from vllm.platforms import current_platform
+
     if current_platform.is_xpu():
         # see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158  # noqa
         os.environ['TORCH_COMPILE_DISABLE'] = 'True'
@@ -47,33 +83,7 @@ def load_general_plugins():
     if plugins_loaded:
         return
     plugins_loaded = True
-    import sys
-    if sys.version_info < (3, 10):
-        from importlib_metadata import entry_points
-    else:
-        from importlib.metadata import entry_points
-
-    allowed_plugins = envs.VLLM_PLUGINS
-
-    discovered_plugins = entry_points(group='vllm.general_plugins')
-    if len(discovered_plugins) == 0:
-        logger.debug("No plugins found.")
-        return
-    logger.info("Available plugins:")
-    for plugin in discovered_plugins:
-        logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value,
-                    plugin.group)
-    if allowed_plugins is None:
-        logger.info("all available plugins will be loaded.")
-        logger.info("set environment variable VLLM_PLUGINS to control"
-                    " which plugins to load.")
-    else:
-        logger.info("plugins to load: %s", allowed_plugins)
-    for plugin in discovered_plugins:
-        if allowed_plugins is None or plugin.name in allowed_plugins:
-            try:
-                func = plugin.load()
-                func()
-                logger.info("plugin %s loaded.", plugin.name)
-            except Exception:
-                logger.exception("Failed to load plugin %s", plugin.name)
+    plugins = load_plugins_by_group(group='vllm.general_plugins')
+    # general plugins, we only need to execute the loaded functions
+    for func in plugins.values():
+        func()
diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py
index 03dc46600d8a9..d678f4578499b 100644
--- a/vllm/spec_decode/metrics.py
+++ b/vllm/spec_decode/metrics.py
@@ -6,7 +6,6 @@
 
 from vllm.model_executor.layers.spec_decode_base_sampler import (
     SpecDecodeBaseSampler)
-from vllm.platforms import current_platform
 from vllm.utils import is_pin_memory_available
 
 
@@ -94,6 +93,7 @@ def init_tensors(self,
     def maybe_collect_rejsample_metrics(
             self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
         # currently using cuda.Event, skip for any non_cuda_alike platform
+        from vllm.platforms import current_platform
         if not current_platform.is_cuda_alike():
             return None
 
diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py
index 9ae46ff43a916..a9deee881f41a 100644
--- a/vllm/usage/usage_lib.py
+++ b/vllm/usage/usage_lib.py
@@ -17,7 +17,6 @@
 
 import vllm.envs as envs
 from vllm.connections import global_http_connection
-from vllm.platforms import current_platform
 from vllm.version import __version__ as VLLM_VERSION
 
 _config_home = envs.VLLM_CONFIG_ROOT
@@ -152,6 +151,7 @@ def _report_usage_once(self, model_architecture: str,
                            usage_context: UsageContext,
                            extra_kvs: Dict[str, Any]) -> None:
         # Platform information
+        from vllm.platforms import current_platform
         if current_platform.is_cuda_alike():
             device_property = torch.cuda.get_device_properties(0)
             self.gpu_count = torch.cuda.device_count()
diff --git a/vllm/utils.py b/vllm/utils.py
index 2b46c1fef0d09..8ef07d2c326a3 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -50,7 +50,6 @@
 
 import vllm.envs as envs
 from vllm.logger import enable_trace_function_call, init_logger
-from vllm.platforms import current_platform
 
 if TYPE_CHECKING:
     from vllm.config import VllmConfig
@@ -609,6 +608,7 @@ def create_kv_caches_with_random_flash(
     seed: int = 0,
     device: Optional[str] = "cuda",
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+    from vllm.platforms import current_platform
     current_platform.seed_everything(seed)
 
     torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@@ -650,7 +650,7 @@ def create_kv_caches_with_random(
         raise ValueError(
             f"Does not support key cache of type fp8 with head_size {head_size}"
         )
-
+    from vllm.platforms import current_platform
     current_platform.seed_everything(seed)
 
     torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@@ -703,6 +703,7 @@ def print_warning_once(msg: str) -> None:
 
 @lru_cache(maxsize=None)
 def is_pin_memory_available() -> bool:
+    from vllm.platforms import current_platform
     return current_platform.is_pin_memory_available()
 
 
@@ -713,6 +714,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
 
     def current_memory_usage(self) -> float:
         # Return the memory usage in bytes.
+        from vllm.platforms import current_platform
         if current_platform.is_cuda_alike():
             torch.cuda.reset_peak_memory_stats(self.device)
             mem = torch.cuda.max_memory_allocated(self.device)
@@ -1066,6 +1068,7 @@ def _cuda_device_count_stateless(
     import torch.cuda
     import torch.version
 
+    from vllm.platforms import current_platform
     if not torch.cuda._is_compiled():
         return 0
     if current_platform.is_rocm():
@@ -1673,6 +1676,7 @@ def direct_register_custom_op(
         return
 
     if not supports_custom_op():
+        from vllm.platforms import current_platform
         assert not current_platform.is_cuda_alike(), (
             "cuda platform needs torch>=2.4 to support custom op, "
             "chances are you are using an old version of pytorch "
diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py
index cd4770202a186..c7abad7e0258d 100644
--- a/vllm/worker/model_runner_base.py
+++ b/vllm/worker/model_runner_base.py
@@ -12,7 +12,6 @@
 from vllm.config import VllmConfig
 from vllm.logger import init_logger
 from vllm.model_executor.layers.sampler import SamplerOutput
-from vllm.platforms import current_platform
 from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
 
 if TYPE_CHECKING:
@@ -265,13 +264,13 @@ def prepare_model_input(
         """
         raise NotImplementedError
 
-    @current_platform.inference_mode()
     def execute_model(
         self,
         model_input: T,
         kv_caches: Optional[List[torch.Tensor]],
-        intermediate_tensors: Optional[IntermediateTensors],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
+        **kwargs,
     ) -> Optional[List[SamplerOutput]]:
         """
         Execute the model on the given input.
diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py
index 65d9bab0e2822..dee63a75c0605 100644
--- a/vllm/worker/multi_step_model_runner.py
+++ b/vllm/worker/multi_step_model_runner.py
@@ -544,6 +544,7 @@ def execute_model(
         model_input.record_step_event(current_stream)
 
         if get_pp_group().is_last_rank and self.is_driver_worker:
+            assert isinstance(output, list)
             assert len(
                 output
             ) == 1, "MultiStepModelRunner requires single-step base_models"
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 3ac7fb8dfb766..249b3ed2dfd37 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -11,7 +11,6 @@
 from vllm.logger import init_logger
 from vllm.lora.request import LoRARequest
 from vllm.model_executor.layers.sampler import SamplerOutput
-from vllm.platforms import current_platform
 from vllm.sequence import ExecuteModelRequest, IntermediateTensors
 from vllm.utils import (enable_trace_function_call_for_thread,
                         resolve_obj_by_qualname, update_environment_variables)
@@ -44,6 +43,8 @@ def __init__(
         self.prompt_adapter_config = vllm_config.prompt_adapter_config
         self.observability_config = vllm_config.observability_config
         self.kv_transfer_config = vllm_config.kv_transfer_config
+        from vllm.platforms import current_platform
+        self.current_platform = current_platform
 
     @abstractmethod
     def init_device(self) -> None:
@@ -74,17 +75,17 @@ def initialize_cache(self, num_gpu_blocks: int,
         """
         raise NotImplementedError
 
-    @current_platform.inference_mode()
     def start_worker_execution_loop(self) -> None:
         """Execute model loop in parallel worker.
 
         You can stop the loop by executing a driver worker with an empty output.
         See `stop_remote_worker_execution_loop` for more details.
         """
-        while True:
-            output = self.execute_model(execute_model_req=None)
-            if output is None:
-                return None
+        with self.current_platform.inference_mode():
+            while True:
+                output = self.execute_model(execute_model_req=None)
+                if output is None:
+                    return None
 
     @abstractmethod
     def execute_model(
@@ -352,6 +353,7 @@ def execute_model(
         model_execute_time = time.perf_counter() - start_time
         if not get_pp_group().is_last_rank:
             # output is IntermediateTensors
+            assert isinstance(output, IntermediateTensors)
             if (self.observability_config is not None
                     and self.observability_config.collect_model_execute_time):
                 output.tensors["model_execute_time"] = torch.tensor(