Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bptt cf #891 #926

Merged
merged 4 commits into from
Aug 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def train(self, train_iter_fct, valid_iter_fct, train_steps, valid_steps):
normalization += num_tokens
else:
normalization += batch.batch_size

accum += 1
if accum == self.grad_accum_count:
reduce_counter += 1
Expand All @@ -168,7 +167,7 @@ def train(self, train_iter_fct, valid_iter_fct, train_steps, valid_steps):
if self.n_gpu > 1:
normalization = sum(onmt.utils.distributed
.all_gather_list
(normalization))
(normalization.cpu()))

self._gradient_accumulation(
true_batchs, normalization, total_stats,
Expand Down Expand Up @@ -255,7 +254,7 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats,

for batch in true_batchs:
target_size = batch.tgt.size(0)
# Truncated BPTT
# Truncated BPTT: reminder not compatible with accum > 1
if self.trunc_size:
trunc_size = self.trunc_size
else:
Expand Down Expand Up @@ -288,20 +287,31 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats,
total_stats.update(batch_stats)
report_stats.update(batch_stats)

# 4. Update the parameters and statistics.
if self.grad_accum_count == 1:
# Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()

# If truncated, don't backprop fully.
if dec_state is not None:
dec_state.detach()

# 3.bis Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))

# 4. Update the parameters and statistics.
self.optim.step()
# in case of multi step gradient accumulation,
# update only after accum batches
if self.grad_accum_count > 1:
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()

def _start_report_manager(self, start_time=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def main(opt):
if opt.epochs:
raise AssertionError("-epochs is deprecated please use -train_steps.")

if opt.truncated_decoder > 0 and opt.accum_count > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")

if len(opt.gpuid) > 1:
multi_main(opt)
else:
Expand Down