Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add EfficientNet #980

Closed
timonbimon opened this issue Jun 2, 2019 · 28 comments · Fixed by #4293
Closed

[Feature Request] Add EfficientNet #980

timonbimon opened this issue Jun 2, 2019 · 28 comments · Fixed by #4293

Comments

@timonbimon
Copy link

timonbimon commented Jun 2, 2019

https://ai.googleblog.com/2019/05/efficientnet-improving-accuracy-and.html

It seems like EfficientNet is faster and better than most the other nets out there currently. Would be awesome to see it in torchvision soon! :)

@soumith
Copy link
Member

soumith commented Jun 2, 2019

fyi it's available in https://github.com/rwightman/pytorch-image-models/blob/master/models/gen_efficientnet.py already

@cdluminate
Copy link

Would be awesome if it's available in torchvision.

@fmassa
Copy link
Member

fmassa commented Jun 3, 2019

It would also be great to have a training script that reproduces EfficientNet, from what I saw in the paper they also use some different training hyperparameters / tricks to improve the trained models, so it might take some effort to have reproducible training scripts for those.

@bhack
Copy link

bhack commented Jun 4, 2019

https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/README.md /cc @lukemelas

@lukemelas
Copy link

lukemelas commented Jun 4, 2019

Thanks for looping me in, @bhack. Yes, it would be great to merge EfficientNet into torchvision!

I don't think we're quite ready yet, though -- as @fmassa mentioned, we have to first have reliable training scripts. I would also like to rewrite EfficientNet in a more PyTorchic style; my implementation and the others I have seen generally mirror the TensorFlow implementation, which is great for quickly reproducing results, but could be improved for merging into torchvision.

@bhack
Copy link

bhack commented Jun 4, 2019

/cc @rwightman @zsef123

@rwightman
Copy link
Contributor

rwightman commented Jun 4, 2019

@bhack thanks for the vote, definitely think a variant of it should be here, but it should be 100% PyTorch trained. The TF ported weights with the necessary padding are not suitable as it's an unacceptable overhead.

Given the nature of these models, it takes a long time to train, and is fiddly to find the right hyper-params. I finally just got some MobileNet-V3 weights (very close to an EfficientNet-B0) that match/slightly exceed the paper after two weeks of attempts. I can do it with my training scripts, but it takes much more time than training something like a ResNet to great results.

@fmassa
Copy link
Member

fmassa commented Jun 4, 2019

@rwightman do you have the scripts for training MobileNet-V3 from scratch in PyTorch?

The MobileNetV2 pre-trained weights that I made available also had to be trained on a lot of epochs, and it indeed seems to take longer to train than a ResNet.

@rwightman
Copy link
Contributor

@fmassa TLDR 'sort of' :)

My train script can handle most of the training right now, basically using RMSprop with Google's crazy high epsilons for opt and batch norm and different BN momentum for 350-450+ epochs. That gets you most of the way, but the final necessary bit is EMA decay of the model weights, which PyTorch doesn't have a built-in default for. I have a hack for that working right now, and found it is very effective to enable it at the late stages of training, but validating the smoothed copy of the weights doesn't play nicely with distributed GPU training, so it adds an extra CPU only validation step. Trying to see if there is a cleaner way of doing it...

@rwightman
Copy link
Contributor

I also suspect that odd optimizer/bn params could be relaxed back to PyTorch defaults, and that the EMA weight decay is more significant, as it alone adds a solid 1.5-2% top-1 bump once the model is getting close to converged.

@lukemelas
Copy link

Agreed on all counts with @rwightman. It's important to have a fully PyTorchic model and training procedure before merging into torchvision. I'm also finding the training hyperparameters to be very sensitive, but we should get them right before officially releasing.

EMA weight decay ... alone adds a solid 1.5-2% top-1 bump once the model is getting close to converged.

I have not done extensive experimentation yet, but from early impressions, it seems to me as well that weight EMA could make quite a significant difference. @rwightman Do you think feel could be something worth trying to build into a core optimizer at some point in the future?

