-
Notifications
You must be signed in to change notification settings - Fork 8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BoF (Bag of Freebies) - Visually Coherent Image Mixup ~+4 AP@[.5, .95] #3272
Comments
can be used in Classifier and Detector. BoF (Bag of Freebies) includes 5 features:
@dkashkin Hi, Mixup should be used also for darknet53.conv pre-trained weights: |
@AlexeyAB yes this sounds great. I like the simplicity of this idea - it's just an additional augmentation strategy that can be implemented in a few lines of code. The question is - can it be easily supported in Darknet training? Ideally, this should be an optional line in the config file... |
@dkashkin
But it can be implemented in several lines only on Python - I just try to understand what do they do for mixup (bag of freebies), so as not to miss important details and points: https://arxiv.org/pdf/1710.09412v2.pdf Where is Beta distribution: https://www.astroml.org/book_figures/chapter3/fig_beta_distribution.html Beta-distribution on C will be something like this - it has not even the most complex implementation :) numpy/numpy#688 Beta-distribution on C++: https://gist.github.com/sftrabbit/5068941 #include <iostream>
#include <sstream>
#include <string>
#include <random>
namespace sftrabbit {
template <typename RealType = double>
class beta_distribution
{
public:
typedef RealType result_type;
class param_type
{
public:
typedef beta_distribution distribution_type;
explicit param_type(RealType a = 2.0, RealType b = 2.0)
: a_param(a), b_param(b) { }
RealType a() const { return a_param; }
RealType b() const { return b_param; }
bool operator==(const param_type& other) const
{
return (a_param == other.a_param &&
b_param == other.b_param);
}
bool operator!=(const param_type& other) const
{
return !(*this == other);
}
private:
RealType a_param, b_param;
};
explicit beta_distribution(RealType a = 2.0, RealType b = 2.0)
: a_gamma(a), b_gamma(b) { }
explicit beta_distribution(const param_type& param)
: a_gamma(param.a()), b_gamma(param.b()) { }
void reset() { }
param_type param() const
{
return param_type(a(), b());
}
void param(const param_type& param)
{
a_gamma = gamma_dist_type(param.a());
b_gamma = gamma_dist_type(param.b());
}
template <typename URNG>
result_type operator()(URNG& engine)
{
return generate(engine, a_gamma, b_gamma);
}
template <typename URNG>
result_type operator()(URNG& engine, const param_type& param)
{
gamma_dist_type a_param_gamma(param.a()),
b_param_gamma(param.b());
return generate(engine, a_param_gamma, b_param_gamma);
}
result_type min() const { return 0.0; }
result_type max() const { return 1.0; }
RealType a() const { return a_gamma.alpha(); }
RealType b() const { return b_gamma.alpha(); }
bool operator==(const beta_distribution<result_type>& other) const
{
return (param() == other.param() &&
a_gamma == other.a_gamma &&
b_gamma == other.b_gamma);
}
bool operator!=(const beta_distribution<result_type>& other) const
{
return !(*this == other);
}
private:
typedef std::gamma_distribution<result_type> gamma_dist_type;
gamma_dist_type a_gamma, b_gamma;
template <typename URNG>
result_type generate(URNG& engine,
gamma_dist_type& x_gamma,
gamma_dist_type& y_gamma)
{
result_type x = x_gamma(engine);
return x / (x + y_gamma(engine));
}
};
template <typename CharT, typename RealType>
std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os,
const beta_distribution<RealType>& beta)
{
os << "~Beta(" << beta.a() << "," << beta.b() << ")";
return os;
}
template <typename CharT, typename RealType>
std::basic_istream<CharT>& operator>>(std::basic_istream<CharT>& is,
beta_distribution<RealType>& beta)
{
std::string str;
RealType a, b;
if (std::getline(is, str, '(') && str == "~Beta" &&
is >> a && is.get() == ',' && is >> b && is.get() == ')') {
beta = beta_distribution<RealType>(a, b);
} else {
is.setstate(std::ios::failbit);
}
return is;
}
}
void data_augmentation(...) {
std::random_device rd;
std::mt19937 gen(rd());
// beta_val1 = 1.5, beta_val2 = 1.5 for B(1.5, 1.5)
sftrabbit::beta_distribution<double> beta_distr_obj(beta_val1, beta_val2);
double beta_distribution = beta_distr_obj(gen);
float alpha_blend = beta_distribution ;
float beta_blend= 1 - beta_distribution;
cv::addWeighted( src1, alpha_blend, src2, beta_blend, 0.0, dst); // mixup images
fuse_labels(src_label1, alpha_blend, src_label2, beta_blend, new_label); // mixup labels
}
Also what do they do for LSR (class label smoothing), very academically written: https://arxiv.org/pdf/1512.00567v3.pdf |
@AlexeyAB sorry for delay! I missed your reply. It might be more reliable to discuss this over email (kashkin at gmail). I agree with you - there are some papers that describe mixup via heavy math resulting in unnecessary complexity. I like the following visual explanation much better: |
@dkashkin Yes, may be we can try to implement mixup with fixed |
@AlexeyAB I think this would be a great starting point! |
@dkashkin Hi, I added Mixup data augmentation. If you want to see result of data augmentation use flag |
Thanks @AlexeyAB ! I'm running a test now, I'll report back on the results shortly. |
Running the latest repo with the mixup option is causing the process to be "killed":
An identical model that I trained yesterday was fine, the only difference is that yesterday's model used the repo. as it existed yesterday and didn't have the mixup flag set. |
I just tried running the same again without -letter_box and it crashed in the same way. Running again now without mix_up=1 and it seems to be going fine so far. |
@LukeAI Hi, can you share your cfg-file and dataset if it isn't private? How many CPU-RAM do you have?
In general, I think it should work, since will be mixed-up two sequences (not just random images). |
The dataset is a subset of Google OpenImages |
I fixed it. |
Hey all - just to feedback, I found that mixup hurt my AP slightly in all classes when using the above cfg and Dataset at the final validation. (I didn't try older weights but the chart.png looked fairly flat.) |
As I see there are only
|
I randomly split the labelled kitti OBJECT2D dataset into 85% training, 15% testing I wrote my own script to do the conversions. |
@LukeAI So it seems that Mixup doesn't increase mAP for the most cases, or it requires more iterations. |
yolov3-tiny_3l.cfg.txt |
@WongKinYiu I added MixUp and CutMix for Classifier training: #4419 |
@AlexeyAB Great! I need about 2~3 weeks to train a classifier. |
I added all 5 features from the BoF (Bag of Freebies). Have you tested them? Also there is new implementation of |
@AlexeyAB Hello, i m still on a holiday. could u plz give an example of [net] and [yolo] layers for suggested data aug and ciou norm hyper-parameters? by the way, cutmix performs much better than mixup for training a classifier currently. |
|
@AlexeyAB thanks! |
for training classifier and do u mean i should change 1000 to same as max_batches? |
I don't know. The only thing I did not fully understand from smoothing loss paper,
Yes, so it will have only 1 cycle, therefore, you don’t have to calculate the max butches so that the training ends exactly at the end of one of the cycles. Also I fixed that |
Thanks, i ll check the training stats tomorrow and feedback to you. |
Although the loss continuous increases to 9xxx, the behavior seems normal. |
Do you mean that accuracy is increasing? |
@AlexeyAB yes. |
@WongKinYiu @AlexeyAB FYI I found in the TensorFlow code two different implementations of label smoothing:
where I assume |
Yes, previouse we use (1) in both Classifier and Detector. Now we use (1) for Classifier (Softmax) and (2) for Detector (Logistic) |
@AlexeyAB ah perfect then! |
get nan when set |
Does anyone knows how to use CmBN in cspdarknet53? and another question about label smoothing, should I add label_smooth_eps in both [net] and all [yolo] ?
|
change all of |
This is a question, not a bug. @AlexeyAB is there a way to incorporate the "Visually Coherent Image Mixup" augmentation strategy? Amazon recently published a paper that claims that this approach results in a massive mAP improvement for YOLOv3: https://arxiv.org/abs/1902.04103
Needless to say, I would love a chance to try this idea and I'd be happy to share my results.
The text was updated successfully, but these errors were encountered: