diff --git a/digits/frameworks/caffe_framework.py b/digits/frameworks/caffe_framework.py
index f5d4eac9b..878d8759d 100644
--- a/digits/frameworks/caffe_framework.py
+++ b/digits/frameworks/caffe_framework.py
@@ -121,3 +121,8 @@ def get_network_visualization(self, desc):
net.name = 'Network'
return ''
+ @override
+ def can_accumulate_gradients(self):
+ return (config_value('caffe_root')['version']
+ > parse_version('0.14.0-alpha'))
+
diff --git a/digits/frameworks/framework.py b/digits/frameworks/framework.py
index 60d816db8..095e72efd 100644
--- a/digits/frameworks/framework.py
+++ b/digits/frameworks/framework.py
@@ -76,7 +76,6 @@ def get_network_visualization(self, desc):
"""
raise NotImplementedError('Please implement me')
-
-
-
+ def can_accumulate_gradients(self):
+ return False
diff --git a/digits/model/forms.py b/digits/model/forms.py
index 210513813..276cddc06 100644
--- a/digits/model/forms.py
+++ b/digits/model/forms.py
@@ -128,6 +128,15 @@ def validate_py_ext(form, field):
tooltip = "How many images to process at once. If blank, values are used from the network definition."
)
+ batch_accumulation = utils.forms.IntegerField('Batch Accumulation',
+ default=1,
+ validators = [
+ validators.NumberRange(min=1),
+ validators.Optional(),
+ ],
+ tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)."
+ )
+
### Solver types
solver_type = utils.forms.SelectField(
diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py
index cff56656a..65f99ab98 100644
--- a/digits/model/images/classification/views.py
+++ b/digits/model/images/classification/views.py
@@ -253,23 +253,24 @@ def create():
else ''), form.python_layer_server_file.data)
job.tasks.append(fw.create_train_task(
- job = job,
- dataset = datasetJob,
- train_epochs = form.train_epochs.data,
- snapshot_interval = form.snapshot_interval.data,
- learning_rate = form.learning_rate.data[0],
- lr_policy = policy,
- gpu_count = gpu_count,
- selected_gpus = selected_gpus,
- batch_size = form.batch_size.data[0],
- val_interval = form.val_interval.data,
- pretrained_model= pretrained_model,
- crop_size = form.crop_size.data,
- use_mean = form.use_mean.data,
- network = network,
- random_seed = form.random_seed.data,
- solver_type = form.solver_type.data,
- shuffle = form.shuffle.data,
+ job = job,
+ dataset = datasetJob,
+ train_epochs = form.train_epochs.data,
+ snapshot_interval = form.snapshot_interval.data,
+ learning_rate = form.learning_rate.data[0],
+ lr_policy = policy,
+ gpu_count = gpu_count,
+ selected_gpus = selected_gpus,
+ batch_size = form.batch_size.data[0],
+ batch_accumulation = form.batch_accumulation.data,
+ val_interval = form.val_interval.data,
+ pretrained_model = pretrained_model,
+ crop_size = form.crop_size.data,
+ use_mean = form.use_mean.data,
+ network = network,
+ random_seed = form.random_seed.data,
+ solver_type = form.solver_type.data,
+ shuffle = form.shuffle.data,
)
)
diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py
index 63014438c..e4badb6bd 100644
--- a/digits/model/images/generic/views.py
+++ b/digits/model/images/generic/views.py
@@ -205,23 +205,24 @@ def create():
else ''), form.python_layer_server_file.data)
job.tasks.append(fw.create_train_task(
- job = job,
- dataset = datasetJob,
- train_epochs = form.train_epochs.data,
- snapshot_interval = form.snapshot_interval.data,
- learning_rate = form.learning_rate.data[0],
- lr_policy = policy,
- gpu_count = gpu_count,
- selected_gpus = selected_gpus,
- batch_size = form.batch_size.data[0],
- val_interval = form.val_interval.data,
- pretrained_model= pretrained_model,
- crop_size = form.crop_size.data,
- use_mean = form.use_mean.data,
- network = network,
- random_seed = form.random_seed.data,
- solver_type = form.solver_type.data,
- shuffle = form.shuffle.data,
+ job = job,
+ dataset = datasetJob,
+ train_epochs = form.train_epochs.data,
+ snapshot_interval = form.snapshot_interval.data,
+ learning_rate = form.learning_rate.data[0],
+ lr_policy = policy,
+ gpu_count = gpu_count,
+ selected_gpus = selected_gpus,
+ batch_size = form.batch_size.data[0],
+ batch_accumulation = form.batch_accumulation.data,
+ val_interval = form.val_interval.data,
+ pretrained_model = pretrained_model,
+ crop_size = form.crop_size.data,
+ use_mean = form.use_mean.data,
+ network = network,
+ random_seed = form.random_seed.data,
+ solver_type = form.solver_type.data,
+ shuffle = form.shuffle.data,
)
)
diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py
index 55b9abde9..3602233f8 100644
--- a/digits/model/tasks/caffe_train.py
+++ b/digits/model/tasks/caffe_train.py
@@ -426,6 +426,11 @@ def save_files_classification(self):
solver.snapshot_prefix = self.snapshot_prefix
+ # Batch accumulation
+ from digits.frameworks import CaffeFramework
+ if CaffeFramework().can_accumulate_gradients():
+ solver.iter_size = self.batch_accumulation
+
# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
@@ -623,6 +628,11 @@ def save_files_generic(self):
solver.snapshot_prefix = self.snapshot_prefix
+ # Batch accumulation
+ from digits.frameworks import CaffeFramework
+ if CaffeFramework().can_accumulate_gradients():
+ solver.iter_size = self.batch_accumulation
+
# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_image_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
diff --git a/digits/model/tasks/train.py b/digits/model/tasks/train.py
index af6e5dc92..adaedad25 100644
--- a/digits/model/tasks/train.py
+++ b/digits/model/tasks/train.py
@@ -38,6 +38,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
gpu_count -- how many GPUs to use for training (integer)
selected_gpus -- a list of GPU indexes to be used for training
batch_size -- if set, override any network specific batch_size with this value
+ batch_accumulation -- accumulate gradients over multiple batches
val_interval -- how many epochs between validating the model with an epoch of validation data
pretrained_model -- filename for a model to use for fine-tuning
crop_size -- crop each image down to a square of this size
@@ -47,6 +48,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
self.gpu_count = kwargs.pop('gpu_count', None)
self.selected_gpus = kwargs.pop('selected_gpus', None)
self.batch_size = kwargs.pop('batch_size', None)
+ self.batch_accumulation = kwargs.pop('batch_accumulation', None)
self.val_interval = kwargs.pop('val_interval', None)
self.pretrained_model = kwargs.pop('pretrained_model', None)
self.crop_size = kwargs.pop('crop_size', None)
diff --git a/digits/templates/models/images/classification/new.html b/digits/templates/models/images/classification/new.html
index 834e82167..c48ec1d23 100644
--- a/digits/templates/models/images/classification/new.html
+++ b/digits/templates/models/images/classification/new.html
@@ -160,6 +160,12 @@
Solver Options
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
+
+ {{form.batch_accumulation.label}}
+ {{form.batch_accumulation.tooltip}}
+ {{form.batch_accumulation(class='form-control')}}
+
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
@@ -407,7 +413,8 @@
Solver Options
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
- can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
+ can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
+ can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
@@ -432,6 +439,14 @@ Solver Options
$("#torch-warning").hide();
$('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show');
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');
+
+ if (frameworks[fwid].can_accumulate_gradients) {
+ $('#batch_accumulation').prop('disabled', false);
+ $('#batch-accumulation-option').show();
+ } else {
+ $('#batch-accumulation-option').hide();
+ $('#batch_accumulation').prop('disabled', true);
+ }
}
diff --git a/digits/templates/models/images/generic/new.html b/digits/templates/models/images/generic/new.html
index bc6afcf86..b305a91a2 100644
--- a/digits/templates/models/images/generic/new.html
+++ b/digits/templates/models/images/generic/new.html
@@ -160,6 +160,12 @@ Solver Options
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
+
+ {{form.batch_accumulation.label}}
+ {{form.batch_accumulation.tooltip}}
+ {{form.batch_accumulation(class='form-control')}}
+
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
@@ -404,7 +410,8 @@
Solver Options
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
- can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
+ can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
+ can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
@@ -422,6 +429,14 @@ Solver Options
$("select[name=solver_type] > option:selected").prop("selected", false);
}
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');
+
+ if (frameworks[fwid].can_accumulate_gradients) {
+ $('#batch_accumulation').prop('disabled', false);
+ $('#batch-accumulation-option').show();
+ } else {
+ $('#batch-accumulation-option').hide();
+ $('#batch_accumulation').prop('disabled', true);
+ }
}