diff --git a/src/datasampler.py b/src/datasampler.py index 888ce67..7f7b7c8 100644 --- a/src/datasampler.py +++ b/src/datasampler.py @@ -2,10 +2,10 @@ import math from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler import random import torch -from New_Dataset.multi_dataset import multi_dataset +from Dataset.multi_dataset import multi_dataset def make_batch(index_list, batch_size, drop_last): if drop_last: @@ -20,7 +20,7 @@ def make_batch(index_list, batch_size, drop_last): batches.append(index_list[batch_size*_:(batch_size*(_+1))]) return batches -def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True): +def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True, seed = 0): len_2D = len(dataset.data_whole_2D) len_3D = len(dataset.data_whole_3D) @@ -29,6 +29,9 @@ def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffl assert len(index_2D) + len(index_3D) == len(dataset.data_whole) if shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(seed) random.shuffle(index_2D) random.shuffle(index_3D) @@ -53,7 +56,7 @@ class My_DistributedBatchSampler(Sampler): [1, 3, 5, 7, 9] """ - def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True): + def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True, seed: int = 0): self.num_replicas = num_replicas self.rank = rank self.drop_last = drop_last @@ -61,6 +64,8 @@ def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, bat self.dataset = dataset self.batch_size_2D = batch_size_2D self.batch_size_3D = batch_size_3D + self.seed = seed + self.epoch = 0 if num_replicas is None or rank is None: # pragma: no cover if not torch.distributed.is_initialized(): @@ -86,9 +91,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, bat self.total_size = self.num_samples * self.num_replicas def __iter__(self): - indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle) + indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle,self.seed + self.epoch) # print(indices) if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) random.shuffle(indices) if not self.drop_last: @@ -108,16 +116,33 @@ def __iter__(self): assert len(indices) == self.num_samples return iter(indices) - + def __len__(self): return self.num_samples + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch # print(My_DistributedBatchSampler) -# Train_dataset = multi_dataset(text_tokenizer = '/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer') +# Train_dataset = multi_dataset(text_tokenizer = '/mnt/petrelfs/share_data/zhangxiaoman/CODE/RadFM/src/Language_models/tokenizer') # DDP_sample_0 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 0,)) +# DDP_sample_1 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 1,)) # for ii in DDP_sample_0: -# print(ii) \ No newline at end of file +# print(ii) + +# for ii in DDP_sample_1: +# print(ii) +