Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Expose iter_size solver option #744

Merged
merged 2 commits into from
May 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ def get_network_visualization(self, desc):
net.name = 'Network'
return '<image src="data:image/png;base64,' + caffe.draw.draw_net(net, 'UD').encode('base64') + '" style="max-width:100%" />'

@override
def can_accumulate_gradients(self):
return (config_value('caffe_root')['version']
> parse_version('0.14.0-alpha'))

5 changes: 2 additions & 3 deletions digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_network_visualization(self, desc):
"""
raise NotImplementedError('Please implement me')




def can_accumulate_gradients(self):
return False

9 changes: 9 additions & 0 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def validate_py_ext(form, field):
tooltip = "How many images to process at once. If blank, values are used from the network definition."
)

batch_accumulation = utils.forms.IntegerField('Batch Accumulation',
default=1,
validators = [
validators.NumberRange(min=1),
validators.Optional(),
],
tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)."
)

### Solver types

solver_type = utils.forms.SelectField(
Expand Down
35 changes: 18 additions & 17 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,23 +253,24 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
batch_accumulation = form.batch_accumulation.data,
val_interval = form.val_interval.data,
pretrained_model = pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the whitespace changes. It had to be done at some point.

)
)

Expand Down
35 changes: 18 additions & 17 deletions digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,24 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
batch_accumulation = form.batch_accumulation.data,
val_interval = form.val_interval.data,
pretrained_model = pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
)
)

Expand Down
10 changes: 10 additions & 0 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ def save_files_classification(self):

solver.snapshot_prefix = self.snapshot_prefix

# Batch accumulation
from digits.frameworks import CaffeFramework
if CaffeFramework().can_accumulate_gradients():
solver.iter_size = self.batch_accumulation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to make this change in save_files_generic() too.

Copy link
Member Author

@lukeyeager lukeyeager May 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d'oh. Done.

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
Expand Down Expand Up @@ -623,6 +628,11 @@ def save_files_generic(self):

solver.snapshot_prefix = self.snapshot_prefix

# Batch accumulation
from digits.frameworks import CaffeFramework
if CaffeFramework().can_accumulate_gradients():
solver.iter_size = self.batch_accumulation

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_image_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
Expand Down
2 changes: 2 additions & 0 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
gpu_count -- how many GPUs to use for training (integer)
selected_gpus -- a list of GPU indexes to be used for training
batch_size -- if set, override any network specific batch_size with this value
batch_accumulation -- accumulate gradients over multiple batches
val_interval -- how many epochs between validating the model with an epoch of validation data
pretrained_model -- filename for a model to use for fine-tuning
crop_size -- crop each image down to a square of this size
Expand All @@ -47,6 +48,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
self.gpu_count = kwargs.pop('gpu_count', None)
self.selected_gpus = kwargs.pop('selected_gpus', None)
self.batch_size = kwargs.pop('batch_size', None)
self.batch_accumulation = kwargs.pop('batch_accumulation', None)
self.val_interval = kwargs.pop('val_interval', None)
self.pretrained_model = kwargs.pop('pretrained_model', None)
self.crop_size = kwargs.pop('crop_size', None)
Expand Down
17 changes: 16 additions & 1 deletion digits/templates/models/images/classification/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ <h4>Solver Options</h4>
</small>
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
</div>
<div class="form-group{{mark_errors([form.batch_accumulation])}}"
id="batch-accumulation-option" style="display:none;">
{{form.batch_accumulation.label}}
{{form.batch_accumulation.tooltip}}
{{form.batch_accumulation(class='form-control')}}
</div>
<div class="form-group{{mark_errors([form.solver_type])}}">
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
Expand Down Expand Up @@ -407,7 +413,8 @@ <h4>Solver Options</h4>
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
Expand All @@ -432,6 +439,14 @@ <h4>Solver Options</h4>
$("#torch-warning").hide();
$('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show');
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');

if (frameworks[fwid].can_accumulate_gradients) {
$('#batch_accumulation').prop('disabled', false);
$('#batch-accumulation-option').show();
} else {
$('#batch-accumulation-option').hide();
$('#batch_accumulation').prop('disabled', true);
}
}
</script>

Expand Down
17 changes: 16 additions & 1 deletion digits/templates/models/images/generic/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ <h4>Solver Options</h4>
</small>
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
</div>
<div class="form-group{{mark_errors([form.batch_accumulation])}}"
id="batch-accumulation-option" style="display:none;">
{{form.batch_accumulation.label}}
{{form.batch_accumulation.tooltip}}
{{form.batch_accumulation(class='form-control')}}
</div>
<div class="form-group{{mark_errors([form.solver_type])}}">
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
Expand Down Expand Up @@ -404,7 +410,8 @@ <h4>Solver Options</h4>
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
Expand All @@ -422,6 +429,14 @@ <h4>Solver Options</h4>
$("select[name=solver_type] > option:selected").prop("selected", false);
}
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');

if (frameworks[fwid].can_accumulate_gradients) {
$('#batch_accumulation').prop('disabled', false);
$('#batch-accumulation-option').show();
} else {
$('#batch-accumulation-option').hide();
$('#batch_accumulation').prop('disabled', true);
}
}
</script>

Expand Down