diff --git a/onmt/trainer.py b/onmt/trainer.py index ffd5462433..21e4104c08 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -156,7 +156,6 @@ def train(self, train_iter_fct, valid_iter_fct, train_steps, valid_steps): normalization += num_tokens else: normalization += batch.batch_size - accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 @@ -168,7 +167,7 @@ def train(self, train_iter_fct, valid_iter_fct, train_steps, valid_steps): if self.n_gpu > 1: normalization = sum(onmt.utils.distributed .all_gather_list - (normalization)) + (normalization.cpu())) self._gradient_accumulation( true_batchs, normalization, total_stats, @@ -255,7 +254,7 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats, for batch in true_batchs: target_size = batch.tgt.size(0) - # Truncated BPTT + # Truncated BPTT: reminder not compatible with accum > 1 if self.trunc_size: trunc_size = self.trunc_size else: @@ -288,20 +287,31 @@ def _gradient_accumulation(self, true_batchs, normalization, total_stats, total_stats.update(batch_stats) report_stats.update(batch_stats) + # 4. Update the parameters and statistics. + if self.grad_accum_count == 1: + # Multi GPU gradient gather + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(1)) + self.optim.step() + # If truncated, don't backprop fully. if dec_state is not None: dec_state.detach() - # 3.bis Multi GPU gradient gather - if self.n_gpu > 1: - grads = [p.grad.data for p in self.model.parameters() - if p.requires_grad - and p.grad is not None] - onmt.utils.distributed.all_reduce_and_rescale_tensors( - grads, float(1)) - - # 4. Update the parameters and statistics. - self.optim.step() + # in case of multi step gradient accumulation, + # update only after accum batches + if self.grad_accum_count > 1: + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(1)) + self.optim.step() def _start_report_manager(self, start_time=None): """ diff --git a/train.py b/train.py index d0517d4f24..99d9df3b5d 100755 --- a/train.py +++ b/train.py @@ -18,6 +18,9 @@ def main(opt): if opt.epochs: raise AssertionError("-epochs is deprecated please use -train_steps.") + if opt.truncated_decoder > 0 and opt.accum_count > 1: + raise AssertionError("BPTT is not compatible with -accum > 1") + if len(opt.gpuid) > 1: multi_main(opt) else: