From 83c3a6a1081bc4db7e619f86b4906d1876076887 Mon Sep 17 00:00:00 2001 From: hamlet Date: Tue, 1 Dec 2020 12:36:01 +0800 Subject: [PATCH] Fix pipline dataloader when batch elements contain tuple --- deepspeed/runtime/pipe/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 0dbcf88eb4e8..9f6958b8a2e7 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -621,7 +621,8 @@ def _exec_load_micro_batch(self, buffer_id): loaded = batch[0].clone().to(self.device).detach() loaded.requires_grad = loaded.is_floating_point() else: - assert isinstance(batch[0], tuple) + # XXX: torch 1.6.0 DataLoader will auto convert tuple to list + assert isinstance(batch[0], (tuple, list)) # Assume list or tuple loaded = [] for x in batch[0]: @@ -637,7 +638,8 @@ def _exec_load_micro_batch(self, buffer_id): loaded = batch[1] if torch.is_tensor(batch[1]): loaded = batch[1].to(self.device) - elif isinstance(batch[1], tuple): + # XXX: torch 1.6.0 DataLoader will auto convert tuple to list + elif isinstance(batch[1], (tuple, list)): loaded = [] for x in batch[1]: assert torch.is_tensor(x)