From e8afb76900b65d2ffbbfaad6622a02827485d2d7 Mon Sep 17 00:00:00 2001 From: chaoyi-wu <18706207256@163.com> Date: Mon, 25 Mar 2024 21:25:01 +0800 Subject: [PATCH] Update trainer.py --- src/My_Trainer/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/My_Trainer/trainer.py b/src/My_Trainer/trainer.py index 6eb3b22..d2e93e9 100644 --- a/src/My_Trainer/trainer.py +++ b/src/My_Trainer/trainer.py @@ -1880,9 +1880,10 @@ def _inner_training_loop( total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + ### 吴超逸加 ### + if isinstance(train_dataloader, DataLoader) and (isinstance(train_dataloader.sampler, DistributedSampler) or self.args.data_sampler != None): train_dataloader.sampler.set_epoch(epoch) - elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): + elif hasattr(train_dataloader, "dataset") and (isinstance(train_dataloader.sampler, DistributedSampler) or self.args.data_sampler != None): train_dataloader.dataset.set_epoch(epoch) if is_torch_tpu_available():