Skip to content

Commit

Permalink
Fix Performance Issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonepia committed Dec 3, 2024
1 parent a023896 commit 2be0acf
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 20 deletions.
40 changes: 40 additions & 0 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,50 @@ at::Tensor nms_kernel(
return result;
}


at::Tensor nms_kernel_postprocess(
const at::Tensor& order,
const at::Tensor& iou_keep_out_mask,
const int64_t num_boxes) {

// ceil div to 32. Which is the size of ulong type.
const int col_blocks = (num_boxes + 32 - 1) / 32;
std::vector<unsigned long> remove_box(col_blocks);
std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks);


at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU));
int64_t * keep_data_ptr = keep.data_ptr<int64_t>();

unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr<int64_t>();
int num_to_keep = 0;
unsigned long long* iou_keep_out_mask_data_ptr0 = (unsigned long long*)iou_keep_out_mask[0].data_ptr<int64_t>();
unsigned long long*iou_keep_out_mask_data_ptr1 = (unsigned long long*)iou_keep_out_mask[1].data_ptr<int64_t>();

// Note that the iou_keep_out_mask has the shape of (N, N//32)
for (int64_t i = 0; i < num_boxes; i++) {
int nblock = i / 32;
// module 32
int inblock = (31 - i) & (32 -1);

if (!(remove_box[nblock] & (1UL << inblock))){
keep_data_ptr[num_to_keep++]=i;
unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks;
for (int j = nblock; j < col_blocks; j++){
remove_box[j] |= p[j];
}
}
}
return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)});
}



} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess));
}

} // namespace ops
Expand Down
2 changes: 2 additions & 0 deletions torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.set_python_module("torchvision._meta_registrations");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor"));
}

} // namespace ops
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
generalized_box_iou,
masks_to_boxes,
nms,
nms_kernel_postprocess,
remove_small_boxes,
)
from .ciou_loss import complete_box_iou_loss
Expand Down Expand Up @@ -37,6 +38,7 @@
"DeformConv2d",
"nms",
"batched_nms",
"nms_kernel_postprocess",
"remove_small_boxes",
"clip_boxes_to_image",
"box_convert",
Expand Down
4 changes: 4 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)


def nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor:
return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes)


def batched_nms(
boxes: Tensor,
scores: Tensor,
Expand Down
31 changes: 24 additions & 7 deletions torchvision/ops/triton/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@


@triton.jit
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr):
def _combine_bits(val0, val1):
tl.static_assert(val0.dtype == tl.int32, "input must be int32")
tl.static_assert(val1.dtype == tl.int32, "input must be int32")
return val0 | val1


def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
"""
This nms_kernel computes the supressed mask of boxes [i, j].
mask[i, j]==1 means if we choose box 1, the box j will be supressed.
Expand All @@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
threshold (float): The IoU threshold for suppressing boxes.
num_boxes (int): The total number of boxes.
stride_i (int): The stride of the output tensor along the first dimension.
stride_j (int): The stride of the output tensor along the second dimension.
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
"""

Expand Down Expand Up @@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1)
union = area_a + area_b - intersection

iou_keep_out_mask = ((intersection / union) > threshold).to(tl.int8)
iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32)

shift_offsets = tl.arange(0, BLOCK_SIZE) % 32
shift_offsets = tl.flip(shift_offsets, 0)[None, :]
shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE])
iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets

iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)

iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(num_boxes, num_boxes),
strides=(num_boxes, 1),
offsets=(row_block_start, col_block_start),
block_shape=(BLOCK_SIZE, BLOCK_SIZE),
shape=(num_boxes, (num_boxes + 32 - 1) // 32),
strides=(stride_i, stride_j),
offsets=(row_block_start, 0),
block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32),
order=(0, 1),
)
tl.store(output_block_ptr, iou_keep_out_mask, boundary_check=(0, 1))
tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1))
27 changes: 14 additions & 13 deletions torchvision/ops/xpu/nms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import triton
from torchvision.ops.boxes import nms_kernel_postprocess

from torchvision.ops.triton.nms import triton_nms_IoU_kernel

Expand Down Expand Up @@ -35,21 +36,21 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float)
# Triton does not support argsort yet, thus it needs to fallback to ATen Calls
order = torch.argsort(scores, descending=True)
boxes = boxes[order]
iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device)
iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device)

grid = lambda meta: ( # noqa: E731
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
)
# TODO: We need to tune the config from different devices.
triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8)

# # TODO: Need to improve performance for this reduction
picked = []
remove_box = torch.zeros(num_boxes, dtype=torch.bool, device=boxes.device)
for i in range(num_boxes):
if not (remove_box[i]):
picked.append(order[i])
remove_box[i:] |= iou_keep_out_mask[i][i:]

return torch.as_tensor(picked)
triton_nms_IoU_kernel[grid](
boxes,
iou_keep_out_mask,
threshold,
num_boxes,
iou_keep_out_mask.stride(0),
iou_keep_out_mask.stride(1),
BLOCK_SIZE=64,
num_warps=4,
)

return nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device)

0 comments on commit 2be0acf

Please sign in to comment.