diff --git a/digits/frameworks/caffe_framework.py b/digits/frameworks/caffe_framework.py index f5d4eac9b..878d8759d 100644 --- a/digits/frameworks/caffe_framework.py +++ b/digits/frameworks/caffe_framework.py @@ -121,3 +121,8 @@ def get_network_visualization(self, desc): net.name = 'Network' return '' + @override + def can_accumulate_gradients(self): + return (config_value('caffe_root')['version'] + > parse_version('0.14.0-alpha')) + diff --git a/digits/frameworks/framework.py b/digits/frameworks/framework.py index 60d816db8..095e72efd 100644 --- a/digits/frameworks/framework.py +++ b/digits/frameworks/framework.py @@ -76,7 +76,6 @@ def get_network_visualization(self, desc): """ raise NotImplementedError('Please implement me') - - - + def can_accumulate_gradients(self): + return False diff --git a/digits/model/forms.py b/digits/model/forms.py index 210513813..276cddc06 100644 --- a/digits/model/forms.py +++ b/digits/model/forms.py @@ -128,6 +128,15 @@ def validate_py_ext(form, field): tooltip = "How many images to process at once. If blank, values are used from the network definition." ) + batch_accumulation = utils.forms.IntegerField('Batch Accumulation', + default=1, + validators = [ + validators.NumberRange(min=1), + validators.Optional(), + ], + tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)." + ) + ### Solver types solver_type = utils.forms.SelectField( diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index cff56656a..65f99ab98 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -253,23 +253,24 @@ def create(): else ''), form.python_layer_server_file.data) job.tasks.append(fw.create_train_task( - job = job, - dataset = datasetJob, - train_epochs = form.train_epochs.data, - snapshot_interval = form.snapshot_interval.data, - learning_rate = form.learning_rate.data[0], - lr_policy = policy, - gpu_count = gpu_count, - selected_gpus = selected_gpus, - batch_size = form.batch_size.data[0], - val_interval = form.val_interval.data, - pretrained_model= pretrained_model, - crop_size = form.crop_size.data, - use_mean = form.use_mean.data, - network = network, - random_seed = form.random_seed.data, - solver_type = form.solver_type.data, - shuffle = form.shuffle.data, + job = job, + dataset = datasetJob, + train_epochs = form.train_epochs.data, + snapshot_interval = form.snapshot_interval.data, + learning_rate = form.learning_rate.data[0], + lr_policy = policy, + gpu_count = gpu_count, + selected_gpus = selected_gpus, + batch_size = form.batch_size.data[0], + batch_accumulation = form.batch_accumulation.data, + val_interval = form.val_interval.data, + pretrained_model = pretrained_model, + crop_size = form.crop_size.data, + use_mean = form.use_mean.data, + network = network, + random_seed = form.random_seed.data, + solver_type = form.solver_type.data, + shuffle = form.shuffle.data, ) ) diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py index 63014438c..e4badb6bd 100644 --- a/digits/model/images/generic/views.py +++ b/digits/model/images/generic/views.py @@ -205,23 +205,24 @@ def create(): else ''), form.python_layer_server_file.data) job.tasks.append(fw.create_train_task( - job = job, - dataset = datasetJob, - train_epochs = form.train_epochs.data, - snapshot_interval = form.snapshot_interval.data, - learning_rate = form.learning_rate.data[0], - lr_policy = policy, - gpu_count = gpu_count, - selected_gpus = selected_gpus, - batch_size = form.batch_size.data[0], - val_interval = form.val_interval.data, - pretrained_model= pretrained_model, - crop_size = form.crop_size.data, - use_mean = form.use_mean.data, - network = network, - random_seed = form.random_seed.data, - solver_type = form.solver_type.data, - shuffle = form.shuffle.data, + job = job, + dataset = datasetJob, + train_epochs = form.train_epochs.data, + snapshot_interval = form.snapshot_interval.data, + learning_rate = form.learning_rate.data[0], + lr_policy = policy, + gpu_count = gpu_count, + selected_gpus = selected_gpus, + batch_size = form.batch_size.data[0], + batch_accumulation = form.batch_accumulation.data, + val_interval = form.val_interval.data, + pretrained_model = pretrained_model, + crop_size = form.crop_size.data, + use_mean = form.use_mean.data, + network = network, + random_seed = form.random_seed.data, + solver_type = form.solver_type.data, + shuffle = form.shuffle.data, ) ) diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py index 55b9abde9..3602233f8 100644 --- a/digits/model/tasks/caffe_train.py +++ b/digits/model/tasks/caffe_train.py @@ -426,6 +426,11 @@ def save_files_classification(self): solver.snapshot_prefix = self.snapshot_prefix + # Batch accumulation + from digits.frameworks import CaffeFramework + if CaffeFramework().can_accumulate_gradients(): + solver.iter_size = self.batch_accumulation + # Epochs -> Iterations train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size)) solver.max_iter = train_iter * self.train_epochs @@ -623,6 +628,11 @@ def save_files_generic(self): solver.snapshot_prefix = self.snapshot_prefix + # Batch accumulation + from digits.frameworks import CaffeFramework + if CaffeFramework().can_accumulate_gradients(): + solver.iter_size = self.batch_accumulation + # Epochs -> Iterations train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_image_data_layer.data_param.batch_size)) solver.max_iter = train_iter * self.train_epochs diff --git a/digits/model/tasks/train.py b/digits/model/tasks/train.py index af6e5dc92..adaedad25 100644 --- a/digits/model/tasks/train.py +++ b/digits/model/tasks/train.py @@ -38,6 +38,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate, gpu_count -- how many GPUs to use for training (integer) selected_gpus -- a list of GPU indexes to be used for training batch_size -- if set, override any network specific batch_size with this value + batch_accumulation -- accumulate gradients over multiple batches val_interval -- how many epochs between validating the model with an epoch of validation data pretrained_model -- filename for a model to use for fine-tuning crop_size -- crop each image down to a square of this size @@ -47,6 +48,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate, self.gpu_count = kwargs.pop('gpu_count', None) self.selected_gpus = kwargs.pop('selected_gpus', None) self.batch_size = kwargs.pop('batch_size', None) + self.batch_accumulation = kwargs.pop('batch_accumulation', None) self.val_interval = kwargs.pop('val_interval', None) self.pretrained_model = kwargs.pop('pretrained_model', None) self.crop_size = kwargs.pop('crop_size', None) diff --git a/digits/templates/models/images/classification/new.html b/digits/templates/models/images/classification/new.html index 834e82167..c48ec1d23 100644 --- a/digits/templates/models/images/classification/new.html +++ b/digits/templates/models/images/classification/new.html @@ -160,6 +160,12 @@

Solver Options

{{form.batch_size(class='form-control', placeholder='[network defaults]')}} +
{{form.solver_type.label}} {{form.solver_type.tooltip}} @@ -407,7 +413,8 @@

Solver Options

{% for fw in frameworks %} framework = { name : '{{ fw.get_name() }}', - can_shuffle : '{{ fw.can_shuffle_data() }}'=='True' + can_shuffle : '{{ fw.can_shuffle_data() }}'=='True', + can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True', }; frameworks['{{ fw.get_id() }}'] = framework; {% endfor %} @@ -432,6 +439,14 @@

Solver Options

$("#torch-warning").hide(); $('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show'); $('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show'); + + if (frameworks[fwid].can_accumulate_gradients) { + $('#batch_accumulation').prop('disabled', false); + $('#batch-accumulation-option').show(); + } else { + $('#batch-accumulation-option').hide(); + $('#batch_accumulation').prop('disabled', true); + } } diff --git a/digits/templates/models/images/generic/new.html b/digits/templates/models/images/generic/new.html index bc6afcf86..b305a91a2 100644 --- a/digits/templates/models/images/generic/new.html +++ b/digits/templates/models/images/generic/new.html @@ -160,6 +160,12 @@

Solver Options

{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
+
{{form.solver_type.label}} {{form.solver_type.tooltip}} @@ -404,7 +410,8 @@

Solver Options

{% for fw in frameworks %} framework = { name : '{{ fw.get_name() }}', - can_shuffle : '{{ fw.can_shuffle_data() }}'=='True' + can_shuffle : '{{ fw.can_shuffle_data() }}'=='True', + can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True', }; frameworks['{{ fw.get_id() }}'] = framework; {% endfor %} @@ -422,6 +429,14 @@

Solver Options

$("select[name=solver_type] > option:selected").prop("selected", false); } $('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show'); + + if (frameworks[fwid].can_accumulate_gradients) { + $('#batch_accumulation').prop('disabled', false); + $('#batch-accumulation-option').show(); + } else { + $('#batch-accumulation-option').hide(); + $('#batch_accumulation').prop('disabled', true); + } }