Skip to content

Commit

Permalink
Update datasampler.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoyi-wu authored Mar 25, 2024
1 parent 12504fa commit 5505f3c
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions src/datasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import math
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import Sampler
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler
import random
import torch
from New_Dataset.multi_dataset import multi_dataset
from Dataset.multi_dataset import multi_dataset

def make_batch(index_list, batch_size, drop_last):
if drop_last:
Expand All @@ -20,7 +20,7 @@ def make_batch(index_list, batch_size, drop_last):
batches.append(index_list[batch_size*_:(batch_size*(_+1))])
return batches

def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True):
def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True, seed = 0):

len_2D = len(dataset.data_whole_2D)
len_3D = len(dataset.data_whole_3D)
Expand All @@ -29,6 +29,9 @@ def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffl
assert len(index_2D) + len(index_3D) == len(dataset.data_whole)

if shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(seed)
random.shuffle(index_2D)
random.shuffle(index_3D)

Expand All @@ -53,14 +56,16 @@ class My_DistributedBatchSampler(Sampler):
[1, 3, 5, 7, 9]
"""

def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True):
def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True, seed: int = 0):
self.num_replicas = num_replicas
self.rank = rank
self.drop_last = drop_last
self.shuffle = shuffle
self.dataset = dataset
self.batch_size_2D = batch_size_2D
self.batch_size_3D = batch_size_3D
self.seed = seed
self.epoch = 0

if num_replicas is None or rank is None: # pragma: no cover
if not torch.distributed.is_initialized():
Expand All @@ -86,9 +91,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, bat
self.total_size = self.num_samples * self.num_replicas

def __iter__(self):
indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle)
indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle,self.seed + self.epoch)
# print(indices)
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
random.shuffle(indices)

if not self.drop_last:
Expand All @@ -108,16 +116,33 @@ def __iter__(self):
assert len(indices) == self.num_samples

return iter(indices)

def __len__(self):
return self.num_samples

def set_epoch(self, epoch: int) -> None:
r"""
Set the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch


# print(My_DistributedBatchSampler)
# Train_dataset = multi_dataset(text_tokenizer = '/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer')
# Train_dataset = multi_dataset(text_tokenizer = '/mnt/petrelfs/share_data/zhangxiaoman/CODE/RadFM/src/Language_models/tokenizer')

# DDP_sample_0 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 0,))
# DDP_sample_1 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 1,))

# for ii in DDP_sample_0:
# print(ii)
# print(ii)

# for ii in DDP_sample_1:
# print(ii)

0 comments on commit 5505f3c

Please sign in to comment.