Skip to content

Commit

Permalink
bugfix: bugfix for torch library annotation (#562)
Browse files Browse the repository at this point in the history
Fix bugs introduced in #554 

1. Function signature change for `chain_speculative_sampling()` pybind
in aot.
2. `packbits()` uses a str default value, which is not supported by
PyTorch 2.4. This PR added a workaround.
3. For Pytorch < 2.4, the two decorators (`register_custom_op()` and
`register_fake_op()`) should return identity function instead of `None`.
  • Loading branch information
abcdabcd987 authored Oct 26, 2024
1 parent 7a7ad46 commit 9d2996d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
6 changes: 3 additions & 3 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic);

void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);

Expand Down
16 changes: 10 additions & 6 deletions python/flashinfer/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def get_quantization_module():


@register_custom_op("flashinfer::packbits", mutates_args=())
def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor:
return get_quantization_module().packbits(x, bitorder)


@register_fake_op("flashinfer::packbits")
def _fake_packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor:
return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device)


def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
r"""Pack the elements of a binary-valued array into bits in a uint8 array.
Expand Down Expand Up @@ -74,12 +83,7 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
--------
segment_packbits
"""
return get_quantization_module().packbits(x, bitorder)


@register_fake_op("flashinfer::packbits")
def _fake_packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device)
return _packbits(x, bitorder)


def segment_packbits(
Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def register_custom_op(
schema: Optional[str] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return lambda x: x
return torch.library.custom_op(
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
)
Expand All @@ -223,5 +223,5 @@ def register_fake_op(
fn: Optional[Callable] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return lambda x: x
return torch.library.register_fake(name, fn)

0 comments on commit 9d2996d

Please sign in to comment.