diff --git a/digits/frameworks/caffe_framework.py b/digits/frameworks/caffe_framework.py index f5d4eac9b..878d8759d 100644 --- a/digits/frameworks/caffe_framework.py +++ b/digits/frameworks/caffe_framework.py @@ -121,3 +121,8 @@ def get_network_visualization(self, desc): net.name = 'Network' return '' + @override + def can_accumulate_gradients(self): + return (config_value('caffe_root')['version'] + > parse_version('0.14.0-alpha')) + diff --git a/digits/frameworks/framework.py b/digits/frameworks/framework.py index 60d816db8..095e72efd 100644 --- a/digits/frameworks/framework.py +++ b/digits/frameworks/framework.py @@ -76,7 +76,6 @@ def get_network_visualization(self, desc): """ raise NotImplementedError('Please implement me') - - - + def can_accumulate_gradients(self): + return False diff --git a/digits/model/forms.py b/digits/model/forms.py index ee15c6847..647036964 100644 --- a/digits/model/forms.py +++ b/digits/model/forms.py @@ -129,11 +129,12 @@ def validate_py_ext(form, field): ) iter_size = utils.forms.IntegerField('Iteration size', + default=1, validators = [ validators.NumberRange(min=1), validators.Optional(), ], - tooltip = "Accumulate gradients over `iter_size` x `batch_size` instances." + tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)." ) ### Solver types diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py index a5a89bf8e..f1c82c0de 100644 --- a/digits/model/tasks/caffe_train.py +++ b/digits/model/tasks/caffe_train.py @@ -427,7 +427,7 @@ def save_files_classification(self): solver.snapshot_prefix = self.snapshot_prefix # Iteration size - solver.iter_size = self.iter_size or 1 + solver.iter_size = self.iter_size # Epochs -> Iterations train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size)) diff --git a/digits/templates/models/images/classification/new.html b/digits/templates/models/images/classification/new.html index fb3f4b860..0c2b9d7f7 100644 --- a/digits/templates/models/images/classification/new.html +++ b/digits/templates/models/images/classification/new.html @@ -160,10 +160,11 @@

Solver Options

{{form.batch_size(class='form-control', placeholder='[network defaults]')}} -
+
{{form.solver_type.label}} @@ -437,6 +438,17 @@

Solver Options

$("#torch-warning").hide(); $('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show'); $('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show'); + + $('#iter-size-option').hide(); + $('#iter_size').prop('disabled', true); + {% for fw in frameworks %} + {% if fw.can_accumulate_gradients() %} + if (fwid == "{{ fw.get_id() }}") { + $('#iter_size').prop('disabled', false); + $('#iter-size-option').show(); + } + {% endif %} + {% endfor %} } diff --git a/digits/templates/models/images/generic/new.html b/digits/templates/models/images/generic/new.html index c52d8d921..674102ba5 100644 --- a/digits/templates/models/images/generic/new.html +++ b/digits/templates/models/images/generic/new.html @@ -163,7 +163,7 @@

Solver Options

{{form.iter_size.label}} {{form.iter_size.tooltip}} - {{form.iter_size(class='form-control', placeholder='[default = 1]')}} + {{form.iter_size(class='form-control')}}
{{form.solver_type.label}}