-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
alpha should not be optimized in updating weight #5
Comments
@yangsenius you remind me! Thank you.
What's the performance when you update the code with above statement? |
mill make
is OK? or just
What do you think about ? Does it have a better code implementation about this issue? |
since we usually set device to 'cuda:0', the
would be ok option. |
Hi, I also noticed this problem yesterday. I think that making the parameters into two groups maybe a good choice. When training a ConvNet (ie. MobileNet), we always make the weights or parameters of conv having decay : 5e-4, but no decay for BN, so we will define I think we can separate alphas and weights in this way. |
Yeah, you get it . @zh583007354 self.alpha_normal = nn.Parameter(torch.randn(k, num_ops))
def filter(model):
for name, param in model.name_parameters():
if 'alpha' in name:
contiue
yield param
optimizer = torch.optim.Adam([ {'weights':filter(model), 'alphas':model.alpha_normal}]) |
It takes one hour for a epoch to search architecture. However, the paper use "a small network of 8 cells is trained using DARTS for 50 epochs. The search takes one day on a single GPU". If I train 50 epochs. It will take more than two days. |
@yangsenius hi, I have another question. I want to know whether it is necessary of the clip_grad_norm_() in train_search.py If it is necessary to clip the gradient, should it be used for only weight params or all params? Thank you. |
I think the necessity of this |
DARTS-PyTorch/model_search.py
Line 217 in cfcdd02
nn.Parameters() will make the alpha and beta registered to model.parameters(), so your optimizer will update the alpha and beta when optimize the weight of operations. So i think the nn.parameters() should not be used in here, which will be not consistent with the paper or original code.
The text was updated successfully, but these errors were encountered: