Skip to content

Commit

Permalink
doc: fix sphinx (#573)
Browse files Browse the repository at this point in the history
Here's the reason why docs fail to build after #552: As specified in
`conf.py`, Sphinx mocks `torch`. The mock makes the following predicate
behave badly: `TorchVersion(torch_version) < TorchVersion("2.4")`.

The fix is to explicitly pass in an env var indicating docs building.

Also changing the way that `prefill.py` imports compiled `_kernels` so
that it's consistent with other files.
  • Loading branch information
abcdabcd987 authored Oct 31, 2024
1 parent f19e308 commit 06a922f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ help:

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: export FLASHINFER_BUILDING_DOCS=1
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
7 changes: 4 additions & 3 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@
register_fake_op,
)

if has_prebuilt_ops:
from . import _kernels # type: ignore[attr-defined]


def compile_single_prefill_module(
*args,
Expand Down Expand Up @@ -85,6 +82,8 @@ def get_single_prefill_module(*args):
if args not in _single_prefill_modules:
uri = get_single_prefill_uri(*args)
if has_prebuilt_ops and uri in prebuilt_ops_uri:
from . import _kernels

# NOTE(Zihao): we should avoid hard-coded index like this, refactor it later
mask_mode = args[5]
run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache(
Expand Down Expand Up @@ -157,6 +156,8 @@ def get_batch_prefill_module(*args):
if args not in _batch_prefill_modules:
uri = get_batch_prefill_uri(*args)
if has_prebuilt_ops and uri in prebuilt_ops_uri:
from . import _kernels

# NOTE(Zihao): we should avoid hard-coded index like this, refactor it later
head_dim = args[4]
plan_func = lambda *plan_args: _kernels.batch_prefill_with_kv_cache_plan(
Expand Down
63 changes: 43 additions & 20 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
"""

import math
import os
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"


class PosEncodingMode(Enum):
NONE = 0
Expand Down Expand Up @@ -202,26 +205,46 @@ def _check_cached_qkv_data_type(
)


def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return lambda x: x
return torch.library.custom_op(
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
)
if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"):

def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return lambda x: x

def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return lambda x: x
return torch.library.register_fake(name, fn)

else:

def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return torch.library.custom_op(
name,
fn,
mutates_args=mutates_args,
device_types=device_types,
schema=schema,
)

def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return torch.library.register_fake(name, fn)

0 comments on commit 06a922f

Please sign in to comment.