Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep NMS index gathering on cuda device #8766

Merged
merged 6 commits into from
Feb 20, 2025

Conversation

Ghelfi
Copy link
Contributor

@Ghelfi Ghelfi commented Nov 29, 2024

Performs the unwrap of IoU mask directly on the cuda device in NMS.

This prevents device -> host -> device data transfer.

fixes #8713

Copy link

pytorch-bot bot commented Nov 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8766

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 41 Pending

As of commit b3f51f9 with merge base ae9bd7e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Ghelfi!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@Ghelfi Ghelfi force-pushed the nms-unwrap-on-cuda branch from 2db6ab2 to 3f40bdb Compare December 20, 2024 12:29
@Ghelfi Ghelfi marked this pull request as ready for review December 20, 2024 13:54
@Ghelfi
Copy link
Contributor Author

Ghelfi commented Dec 20, 2024

@NicolasHug for review

@CNOCycle
Copy link

CNOCycle commented Jan 7, 2025

I would like to share an important finding regarding the root cause of slow performance in NMS when dealing with a large number of boxes (#8713). The issue is primarily due to the occurrence of minor page faults, as detailed in [1]. When data is transferred from the GPU to the host, the operating system must locate an appropriate memory space to store these temporary variables, and this operation can be quite costly in terms of execution time. As shown in Fig. 2 of [1], execution time curves diverge into two distinct groups as the number of objects increases, with varying results across different hardware configurations. Further analysis is provided in Table 1 and Section Performance Analysis of [1].

To summarize, the key takeaways are:

  • The slow execution time of NMS can be reproduced on both x86-64 and ARM systems, across various generations of Nvidia GPU combanations.
  • To reduce the execution time of NMS, we should aim to minimize the occurrence of minor page faults.

Since end users typically do not have the privilege to modify the operating system, and adjusting the lifecycle of these temporary variables may fall outside the scope of torchvision, I suggest the following approaches, which are detailed in [1]:

  1. CPU-free NMS: This method is the same as similar to the approach proposed in PR (Keep NMS index gathering on cuda device #8766), and its mechanism has been extensively studied.
  2. Async NMS: The performance of Async NMS depends on various factors, including the version of CUDA, GPU driver, the operating system, and current memory usage. More details can be found in the Overhead of Minor Page Faults section of the Discussion in [1].
  3. Hybrid NMS: This approach is more complex, requiring meta-information about the currently used system. It is thus more suitable for advanced users.

The experimental comparison among three approaches can be found in [1]. Personally, I highly recommend CPU-free NMS. Nevertheless, the simplest implementation of Async NMS could potentially provide performance benefits if data copy is set to non-blocking mode, along with adding system synchronization before the CPU accesses the data:

//  nms_kernel_impl on GPU
at::Tensor mask_cpu = mask.to(at::kCPU, /* non_blocking*/ true);
// some non-relevant CPU codes
cudaStreamSynchronize(stream)
// unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();

The following is result of default NMS and CPU-free NMS executed on V100 with the latest docker image nvcr.io/nvidia/pytorch:24.12-py3. It can be seen that the execution time of CPU-free NMS is half of the default NMS. In the worst-case scenario, we simulated a situation in which all objects output by the object detection model survive. The best-case scenario returned only one object, while the random case randomly assigned properties to all objects. By implementing these three distinct cases, we are able to effectively evaluate the performance of the proposed methods under a range of circumstances, and make meaningful comparisons between them.
fig_nms_objs_scatter_default
fig_nms_objs_scatter_gpu

I would strongly recommend that the community can cite the paper [1] to help users understand that the code modifications involved in the PR (#8766) are based not just on experimental results, but also on clear explanations and insights drawn from computing architecture.

[UPDATED-1] I carefully compared the implementation of CPU-free NMS in this PR with my implementation and found slight differences. I can do another comparison to access performance If necessary.

References:

[1] Chen, E. C., Chen, P. Y., Chung, I., & Lee, C. R. (2024). Latency Attack Resilience in Object Detectors: Insights from Computing Architecture. In Proceedings of the Asian Conference on Computer Vision (pp. 3206-3222).

@Ghelfi
Copy link
Contributor Author

Ghelfi commented Jan 7, 2025

I would like to share an important finding regarding the root cause of slow performance in NMS when dealing with a large number of boxes (#8713). The issue is primarily due to the occurrence of minor page faults, as detailed in [1]. When data is transferred from the GPU to the host, the operating system must locate an appropriate memory space to store these temporary variables, and this operation can be quite costly in terms of execution time. As shown in Fig. 2 of [1], execution time curves diverge into two distinct groups as the number of objects increases, with varying results across different hardware configurations. Further analysis is provided in Table 1 and Section Performance Analysis of [1].

To summarize, the key takeaways are:

  • The slow execution time of NMS can be reproduced on both x86-64 and ARM systems, across various generations of Nvidia GPU combanations.
  • To reduce the execution time of NMS, we should aim to minimize the occurrence of minor page faults.

Since end users typically do not have the privilege to modify the operating system, and adjusting the lifecycle of these temporary variables may fall outside the scope of torchvision, I suggest the following approaches, which are detailed in [1]:

  1. CPU-free NMS: This method is the same as similar to the approach proposed in PR (Keep NMS index gathering on cuda device #8766), and its mechanism has been extensively studied.
  2. Async NMS: The performance of Async NMS depends on various factors, including the version of CUDA, GPU driver, the operating system, and current memory usage. More details can be found in the Overhead of Minor Page Faults section of the Discussion in [1].
  3. Hybrid NMS: This approach is more complex, requiring meta-information about the currently used system. It is thus more suitable for advanced users.

The experimental comparison among three approaches can be found in [1]. Personally, I highly recommend CPU-free NMS. Nevertheless, the simplest implementation of Async NMS could potentially provide performance benefits if data copy is set to non-blocking mode, along with adding system synchronization before the CPU accesses the data:

//  nms_kernel_impl on GPU
at::Tensor mask_cpu = mask.to(at::kCPU, /* non_blocking*/ true);
// some non-relevant CPU codes
cudaStreamSynchronize(stream)
// unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();

The following is result of default NMS and CPU-free NMS executed on V100 with the latest docker image nvcr.io/nvidia/pytorch:24.12-py3. It can be seen that the execution time of CPU-free NMS is half of the default NMS. In the worst-case scenario, we simulated a situation in which all objects output by the object detection model survive. The best-case scenario returned only one object, while the random case randomly assigned properties to all objects. By implementing these three distinct cases, we are able to effectively evaluate the performance of the proposed methods under a range of circumstances, and make meaningful comparisons between them. fig_nms_objs_scatter_default fig_nms_objs_scatter_gpu

I would strongly recommend that the community can cite the paper [1] to help users understand that the code modifications involved in the PR (#8766) are based not just on experimental results, but also on clear explanations and insights drawn from computing architecture.

[UPDATED-1] I carefully compared the implementation of CPU-free NMS in this PR with my implementation and found slight differences. I can do another comparison to access performance If necessary.

References:

[1] Chen, E. C., Chen, P. Y., Chung, I., & Lee, C. R. (2024). Latency Attack Resilience in Object Detectors: Insights from Computing Architecture. In Proceedings of the Asian Conference on Computer Vision (pp. 3206-3222).

Thanks for your contribution and the sources shared. It is nice to be backed by a paper with a more thorough analysis. I'll read the paper and reference it in the code as comment to bring context for future users.

@antoinebrl
Copy link
Contributor

antoinebrl commented Jan 9, 2025

Well done, the improvement looks quite significant ! @NicolasHug any chance this makes it to the next release of torchvision ?

@CNOCycle
Copy link

CNOCycle commented Jan 10, 2025

To provide a proper NMS implementation for end users, I would like to offer the following personal comments.

As I mentioned earlier, my NMS-free NMS implementation differs slightly from the proposed one. In my approach, the function gather_keep_from_mask is fused into nms_kernel_impl, which results in only one GPU kernel being launched. Also, this kernel is executed by a single warp to ensure correctness.

From what I understand, most object detectors limit the top-k objects fed into the NMS function, with k generally set between 1000 and 5000. In this range, the elapsed time for NMS is under 2 ms in both the random and best-case scenarios. (The worst case is unlikely to occur with real-world datasets.) However, the overhead of launching a second GPU kernel is approximately 1-5 ms, which depends heavily on the CPU's status.

That being said, I completely agree that executing nms_kernel_impl with only one warp becomes inefficient when dealing with more than 20,000 objects. Even in a suboptimal CPU-free NMS implementation, its performance exceeds that of Async NMS. For this reason, I highly recommend replacing the default NMS implementation with the CPU-free NMS.

In summary, I believe that feeding 20,000+ objects into NMS is a rare occurrence, and the fused NMS implementation can offer slightly better performance in most scenarios.

Another issue to consider is whether we should maintain two NMS implementations in torchvision (_batched_nms_coordinate_trick and _batched_nms_vanilla) once the performance of the NMS kernel has been refined.

This is my personal comment, and I completely understand that the maintainers may have other considerations. I wholeheartedly respect any decisions made by the maintainers.

@NicolasHug
Copy link
Member

NicolasHug commented Jan 10, 2025

@NicolasHug any chance this makes it to the next release of torchvision ?

Sorry, we've already done the branch cut for 0.21 last month, so we won't be able to land it with the release coming up later in January. I will make sure to review it for the following one though!

@Ghelfi Ghelfi force-pushed the nms-unwrap-on-cuda branch from 3f40bdb to 5c2a5bd Compare February 11, 2025 05:32
@Ghelfi
Copy link
Contributor Author

Ghelfi commented Feb 11, 2025

@NicolasHug Can we try to ship this with the next release if we decide to move forward.

@NicolasHug
Copy link
Member

NicolasHug commented Feb 20, 2025

Yes, I'll try to make this happen for the next release. Sorry, time has been scarce lately. Thank you for your patience on this

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR.

In all honesty I didn't meticulously checked the correctness of the new kernel - I'm mostly relying on our existing tests for that.

For my own record, I ran the following benchmarks where I could reproduce the massive improvements as num_boxes grows.

import torch
from time import perf_counter_ns
from torchvision.ops import nms


def bench(f, *args, num_exp=100, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        torch.cuda.synchronize()
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms", prefix=""):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{prefix}{med = :.2f}{unit} +- {std:.2f}")
    return med


def make_boxes(num_boxes, num_classes=4):
    boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1).to("cuda")
    assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
    assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

    scores = torch.rand(num_boxes).to("cuda")
    idxs = torch.randint(0, num_classes, size=(num_boxes,)).to("cuda")
    return boxes, scores, idxs

NUM_EXP = 30
for num_boxes in (10, 100, 1000, 10000, 100000):
    boxes, scores, _ = make_boxes(num_boxes)
    times = bench(nms, boxes, scores, iou_threshold=.7, warmup=1, num_exp=NUM_EXP)
    report_stats(times, prefix=f"{num_boxes = } ")
On main
num_boxes = 10 med = 0.24ms +- 0.10
num_boxes = 100 med = 0.25ms +- 0.02
num_boxes = 1000 med = 0.31ms +- 0.02
num_boxes = 10000 med = 3.18ms +- 3.78
num_boxes = 100000 med = 1408.21ms +- 27.27

This PR
num_boxes = 10 med = 0.25ms +- 0.10
num_boxes = 100 med = 0.26ms +- 0.02
num_boxes = 1000 med = 0.29ms +- 0.03
num_boxes = 10000 med = 0.95ms +- 0.04
num_boxes = 100000 med = 15.32ms +- 1.50

I'll merge this PR when it's ready and probably will follow-up with some update to the batched-nms heuristic, as @CNOCycle suggested.

@NicolasHug NicolasHug merged commit e239710 into pytorch:main Feb 20, 2025
48 of 54 checks passed
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added.
The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

github-merge-queue bot pushed a commit to intel/torch-xpu-ops that referenced this pull request Mar 5, 2025
Based on the NMS updates in pytorch/vision#8766,
this PR moves the gather-keep section of the `nms` op from CPU to XPU.
This causes a very minor slowdown for small num_boxes `< 400` but
drastically increases performance for large num_boxes by eliminating
data transfer between XPU and CPU. Since the number of boxes is
typically `> 1000`, this is a reasonable change.

<details><summary>Details</summary>
<p>

```
XPU New Code Timings
num_boxes = 10 med = 0.60ms +- 0.10  # _batched_nms_coordinate_trick
num_boxes = 10 med = 2.73ms +- 0.04  # _batched_nms_vanilla
num_boxes = 100 med = 0.60ms +- 0.03
num_boxes = 100 med = 2.75ms +- 0.07
num_boxes = 200 med = 0.60ms +- 0.03
num_boxes = 200 med = 2.77ms +- 0.05
num_boxes = 400 med = 0.61ms +- 0.03
num_boxes = 400 med = 2.80ms +- 0.03
num_boxes = 800 med = 0.61ms +- 0.02
num_boxes = 800 med = 2.81ms +- 0.03
num_boxes = 1000 med = 0.62ms +- 0.01
num_boxes = 1000 med = 2.15ms +- 0.12
num_boxes = 2000 med = 0.54ms +- 0.01
num_boxes = 2000 med = 2.15ms +- 0.01
num_boxes = 10000 med = 1.76ms +- 0.02
num_boxes = 10000 med = 3.25ms +- 0.02
num_boxes = 20000 med = 2.83ms +- 0.03
num_boxes = 20000 med = 4.74ms +- 0.02
num_boxes = 80000 med = 17.79ms +- 0.05
num_boxes = 80000 med = 12.27ms +- 0.03
num_boxes = 100000 med = 25.76ms +- 0.04
num_boxes = 100000 med = 15.43ms +- 0.04
num_boxes = 200000 med = 85.42ms +- 0.26
num_boxes = 200000 med = 36.35ms +- 0.04

XPU - main
num_boxes = 10 med = 0.47ms +- 0.08
num_boxes = 10 med = 2.35ms +- 0.07
num_boxes = 100 med = 0.59ms +- 0.03
num_boxes = 100 med = 2.40ms +- 0.09
num_boxes = 200 med = 0.60ms +- 0.04
num_boxes = 200 med = 2.46ms +- 0.06
num_boxes = 400 med = 0.60ms +- 0.03
num_boxes = 400 med = 2.98ms +- 0.03
num_boxes = 800 med = 0.61ms +- 0.01
num_boxes = 800 med = 2.98ms +- 0.02
num_boxes = 1000 med = 0.62ms +- 0.01
num_boxes = 1000 med = 3.01ms +- 0.02
num_boxes = 2000 med = 0.66ms +- 0.01
num_boxes = 2000 med = 3.34ms +- 0.02
num_boxes = 10000 med = 3.82ms +- 3.67
num_boxes = 10000 med = 5.31ms +- 1.82
num_boxes = 20000 med = 20.92ms +- 1.70
num_boxes = 20000 med = 7.22ms +- 1.43
num_boxes = 80000 med = 119.85ms +- 5.65
num_boxes = 80000 med = 90.21ms +- 3.99
num_boxes = 100000 med = 168.14ms +- 4.02
num_boxes = 100000 med = 123.07ms +- 1.49
num_boxes = 200000 med = 457.85ms +- 70.04
num_boxes = 200000 med = 254.54ms +- 5.27
```

```python
import torch
from time import perf_counter_ns
from torchvision.ops import nms
from torchvision.ops.boxes import _batched_nms_coordinate_trick, _batched_nms_vanilla

def bench(f, *args, num_exp=1000, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        torch.xpu.synchronize()
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms", prefix=""):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{prefix}{med = :.2f}{unit} +- {std:.2f}")
    return med


def make_boxes(num_boxes, num_classes=4, device="xpu"):
    boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1).to(device)
    assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
    assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

    scores = torch.rand(num_boxes).to(device)
    idxs = torch.randint(0, num_classes, size=(num_boxes,)).to(device)
    return boxes, scores, idxs

NUM_EXP = 30
for num_boxes in (10, 100, 200, 400, 600, 800, 1000, 1400, 2000, 10000, 20_000, 80_000, 100000, 200_000):
    for f in (_batched_nms_coordinate_trick, _batched_nms_vanilla):
        boxes, scores, idxs = make_boxes(num_boxes)
        times = bench(f, boxes, scores, idxs, iou_threshold=.7, warmup=1, num_exp=NUM_EXP)
        report_stats(times, prefix=f"{num_boxes = } ")

```

</p>
</details>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torchvision.ops.boxes.batched_nms slow on large box numbers
5 participants