-
Notifications
You must be signed in to change notification settings - Fork 1
/
CustomDataSet.py
42 lines (36 loc) · 1.38 KB
/
CustomDataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from torch.utils.data import Dataset
from natsort import natsorted
from PIL import Image
import numpy as np
import torch
import glob
import os
class CustomDataSet(Dataset):
def __init__(self, main_dir, transform, labelmat):
self.main_dir = main_dir
self.transforms = transform
self.all_imgs = glob.glob(os.path.join(
main_dir, '**/*.tif'), recursive=False)
self.total_imgs = natsorted(self.all_imgs)
self.xlabels = labelmat
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
# print(idx,len(self.total_imgs))
img_loc = self.total_imgs[idx]
# print(img_loc)
imagebaselabel = os.path.splitext(os.path.basename(img_loc))[0]
label = self.xlabels[np.where(
self.xlabels[:, 0] == imagebaselabel), 1:].reshape(17).astype(np.int64)
# print(label,label.shape)
tensor_label = torch.from_numpy(label)
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transforms(image)
return tensor_image, tensor_label
def __getlabel__(self, idx):
img_loc = self.total_imgs[idx]
# print(img_loc)
imagebaselabel = os.path.splitext(os.path.basename(img_loc))[0]
label = self.xlabels[np.where(
self.xlabels[:, 0] == imagebaselabel), 1:].reshape(17).astype(np.int64)
return label