Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the experiment in the paper #30

Open
870231652 opened this issue Apr 16, 2023 · 8 comments
Open

About the experiment in the paper #30

870231652 opened this issue Apr 16, 2023 · 8 comments

Comments

@870231652
Copy link

I have a question for you. I did an experiment using your method mentioned in the paper, which is stacking 10 pure convolutions. However, I found that the results were not ideal, and the FLOPS of Pconv were even lower than those of GPU on the CPU. I used a 3060 graphics card. Can you help me take a look? Is there any problem with the code? Thank you very much.
134SL7X)JU%IE(NHT7AJ9K9
2S4N9V W 393KKBPQLBH

@870231652
Copy link
Author

Excuse me, do you still have the code for this experiment? I want to use your experiment to create a table, but I have not been able to get good results in my tests. I am not sure if there is a problem with the model construction or with the FLOPS calculation. Would you mind sharing the code with me? I have been stuck on this experiment for several days without any progress in my work. I'm very sorry to have bothered you

@JierunChen
Copy link
Owner

JierunChen commented Apr 17, 2023

@870231652 Hi, you may consider the following:

  • The ReLu function should be removed.
  • .eval() mode should be enabled for model inference.
  • Benchmark mode can be enabled for GPU inference.
  • Slicing mode can be used for partial conv.

@870231652
Copy link
Author

import torch
import torch.nn as nn
from torch import Tensor
from typing import List
from functools import partial
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchsummary import summary
import time
from thop import profile

class C_Net(nn.Module):
    def __init__(self):
        super(C_Net, self).__init__()
        self.conv1 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv2 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv3 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv4 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv5 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv6 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv7 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv8 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv9 = nn.Conv2d(a[1], a[1], 3, padding=1)
        self.conv10 = nn.Conv2d(a[1], a[1], 3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        return x
class DW_Net(nn.Module):
    def __init__(self):
        super(DW_Net, self).__init__()
        self.conv1 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv2 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv3 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv4 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv5 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv6 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv7 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv8 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv9 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
        self.conv10 = nn.Conv2d(a[1], a[1], 3, padding=1,groups = a[1])
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        return x
class Partial_conv3(nn.Module):

    def __init__(self, dim,forward, n_div=8):
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim_untouched = dim - self.dim_conv3
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)

        if forward == 'slicing':
            self.forward = self.forward_slicing
        elif forward == 'split_cat':
            self.forward = self.forward_split_cat
        else:
            raise NotImplementedError
    def forward_slicing(self, x: Tensor) -> Tensor:
        # only for inference
        x = x.clone()   # !!! Keep the original input intact for the residual connection later
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
        return x
    def forward_split_cat(self, x: Tensor) -> Tensor:
        # for training/inference
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)
        return x
class P_Net(nn.Module):
    def __init__(self):
        super(P_Net, self).__init__()
        self.conv1 = Partial_conv3(a[1], forward=bb)
        self.conv2 = Partial_conv3(a[1], forward=bb)
        self.conv3 = Partial_conv3(a[1], forward=bb)
        self.conv4 = Partial_conv3(a[1], forward=bb)
        self.conv5 = Partial_conv3(a[1], forward=bb)
        self.conv6 = Partial_conv3(a[1], forward=bb)
        self.conv7 = Partial_conv3(a[1], forward=bb)
        self.conv8 = Partial_conv3(a[1], forward=bb)
        self.conv9 = Partial_conv3(a[1], forward=bb)
        self.conv10 = Partial_conv3(a[1], forward=bb)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        return x


#---------------------------------------------------------------------------------------------------------------------
a = 1, 768, 7, 7
bb = 'split_cat'# 'split_cat' 'slicing'
model =C_Net().cuda()
model.eval()
x = torch.randn(a).cuda()
torch.backends.cudnn.benchmark = True
#---------------------------------------------------------------------------------------------------------------------

flops, params = profile(model, inputs=(x, ))

num_runs = 100
start_time = time.time()
for i in range(num_runs):
    model(x)
end_time = time.time()
elapsed_time = end_time - start_time


flops_per_run = flops  
fps = num_runs / elapsed_time
flops_per_sec = flops_per_run * fps / 1e9  

once = elapsed_time / num_runs

print(f"FLOPS: {flops / 1e6:.2f} mFLOPS")
# print(f"Parameters: {params / 1e6:.2f} M")
print(f"Latency Elapsed time: {once * 1000:.2f} ms per run")
print(f"FPS: {fps:.2f}")
print(f"FLOPS per second: {flops_per_sec:.2f} GFLOPS/s")

@870231652
Copy link
Author

This is my experimental code, but after testing, I found that the frame rate of PConv is similar to that of ordinary Conv, but the computational efficiency (FLOPS) is much smaller. For example, when the input data is 1, 768, 7, 7, the frame rate of Conv is 900, with FLOPS of 2353G/s. When using 'split_cat', the frame rate of PConv is 1110 with FLOPS of 43G/s, and when using 'slicing', the frame rate of PConv is 666 with FLOPS of 27G/s, which is even worse than 'split_cat'. The frame rate of DWconv is 4545 with FLOPS of 15G/s. Is there something wrong with my code?

@JierunChen
Copy link
Owner

JierunChen commented Apr 29, 2023

@870231652 Hi, the PConv's statistics in the paper are measured by setting the partial ratio r=1/4 where in your case r=1/8.

When using the 'slicing' mode, the line x = x.clone() can be removed from the forward_slicingfunction.

@ahdxwg
Copy link

ahdxwg commented Jun 5, 2023

hello,I did the same experiment. I converted it to onnx format and tested it on cpu. The speed was not as fast as that of DW conv

@gaoxinghua951
Copy link

hello,I did the same experiment. I converted it to onnx format and tested it on cpu. The speed was not as fast as that of DW conv

hi, Can you tell me why FLOPS are calculated like you are? thank you

@JierunChen
Copy link
Owner

hello,I did the same experiment. I converted it to onnx format and tested it on cpu. The speed was not as fast as that of DW conv

Hi, the statistics reported in the paper are measured without converting to ONNX and intended for comparing the FLOPS but not latency.

Besides, PConv may not be well supported by ONNX and could be further optimized particularly regarding the slicing or concatenation operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants