From 2be0acf9652e7f490221380d4dcd1bb28470a8dc Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Tue, 3 Dec 2024 17:48:37 +0800 Subject: [PATCH] Fix Performance Issue --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 40 +++++++++++++++++++++++++ torchvision/csrc/ops/nms.cpp | 2 ++ torchvision/ops/__init__.py | 2 ++ torchvision/ops/boxes.py | 4 +++ torchvision/ops/triton/nms.py | 31 ++++++++++++++----- torchvision/ops/xpu/nms.py | 27 +++++++++-------- 6 files changed, 86 insertions(+), 20 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 50479066cbd..7756d577fbd 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -107,10 +107,50 @@ at::Tensor nms_kernel( return result; } + +at::Tensor nms_kernel_postprocess( + const at::Tensor& order, + const at::Tensor& iou_keep_out_mask, + const int64_t num_boxes) { + + // ceil div to 32. Which is the size of ulong type. + const int col_blocks = (num_boxes + 32 - 1) / 32; + std::vector remove_box(col_blocks); + std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks); + + + at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU)); + int64_t * keep_data_ptr = keep.data_ptr(); + + unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr(); + int num_to_keep = 0; + unsigned long long* iou_keep_out_mask_data_ptr0 = (unsigned long long*)iou_keep_out_mask[0].data_ptr(); + unsigned long long*iou_keep_out_mask_data_ptr1 = (unsigned long long*)iou_keep_out_mask[1].data_ptr(); + + // Note that the iou_keep_out_mask has the shape of (N, N//32) + for (int64_t i = 0; i < num_boxes; i++) { + int nblock = i / 32; + // module 32 + int inblock = (31 - i) & (32 -1); + + if (!(remove_box[nblock] & (1UL << inblock))){ + keep_data_ptr[num_to_keep++]=i; + unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks; + for (int j = nblock; j < col_blocks; j++){ + remove_box[j] |= p[j]; + } + } + } + return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}); +} + + + } // namespace TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 5ecf8812f1b..f1eb9c0ee0f 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.set_python_module("torchvision._meta_registrations"); m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor")); } } // namespace ops diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index f750b2ee2db..bb944347ce0 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -10,6 +10,7 @@ generalized_box_iou, masks_to_boxes, nms, + nms_kernel_postprocess, remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss @@ -37,6 +38,7 @@ "DeformConv2d", "nms", "batched_nms", + "nms_kernel_postprocess", "remove_small_boxes", "clip_boxes_to_image", "box_convert", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 309990ea03a..872a9efd2cb 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -41,6 +41,10 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: return torch.ops.torchvision.nms(boxes, scores, iou_threshold) +def nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor: + return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) + + def batched_nms( boxes: Tensor, scores: Tensor, diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py index 40943934697..797b82283f8 100644 --- a/torchvision/ops/triton/nms.py +++ b/torchvision/ops/triton/nms.py @@ -3,7 +3,13 @@ @triton.jit -def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr): +def _combine_bits(val0, val1): + tl.static_assert(val0.dtype == tl.int32, "input must be int32") + tl.static_assert(val1.dtype == tl.int32, "input must be int32") + return val0 | val1 + + +def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr): """ This nms_kernel computes the supressed mask of boxes [i, j]. mask[i, j]==1 means if we choose box 1, the box j will be supressed. @@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored. threshold (float): The IoU threshold for suppressing boxes. num_boxes (int): The total number of boxes. + stride_i (int): The stride of the output tensor along the first dimension. + stride_j (int): The stride of the output tensor along the second dimension. BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel. """ @@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1) union = area_a + area_b - intersection - iou_keep_out_mask = ((intersection / union) > threshold).to(tl.int8) + iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32) + + shift_offsets = tl.arange(0, BLOCK_SIZE) % 32 + shift_offsets = tl.flip(shift_offsets, 0)[None, :] + shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE]) + iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets + + iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32)) + iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits) + iou_keep_out_combined = iou_keep_out_combined.to(tl.int64) output_block_ptr = tl.make_block_ptr( output_ptr, - shape=(num_boxes, num_boxes), - strides=(num_boxes, 1), - offsets=(row_block_start, col_block_start), - block_shape=(BLOCK_SIZE, BLOCK_SIZE), + shape=(num_boxes, (num_boxes + 32 - 1) // 32), + strides=(stride_i, stride_j), + offsets=(row_block_start, 0), + block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32), order=(0, 1), ) - tl.store(output_block_ptr, iou_keep_out_mask, boundary_check=(0, 1)) + tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1)) diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py index 209109c3762..0e46b322ed7 100644 --- a/torchvision/ops/xpu/nms.py +++ b/torchvision/ops/xpu/nms.py @@ -1,5 +1,6 @@ import torch import triton +from torchvision.ops.boxes import nms_kernel_postprocess from torchvision.ops.triton.nms import triton_nms_IoU_kernel @@ -35,21 +36,21 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) # Triton does not support argsort yet, thus it needs to fallback to ATen Calls order = torch.argsort(scores, descending=True) boxes = boxes[order] - iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device) + iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device) grid = lambda meta: ( # noqa: E731 triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), ) - # TODO: We need to tune the config from different devices. - triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8) - - # # TODO: Need to improve performance for this reduction - picked = [] - remove_box = torch.zeros(num_boxes, dtype=torch.bool, device=boxes.device) - for i in range(num_boxes): - if not (remove_box[i]): - picked.append(order[i]) - remove_box[i:] |= iou_keep_out_mask[i][i:] - - return torch.as_tensor(picked) + triton_nms_IoU_kernel[grid]( + boxes, + iou_keep_out_mask, + threshold, + num_boxes, + iou_keep_out_mask.stride(0), + iou_keep_out_mask.stride(1), + BLOCK_SIZE=64, + num_warps=4, + ) + + return nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device)