-
Notifications
You must be signed in to change notification settings - Fork 4
/
IIIT5K.py
26 lines (22 loc) · 961 Bytes
/
IIIT5K.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import PIL
class IIIT5K(Dataset):
def __init__(self, root_dir, transform=None, train = True):
self.root_dir = root_dir
self.transform = transform
self.train = train
curr_dir=os.path.join(root_dir,'train' if train else 'test')
file=open(os.path.join(root_dir,'train_Data.txt' if train else 'test_Data.txt'))
self.files={os.path.join(curr_dir,line.strip().split(',')[0]):line.strip().split(',')[1] for line in file}
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img_name, label = list(self.files.items())[idx]
image = PIL.Image.open(img_name).convert("RGB") # A few images are grayscale
if self.transform:
image = self.transform(image)
sample = (image, label)
return sample