Skip to content

Commit

Permalink
Added [net] dynamic_minibatch=1 for increasing mini_batch_size when r…
Browse files Browse the repository at this point in the history
…andom=1 is used
  • Loading branch information
AlexeyAB committed Mar 2, 2020
1 parent df9e602 commit c814d56
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 55 deletions.
3 changes: 3 additions & 0 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct layer {
int batch_normalize;
int shortcut;
int batch;
int dynamic_minibatch;
int forced;
int flipped;
int inputs;
Expand Down Expand Up @@ -640,6 +641,7 @@ typedef struct network {
int n;
int batch;
uint64_t *seen;
int *cur_iteration;
int *t;
float epoch;
int subdivisions;
Expand Down Expand Up @@ -739,6 +741,7 @@ typedef struct network {
size_t max_delta_gpu_size;
//#endif // GPU
int optimized_memory;
int dynamic_minibatch;
size_t workspace_size_limit;
} network;

Expand Down
6 changes: 3 additions & 3 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)

//#ifdef CUDNN_HALF
//if (state.use_mixed_precision) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
{
Expand Down Expand Up @@ -671,7 +671,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
float alpha = 1, beta = 0;

//#ifdef CUDNN_HALF
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
{
Expand Down Expand Up @@ -978,7 +978,7 @@ void assisted_activation2_gpu(float alpha, float *output, float *gt_gpu, float *

void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
{
const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
const int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);

// epoch
//const float epoch = (float)(*state.net.seen) / state.net.train_images_num;
Expand Down
2 changes: 1 addition & 1 deletion src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)

if (l->activation == SWISH || l->activation == MISH) l->activation_input = (float*)realloc(l->activation_input, total_batch*l->outputs * sizeof(float));
#ifdef GPU
if (old_w < w || old_h < h) {
if (old_w < w || old_h < h || l->dynamic_minibatch) {
if (l->train) {
cuda_free(l->delta_gpu);
l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs);
Expand Down
108 changes: 73 additions & 35 deletions src/detector.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i

srand(time(0));
int seed = rand();
int i;
for (i = 0; i < ngpus; ++i) {
int k;
for (k = 0; k < ngpus; ++k) {
srand(seed);
#ifdef GPU
cuda_set_device(gpus[i]);
cuda_set_device(gpus[k]);
#endif
nets[i] = parse_network_cfg(cfgfile);
nets[i].benchmark_layers = benchmark_layers;
nets[k] = parse_network_cfg(cfgfile);
nets[k].benchmark_layers = benchmark_layers;
if (weightfile) {
load_weights(&nets[i], weightfile);
load_weights(&nets[k], weightfile);
}
if (clear) *nets[i].seen = 0;
nets[i].learning_rate *= ngpus;
if (clear) *nets[k].seen = 0;
nets[k].learning_rate *= ngpus;
}
srand(time(0));
network net = nets[0];
Expand All @@ -105,12 +105,13 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int train_images_num = plist->size;
char **paths = (char **)list_to_array(plist);

int init_w = net.w;
int init_h = net.h;
const int init_w = net.w;
const int init_h = net.h;
const int init_b = net.batch;
int iter_save, iter_save_last, iter_map;
iter_save = get_current_batch(net);
iter_save_last = get_current_batch(net);
iter_map = get_current_batch(net);
iter_save = get_current_iteration(net);
iter_save_last = get_current_iteration(net);
iter_map = get_current_iteration(net);
float mean_average_precision = -1;
float best_map = mean_average_precision;

Expand Down Expand Up @@ -165,7 +166,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
pthread_t load_thread = load_data(args);
int count = 0;
//while(i*imgs < N*120){
while (get_current_batch(net) < net.max_batches) {
while (get_current_iteration(net) < net.max_batches) {
if (l.random && count++ % 10 == 0) {
float rand_coef = 1.4;
if (l.random != 1.0) rand_coef = l.random;
Expand All @@ -175,26 +176,48 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
int dim_h = roundl(random_val*init_h / net.resize_step + 1) * net.resize_step;
if (random_val < 1 && (dim_w > init_w || dim_h > init_h)) dim_w = init_w, dim_h = init_h;

// at the beginning
if (avg_loss < 0) {
dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step;
int max_dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
int max_dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step;

// at the beginning (check if enough memory) and at the end (calc rolling mean/variance)
if (avg_loss < 0 || get_current_iteration(net) > net.max_batches - 100) {
dim_w = max_dim_w;
dim_h = max_dim_h;
}

if (dim_w < net.resize_step) dim_w = net.resize_step;
if (dim_h < net.resize_step) dim_h = net.resize_step;
int dim_b = (init_b * max_dim_w * max_dim_h) / (dim_w * dim_h);
int new_dim_b = (int)(dim_b * 0.8);
if (new_dim_b > init_b) dim_b = new_dim_b;

printf("%d x %d \n", dim_w, dim_h);
args.w = dim_w;
args.h = dim_h;

int k;
if (net.dynamic_minibatch) {
for (k = 0; k < ngpus; ++k) {
(*nets[k].seen) = init_b * net.subdivisions * get_current_iteration(net); // remove this line, when you will save to weights-file both: seen & cur_iteration
nets[k].batch = dim_b;
int j;
for (j = 0; j < nets[k].n; ++j)
nets[k].layers[j].batch = dim_b;
}
net.batch = dim_b;
imgs = net.batch * net.subdivisions * ngpus;
args.n = imgs;
printf("\n %d x %d (batch = %d) \n", dim_w, dim_h, net.batch);
}
else
printf("\n %d x %d \n", dim_w, dim_h);

pthread_join(load_thread, 0);
train = buffer;
free_data(train);
load_thread = load_data(args);

for (i = 0; i < ngpus; ++i) {
resize_network(nets + i, dim_w, dim_h);
for (k = 0; k < ngpus; ++k) {
resize_network(nets + k, dim_w, dim_h);
}
net = nets[0];
}
Expand Down Expand Up @@ -246,7 +269,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
if (avg_loss < 0 || avg_loss != avg_loss) avg_loss = loss; // if(-inf or nan)
avg_loss = avg_loss*.9 + loss*.1;

i = get_current_batch(net);
const int iteration = get_current_iteration(net);
//i = get_current_batch(net);

int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
calc_map_for_each = fmax(calc_map_for_each, 100);
Expand All @@ -259,22 +283,36 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
}

if (net.cudnn_half) {
if (i < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
if (iteration < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
else fprintf(stderr, "\n Tensor Cores are used.");
}
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), i*imgs);
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", iteration, loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), iteration*imgs);

int draw_precision = 0;
if (calc_map && (i >= next_map_calc || i == net.max_batches)) {
if (calc_map && (iteration >= next_map_calc || iteration == net.max_batches)) {
if (l.random) {
printf("Resizing to initial size: %d x %d \n", init_w, init_h);
printf("Resizing to initial size: %d x %d ", init_w, init_h);
args.w = init_w;
args.h = init_h;
int k;
if (net.dynamic_minibatch) {
for (k = 0; k < ngpus; ++k) {
for (k = 0; k < ngpus; ++k) {
nets[k].batch = init_b;
int j;
for (j = 0; j < nets[k].n; ++j)
nets[k].layers[j].batch = init_b;
}
}
net.batch = init_b;
imgs = init_b * net.subdivisions * ngpus;
args.n = imgs;
printf("\n %d x %d (batch = %d) \n", init_w, init_h, init_b);
}
pthread_join(load_thread, 0);
free_data(train);
train = buffer;
load_thread = load_data(args);
int k;
for (k = 0; k < ngpus; ++k) {
resize_network(nets + k, init_w, init_h);
}
Expand All @@ -286,7 +324,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
// combine Training and Validation networks
//network net_combined = combine_train_valid_networks(net, net_map);

iter_map = i;
iter_map = iteration;
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
printf("\n mean_average_precision ([email protected]) = %f \n", mean_average_precision);
if (mean_average_precision > best_map) {
Expand All @@ -300,23 +338,23 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
draw_precision = 1;
}
#ifdef OPENCV
draw_train_loss(windows_name, img, img_size, avg_loss, max_img_loss, i, net.max_batches, mean_average_precision, draw_precision, "mAP%", dont_show, mjpeg_port);
draw_train_loss(windows_name, img, img_size, avg_loss, max_img_loss, iteration, net.max_batches, mean_average_precision, draw_precision, "mAP%", dont_show, mjpeg_port);
#endif // OPENCV

//if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) {
//if (i % 100 == 0) {
if (i >= (iter_save + 1000) || i % 1000 == 0) {
iter_save = i;
if (iteration >= (iter_save + 1000) || iteration % 1000 == 0) {
iter_save = iteration;
#ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, iteration);
save_weights(net, buff);
}

if (i >= (iter_save_last + 100) || i % 100 == 0) {
iter_save_last = i;
if (iteration >= (iter_save_last + 100) || iteration % 100 == 0) {
iter_save_last = iteration;
#ifdef GPU
if (ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
Expand Down Expand Up @@ -350,7 +388,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
free_list_contents_kvp(options);
free_list(options);

for (i = 0; i < ngpus; ++i) free_network(nets[i]);
for (k = 0; k < ngpus; ++k) free_network(nets[k]);
free(nets);
//free_network(net);

Expand Down
19 changes: 14 additions & 5 deletions src/dropout_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ dropout_layer make_dropout_layer(int batch, int inputs, float probability, int d
l.forward_gpu = forward_dropout_layer_gpu;
l.backward_gpu = backward_dropout_layer_gpu;
l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch);
l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch);
if (l.dropblock) {
l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch);
l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch);
}
#endif
if (l.dropblock) {
if(l.dropblock_size_abs) fprintf(stderr, "dropblock p = %.3f l.dropblock_size_abs = %d %4d -> %4d\n", probability, l.dropblock_size_abs, inputs, inputs);
Expand All @@ -48,11 +50,18 @@ void resize_dropout_layer(dropout_layer *l, int inputs)
{
l->inputs = l->outputs = inputs;
l->rand = (float*)xrealloc(l->rand, l->inputs * l->batch * sizeof(float));
#ifdef GPU
#ifdef GPU
cuda_free(l->rand_gpu);

l->rand_gpu = cuda_make_array(l->rand, l->inputs*l->batch);
#endif

if (l->dropblock) {
cudaFreeHost(l->drop_blocks_scale);
l->drop_blocks_scale = cuda_make_array_pinned(l->rand, l->batch);

cuda_free(l->drop_blocks_scale_gpu);
l->drop_blocks_scale_gpu = cuda_make_array(l->rand, l->batch);
}
#endif
}

void forward_dropout_layer(dropout_layer l, network_state state)
Expand Down
12 changes: 6 additions & 6 deletions src/dropout_layer_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
void forward_dropout_layer_gpu(dropout_layer l, network_state state)
{
if (!state.train) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;

// We gradually increase the block size and the probability of dropout - during the first half of the training
Expand Down Expand Up @@ -141,9 +141,9 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
for (int b = 0; b < l.batch; ++b) {
const float prob = l.drop_blocks_scale[b] * block_size * block_size / (float)l.outputs;
const float scale = 1.0f / (1.0f - prob);
printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size);
printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n",
l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale);
//printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size);
//printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n",
// l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale);
l.drop_blocks_scale[b] = scale;
}

Expand Down Expand Up @@ -176,14 +176,14 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
void backward_dropout_layer_gpu(dropout_layer l, network_state state)
{
if(!state.delta) return;
//int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;

int size = l.inputs*l.batch;

// dropblock
if (l.dropblock) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
float multiplier = 1.0;
if (iteration_num < (state.net.max_batches*0.85))
multiplier = (iteration_num / (float)(state.net.max_batches*0.85));
Expand Down
Loading

0 comments on commit c814d56

Please sign in to comment.