Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alpha should not be optimized in updating weight #5

Open
yangsenius opened this issue Mar 26, 2019 · 8 comments
Open

alpha should not be optimized in updating weight #5

yangsenius opened this issue Mar 26, 2019 · 8 comments

Comments

@yangsenius
Copy link

self.alpha_normal = nn.Parameter(torch.randn(k, num_ops))

nn.Parameters() will make the alpha and beta registered to model.parameters(), so your optimizer will update the alpha and beta when optimize the weight of operations. So i think the nn.parameters() should not be used in here, which will be not consistent with the paper or original code.

@yangsenius yangsenius changed the title alpha and beta should not be optimized in update weight alpha should not be optimized in update weight Mar 26, 2019
@yangsenius yangsenius changed the title alpha should not be optimized in update weight alpha should not be optimized in updating weight Mar 26, 2019
@dragen1860
Copy link
Owner

dragen1860 commented Apr 1, 2019

@yangsenius you remind me! Thank you.
Have you try 👍

self.alpha_normal = torch.randn(k, num_ops)
self.alpha_reduce = torch.randn(k, num_ops)

What's the performance when you update the code with above statement?
Please tell me if you re-run the exp.

@yangsenius
Copy link
Author

self.alpha_normal = torch.randn(k, num_ops)
self.alpha_reduce = torch.randn(k, num_ops)

mill make self.alpha_normal and self.alpha_reduce always be torch.floatTensor, somtimes causing error with model.cuda(), this is a little trouble.
maybe


self.alpha_normal = torch.randn(k, num_ops, dtype = self.your_conv.dtype)
self.alpha_reduce = torch.randn(k, num_ops, dtype = self.your_conv.dtype)

is OK?

or

just

self.alpha_normal = nn.Parameter(torch.randn(k, num_ops)) 

def filter(model):
    for name, param in model.name_parameters():
        if 'alpha' in name:
            contiue
        yield param

optimizer = torch.optim.Adam(filter(model),)

What do you think about ? Does it have a better code implementation about this issue?

@dragen1860
Copy link
Owner

dragen1860 commented Apr 2, 2019

since we usually set device to 'cuda:0', the

self.alpha_reduce = torch.randn(k, num_ops, dtype = torch.device("cuda"))

would be ok option.
and see any problems.

@yangsenius

@zh583007354
Copy link

Hi, I also noticed this problem yesterday. I think that making the parameters into two groups maybe a good choice. When training a ConvNet (ie. MobileNet), we always make the weights or parameters of conv having decay : 5e-4, but no decay for BN, so we will define
optimizer = SGD([{param groups1 for conv with decay}, {param groups2 for BN without decay}])

I think we can separate alphas and weights in this way.

@dragen1860 @yangsenius

@yangsenius
Copy link
Author

Yeah, you get it . @zh583007354

self.alpha_normal = nn.Parameter(torch.randn(k, num_ops)) 
def filter(model):
    for name, param in model.name_parameters():
        if 'alpha' in name:
            contiue
        yield param
optimizer = torch.optim.Adam([ {'weights':filter(model), 'alphas':model.alpha_normal}])

@skx6
Copy link

skx6 commented Jun 26, 2019

It takes one hour for a epoch to search architecture. However, the paper use "a small network of 8 cells is trained using DARTS for 50 epochs. The search takes one day on a single GPU". If I train 50 epochs. It will take more than two days.

@zh583007354
Copy link

@yangsenius hi, I have another question.

I want to know whether it is necessary of the clip_grad_norm_() in train_search.py
loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step()

If it is necessary to clip the gradient, should it be used for only weight params or all params?

Thank you.

@yangsenius
Copy link
Author

yangsenius commented Jun 30, 2019

I think the necessity of this clip_grad_norm_() is unknowable. Because we can't get the gradient range of the parameters, but this should be done to avoid gradient explosions (just in case), although this may not happen. So this code snippet may be useless . @zh583007354

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants