From 06a922f7a5958cc06a993c9928d3c49d5094e777 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Thu, 31 Oct 2024 01:54:35 -0700 Subject: [PATCH] doc: fix sphinx (#573) Here's the reason why docs fail to build after #552: As specified in `conf.py`, Sphinx mocks `torch`. The mock makes the following predicate behave badly: `TorchVersion(torch_version) < TorchVersion("2.4")`. The fix is to explicitly pass in an env var indicating docs building. Also changing the way that `prefill.py` imports compiled `_kernels` so that it's consistent with other files. --- docs/Makefile | 1 + python/flashinfer/prefill.py | 7 ++-- python/flashinfer/utils.py | 63 ++++++++++++++++++++++++------------ 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb..077f4fb0 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -16,5 +16,6 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: export FLASHINFER_BUILDING_DOCS=1 %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 47f5d1a5..3d550576 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -48,9 +48,6 @@ register_fake_op, ) -if has_prebuilt_ops: - from . import _kernels # type: ignore[attr-defined] - def compile_single_prefill_module( *args, @@ -85,6 +82,8 @@ def get_single_prefill_module(*args): if args not in _single_prefill_modules: uri = get_single_prefill_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels + # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later mask_mode = args[5] run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache( @@ -157,6 +156,8 @@ def get_batch_prefill_module(*args): if args not in _batch_prefill_modules: uri = get_batch_prefill_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels + # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later head_dim = args[4] plan_func = lambda *plan_args: _kernels.batch_prefill_with_kv_cache_plan( diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index c7d1ec5a..55ee3603 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -15,6 +15,7 @@ """ import math +import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union @@ -22,6 +23,8 @@ from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version +IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" + class PosEncodingMode(Enum): NONE = 0 @@ -202,26 +205,46 @@ def _check_cached_qkv_data_type( ) -def register_custom_op( - name: str, - fn: Optional[Callable] = None, - /, - *, - mutates_args: Union[str, Iterable[str]], - device_types: Optional[Union[str, Sequence[str]]] = None, - schema: Optional[str] = None, -) -> Callable: - if TorchVersion(torch_version) < TorchVersion("2.4"): - return lambda x: x - return torch.library.custom_op( - name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema - ) +if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): + def register_custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: Optional[Union[str, Sequence[str]]] = None, + schema: Optional[str] = None, + ) -> Callable: + return lambda x: x -def register_fake_op( - name: str, - fn: Optional[Callable] = None, -) -> Callable: - if TorchVersion(torch_version) < TorchVersion("2.4"): + def register_fake_op( + name: str, + fn: Optional[Callable] = None, + ) -> Callable: return lambda x: x - return torch.library.register_fake(name, fn) + +else: + + def register_custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: Optional[Union[str, Sequence[str]]] = None, + schema: Optional[str] = None, + ) -> Callable: + return torch.library.custom_op( + name, + fn, + mutates_args=mutates_args, + device_types=device_types, + schema=schema, + ) + + def register_fake_op( + name: str, + fn: Optional[Callable] = None, + ) -> Callable: + return torch.library.register_fake(name, fn)