From ce374ba87767d551f720242d5e64bfa976531079 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 13 Jul 2020 08:37:38 -0400 Subject: [PATCH] Fix Trainer in DataParallel setting (#5685) * Fix Trainer in DataParallel setting * Fix typo Co-authored-by: Sam Shleifer --- src/transformers/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 067f793d12a287..e21e6cdc46679a 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -618,6 +618,9 @@ def _training_step( if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past + # Our model outputs do not work with DataParallel, so forcing return tuple. + if self.args.n_gpu > 1: + inputs["return_tuple"] = True outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) @@ -818,6 +821,9 @@ def _prediction_loop( inputs[k] = v.to(self.args.device) if self.args.past_index >= 0: inputs["mems"] = past + # Our model outputs do not work with DataParallel, so forcing return tuple. + if self.args.n_gpu > 1: + inputs["return_tuple"] = True with torch.no_grad(): outputs = model(**inputs)