-
Notifications
You must be signed in to change notification settings - Fork 72
/
dataloader.py
101 lines (78 loc) · 3.95 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import os
from os.path import join, split, isdir, isfile, abspath
import torch
from PIL import Image
import random
import collections
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class SemanLineDataset(Dataset):
def __init__(self, root_dir, label_file, split='train', transform=None, t_transform=None):
lines = [line.rstrip('\n') for line in open(label_file)]
self.image_path = [join(root_dir, i+".jpg") for i in lines]
self.data_path = [join(root_dir, i+".npy") for i in lines]
self.split = split
self.transform = transform
self.t_transform = t_transform
def __getitem__(self, item):
assert isfile(self.image_path[item]), self.image_path[item]
image = Image.open(self.image_path[item]).convert('RGB')
data = np.load(self.data_path[item], allow_pickle=True).item()
hough_space_label8 = data["hough_space_label8"].astype(np.float32)
if self.transform is not None:
image = self.transform(image)
hough_space_label8 = torch.from_numpy(hough_space_label8).unsqueeze(0)
gt_coords = data["coords"]
if self.split == 'val':
return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
elif self.split == 'train':
return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
def __len__(self):
return len(self.image_path)
def collate_fn(self, batch):
images, hough_space_label8, gt_coords, names = list(zip(*batch))
images = torch.stack([image for image in images])
hough_space_label8 = torch.stack([hough_space_label for hough_space_label in hough_space_label8])
return images, hough_space_label8, gt_coords, names
class SemanLineDatasetTest(Dataset):
def __init__(self, root_dir, label_file, transform=None, t_transform=None):
lines = [line.rstrip('\n') for line in open(label_file)]
self.image_path = [join(root_dir, i+".jpg") for i in lines]
self.transform = transform
self.t_transform = t_transform
def __getitem__(self, item):
assert isfile(self.image_path[item]), self.image_path[item]
image = Image.open(self.image_path[item]).convert('RGB')
w, h = image.size
if self.transform is not None:
image = self.transform(image)
return image, self.image_path[item].split('/')[-1], (h, w)
def __len__(self):
return len(self.image_path)
def collate_fn(self, batch):
images, names, sizes = list(zip(*batch))
images = torch.stack([image for image in images])
return images, names, sizes
def get_loader(root_dir, label_file, batch_size, img_size=0, num_thread=4, pin=True, test=False, split='train'):
if test is False:
transform = transforms.Compose([
# transforms.Resize((400, 400)),# Not used for current version.
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = SemanLineDataset(root_dir, label_file, transform=transform, t_transform=None, split=split)
else:
transform = transforms.Compose([
transforms.Resize((400, 400)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = SemanLineDatasetTest(root_dir, label_file, transform=transform, t_transform=None)
if test is False:
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread,
pin_memory=pin, collate_fn=dataset.collate_fn)
else:
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_thread,
pin_memory=pin, collate_fn=dataset.collate_fn)
return data_loader