Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Improve segment_matmul by reducing launching overheads #213

Merged
merged 3 commits into from
Apr 19, 2023

Conversation

yaox12
Copy link
Contributor

@yaox12 yaox12 commented Mar 16, 2023

CUTLASS grouped gemm requires copying matrix pointers and layouts to the device memory, which brings significant "launch" overheads, more concretely, 7 pageable H2D copies. This PR sets up the arguments for grouped gemm in a CPU pinned buffer manually and copy it to the device memory at once to reduce such overheads.

Other changes include setting CUDA stream for the grouped gemm, adding proper C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK.

Performance

Benchmarking with the following script, this PR reduces the op time from 0.29 ms to 0.05 ms on my desktop (RTX 3090).

import torch
import time
import pyg_lib

def bench_pyg(a, b, seg):

    seg = torch.tensor(seg)
    pyg_lib.ops.segment_matmul(a, seg, b)
    torch.cuda.synchronize()

    tic = time.time()
    for _ in range(10):
        pyg_lib.ops.segment_matmul(a, seg, b)
    torch.cuda.synchronize()
    print(f"{(time.time() - tic) * 100:.2f} ms")

if __name__ == "__main__":
    num_seg = 20
    hid_dim = 64
    num_ele = 10000

    torch.manual_seed(42)
    device = torch.device("cuda:0")

    a = torch.rand((num_ele, hid_dim), device=device)
    b = torch.rand((num_seg, hid_dim, hid_dim), device=device)

    seg = torch.randint(num_ele, (num_seg - 1,)).sort()[0].tolist()
    seg = [0] + seg + [num_ele]

    bench_pyg(a, b, seg)

cc @rusty1s @puririshi98

@codecov-commenter
Copy link

Codecov Report

Merging #213 (2a76c9e) into master (c04fb60) will not change coverage.
The diff coverage is n/a.

@@           Coverage Diff           @@
##           master     #213   +/-   ##
=======================================
  Coverage   83.49%   83.49%           
=======================================
  Files          26       26           
  Lines         848      848           
=======================================
  Hits          708      708           
  Misses        140      140           

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@yaox12 yaox12 changed the title [Performance] Improve grouped gemm [Performance] Improve segment_matmul by reducing launching overheads Mar 16, 2023
@rusty1s rusty1s requested a review from puririshi98 March 16, 2023 08:33
@rusty1s
Copy link
Member

rusty1s commented Mar 16, 2023

Thank you! @puririshi98 do you mind to take a look?

@yaox12
Copy link
Contributor Author

yaox12 commented Apr 6, 2023

@puririshi98 Can I get your review?

Copy link
Contributor

@puririshi98 puririshi98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@puririshi98
Copy link
Contributor

thanks for this! @yaox12

@puririshi98 puririshi98 enabled auto-merge (squash) April 19, 2023 03:39
@puririshi98 puririshi98 merged commit a7c7742 into pyg-team:master Apr 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants