Skip to content

Commit

Permalink
bugfix: backward compatibility (#542)
Browse files Browse the repository at this point in the history
We recently changed the plan function signature and remove the
`data_type` argument, which is not compatible with some old version.

This PR keeps the `data_type` (but mark it as deprecated in
documentation) for backward compatibility.

Also fix a bug in `gen_single_decode_cu` function (return uri instead of
filename).
  • Loading branch information
yzh119 authored Oct 20, 2024
1 parent b878508 commit 78e26e4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def plan(
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Optional[Union[str, torch.dtype]] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
Expand Down Expand Up @@ -536,6 +537,9 @@ def plan(
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to
``q_data_type``. Defaults to ``None``.
data_type: Optional[Union[str, torch.dtype]]
The data type of both the query and key/value tensors. Defaults to torch.float16.
data_type is deprecated, please use q_data_type and kv_data_type instead.
Note
----
Expand Down Expand Up @@ -580,6 +584,10 @@ def plan(
qo_indptr = qo_indptr.to("cpu", non_blocking=True)
indptr = indptr.to("cpu", non_blocking=True)

if data_type is not None:
q_data_type = data_type
kv_data_type = data_type

q_data_type = canonicalize_torch_dtype(q_data_type)
if kv_data_type is None:
kv_data_type = q_data_type
Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def gen_single_decode_cu(*args) -> Tuple[str, pathlib.Path]:
path,
get_single_decode_cu_str(*args),
)
return file_name, path
return uri, path


def get_batch_decode_cu_str(
Expand Down

0 comments on commit 78e26e4

Please sign in to comment.