Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question for speed up #3

Open
jameslahm opened this issue Mar 24, 2023 · 1 comment
Open

Question for speed up #3

jameslahm opened this issue Mar 24, 2023 · 1 comment

Comments

@jameslahm
Copy link

Thank you for your great work! In Table 2 in the paper, I see that the pruned DeiT-Tiny can speed up the throughput from 2648.7 to 4496.2.
image
But in my local test, I found that the pruned DeiT-Tiny's throughput (1819) is similar to the original DeiT-Tiny (1760). I use the provided compressed DeiT-Tiny model (Acc@1: 71.6, https://drive.google.com/file/d/1NSq3SRxnObfl6oaFE5gHtjnhzm0Lfc6S/view?usp=sharing). My environment is RTX 3090 and the throughput code is below:

@torch.no_grad()
def throughput(data_loader, model, local_rank):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)

        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        torch.cuda.synchronize()
        tic1 = time.time()
        for i in range(30):
            model(images)
        torch.cuda.synchronize()
        tic2 = time.time()
        throughput = 30 * batch_size / (tic2 - tic1)
        if local_rank == 0:
            print("throughput averaged with 30 times")
            print(f"batch_size {batch_size} throughput {throughput}")
        return

I wonder if I did something wrong. Would you mind sharing your code for testing throughput? Thanks a lot.

@Daner-Wang
Copy link
Owner

Thank you for your comments. In our evaluation, we test the inference time of all MHSA and FFN modules in the model to estimate its throughputs. We have adopted your code and made comparison, in which we find that your results may be influenced by the token selection function which is not well optimized. Thank your for helping us to find this problem and we will try to optimize this function to make it faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants