Skip to content

Commit

Permalink
Cherry-pick #1439 to fix 'topk' on different devices for onnxruntime-…
Browse files Browse the repository at this point in the history
…gpu inference (#1603)

Co-authored-by: grimoire <[email protected]>
  • Loading branch information
hanrui1sensetime and grimoire authored Jan 4, 2023
1 parent c67e2db commit 8a05b8d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mmdeploy/pytorch/functions/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def topk__dynamic(input: torch.Tensor,
sorted: bool = True):
"""Rewrite `topk` for default backend.
Cast k to tensor and makesure k is smaller than input.shape[dim].
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""
ctx = FUNCTION_REWRITER.get_context()

Expand All @@ -28,7 +28,8 @@ def topk__dynamic(input: torch.Tensor,
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
if isinstance(size, torch.Tensor):
size = size.to(input.device)
# size would be treated as cpu tensor, trick to avoid that.
size = k.new_zeros(()) + size
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)

Expand Down

0 comments on commit 8a05b8d

Please sign in to comment.