@rwightman
Copy link
Contributor

It's not torchvision, but it's pretty convenient with the spiffy new hub feature...

torch.hub.list('rwightman/gen-efficientnet-pytorch', force_reload=True)  
model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)

That model is native PyTorch, no TF padding hacks, tested to be fully exportable to ONNX/Caffe2. 76.9% top1, 93.2% top5. It's going be awhile before any of the larger ones are satisfactory. Heh.

If anyone wants to take a crack at B1-B3, the model weight EMA smoothing impl I've been using to train, on a dev branch in https://github.com/rwightman/pytorch-image-models/tree/ema-cleanup

@bhack
Copy link

bhack commented Jun 10, 2019

Thanks for the update @rwightman

@bermanmaxim
Copy link

@rwightman Thanks for the details on training efficientnet_b0. May I ask what parameters you used for your training script (i.e. did you use defaults or Google's RMSprop / etc...) Thank you!

@bermanmaxim
Copy link

@rwightman I noticed your comment on this in huggingface/pytorch-image-models#11. Thanks for the great work.

@abhuse
Copy link

abhuse commented Oct 28, 2019

May I suggest my implementation of EfficientNet: abhuse/pytorch-efficientnet
. It is a pythonic implementation that utilizes the notion of pytorch modules (i.e. nn.Module). I provide pretrained weights which are converted from those in official tensorflow repository and an evaluation script to run models with provided pretrained weights against imagenet validation set.

@rwightman
Copy link
Contributor

rwightman commented Oct 28, 2019

@abhuse I don't think there is a lack of pythonic, nn.Module based EfficientNet implementations, it's that there is not a set of pretrained weights that don't require the extra asymmetric padding like the Tensorflow native ones, and no PyTorch native reproduction of training from B0-B7 to matching accuracy without the padding.

@bhack
Copy link

bhack commented Oct 29, 2019

Pycls seems to have EfficientNet https://github.com/facebookresearch/pycls

@rwightman
Copy link
Contributor

Pycls seems to have EfficientNet https://github.com/facebookresearch/pycls

Doesn't help with the weights situation though... and looking at the included h-param config for training B0, don't think it'll match the paper or AutoAug/RandAug results

@mattans
Copy link

mattans commented Dec 30, 2019

Hi @rwightman , is it possible to use your impl as a backbone to torchvision's segmentation models? Similar to how ResNet is used today.

@rwightman
Copy link
Contributor

@mattans Yes, it should be possible. I have been slowly working on having a module backbone instantiation that will work with various repos like torcvision, detectron2, mmdetection. The WIP for this is in my pytorch-image-models (not in the gen-efficientnet version)

Note that it's not trivial getting the features from the correct location for these networks, they come from a post expansion point inside the IR blocks, not between blocks like ResNets, etc. Hence, I use hooks...

In [3]: m = timm.create_model('efficientnet_b0', features_only=True, pretrained=True, out_indices=(2,4))                                                                                                       
In [4]: o = m(torch.randn(1, 3, 224, 224))                                                                                                                                                                     
In [5]: o[0].shape                                                                                                                                                                                             
Out[5]: torch.Size([1, 1152, 7, 7])
In [6]: o[1].shape                                                                                                                                                                                             
Out[6]: torch.Size([1, 240, 28, 28])
In [7]: m.feature_channels()                                                                                                                                                                                   
Out[7]: [240, 1152]

In [8]: m = timm.create_model('efficientnet_b0', features_only=True, pretrained=True, out_indices=(1,2,3,4))                                                                                                   
In [9]: m.feature_channels()                                                                                                                                                                                   
Out[9]: [144, 240, 672, 1152]
In [10]: [print(x.shape) for x in m(torch.randn(1, 3, 224, 224))]                                                                                                                                              
torch.Size([1, 1152, 7, 7])
torch.Size([1, 672, 14, 14])
torch.Size([1, 240, 28, 28])
torch.Size([1, 144, 56, 56])

@rwightman
Copy link
Contributor

An update on EfficientNet training. Last month I managed to reproduce EfficientNet-B2 training from scratch using RandAugment I hit 80.4% top-1 with (https://github.com/rwightman/pytorch-image-models). I tried similar with EfficientNet-B3 but fell a bit short at 81.5% (only hits > 81.8% when I bump the resolution and crop factor at test time).

I feel confident I could reproduce training eventually given time and resources but sadly I don't have the GPU compute to move beyond B3. The B3 attempt took almost a month with dual Titan RTX.

I'm open to bringing in my EfficientNet, MixNet, MobileNetV3 from (https://github.com/rwightman/gen-efficientnet-pytorch) to here with some cleanup or simplification. But the training code necessary to replicate results varies quite a bit from the templates in torchvision.

@1e100
Copy link
Contributor

1e100 commented May 20, 2020

To be fair though, these difficulties mostly pertain to classifiers. I've been using EfficientNets as feature extractors for object detection quite successfully, even training from scratch. My datasets are not COCO, they're for the actual, practical problems I'm solving, so there are fewer classes, less occlusion, but also noisier labels, and fewer samples.

EfficientNets do take longer to train for this, of course, but not dramatically longer. Imagenet is just crazy hard as a dataset, with all the dog breeds and such - you're unlikely to encounter something as hard as this in a practical setting.

So TL;DR: for folks who are scared away by all these difficulties, if you just want a feature extractor, do try it out - it works better than you'd think.

@francisco-simoes
Copy link

@lukemelas Hi, thanks for the great port.
Any idea when it will be ready to port it torchvision?

@marcusturewicz
Copy link

@rwightman how many GPUs, and what spec of GPU, do you think would be required to train all models B0 through to B7 to the required accuracy? We could use AWS Deep Learning AMI with the required GPUs to train these models. The question then becomes about cost/funding for those instances. I would be willing to cough up some cash and I'm sure other people who are keen to use EfficientNet in PyTorch would be too. Maybe a Go Fund Me, or a sponsored GitHub repository would be the right avenue? I'm keen to help out here, whether that's sorting out funding, provision instances, or training models. Let me know your thoughts.

@rwightman
Copy link
Contributor

@marcusturewicz ... from b3 to b8 the scaling is roughly about 1/2 the throughput and batch size for each model step. For the B3 hparams I've trained from scratch with (that were an improvement on both the official RandAugment and AdvProp weights) it took almost 4 weeks (26 days) on 2 Titan RTX cards. Doing some rough calcs and translating that (ballpark) to V100s I'd estimate about $45-55K in P3 instance costs for the B7 if you managed in a single attempt.

However, the batch sizes are small there, so you'd likely need some degree of sync bn past the b5 model size (or 48GB cards). That could require a bit more hparam search and failed trials. Plus it slows things down a bit more than my calcs...

All in all, something that an org with lots of GPUs or TPUs should probably tackle :)

@1e100
Copy link
Contributor

1e100 commented Jan 9, 2021

All in all, something that an org with lots of GPUs or TPUs should probably tackle :)

Alternatively, someone with that $100K free AWS "startup" quota could take this on, especially if it's soon going to expire anyway, since it's only offered for a year. One of my clients had their quota expire last week, with tens of thousands of dollars worth of it unused, so I thought I'd mention that here as a possibility. I think Google offers something similar, but they're stingier with it.

@rwightman
Copy link
Contributor

@1e100 agreed, that would be a good way to do it but need to make the connections to a startup that's not going to burn through their credits, and at the right time..

Somene recently handed me the keys to an 8xV100 GCP machine that was funded by expiring startup credits... I trained some large ResNets and a medium sized EfficientDet. Tackling something like the B7 would have been too risky to have little to no payoff for the attempt and wait time.

If I do end up training these nets at some point it will likely be via a TPU POD w/ research credits using JAX (can easily port to PyTorch). Partially synchronized BN (sync subset of processes via grouping) is easier to spec in JAX and fits TPU arch well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.