diff --git a/digits/frameworks/caffe_framework.py b/digits/frameworks/caffe_framework.py
index f5d4eac9b..878d8759d 100644
--- a/digits/frameworks/caffe_framework.py
+++ b/digits/frameworks/caffe_framework.py
@@ -121,3 +121,8 @@ def get_network_visualization(self, desc):
net.name = 'Network'
return ''
+ @override
+ def can_accumulate_gradients(self):
+ return (config_value('caffe_root')['version']
+ > parse_version('0.14.0-alpha'))
+
diff --git a/digits/frameworks/framework.py b/digits/frameworks/framework.py
index 60d816db8..095e72efd 100644
--- a/digits/frameworks/framework.py
+++ b/digits/frameworks/framework.py
@@ -76,7 +76,6 @@ def get_network_visualization(self, desc):
"""
raise NotImplementedError('Please implement me')
-
-
-
+ def can_accumulate_gradients(self):
+ return False
diff --git a/digits/model/forms.py b/digits/model/forms.py
index ee15c6847..647036964 100644
--- a/digits/model/forms.py
+++ b/digits/model/forms.py
@@ -129,11 +129,12 @@ def validate_py_ext(form, field):
)
iter_size = utils.forms.IntegerField('Iteration size',
+ default=1,
validators = [
validators.NumberRange(min=1),
validators.Optional(),
],
- tooltip = "Accumulate gradients over `iter_size` x `batch_size` instances."
+ tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)."
)
### Solver types
diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py
index a5a89bf8e..f1c82c0de 100644
--- a/digits/model/tasks/caffe_train.py
+++ b/digits/model/tasks/caffe_train.py
@@ -427,7 +427,7 @@ def save_files_classification(self):
solver.snapshot_prefix = self.snapshot_prefix
# Iteration size
- solver.iter_size = self.iter_size or 1
+ solver.iter_size = self.iter_size
# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size))
diff --git a/digits/templates/models/images/classification/new.html b/digits/templates/models/images/classification/new.html
index fb3f4b860..0c2b9d7f7 100644
--- a/digits/templates/models/images/classification/new.html
+++ b/digits/templates/models/images/classification/new.html
@@ -160,10 +160,11 @@
Solver Options
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
-