-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
89 lines (74 loc) · 3.1 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from paddle.vision import transforms
from PIL import Image
import os
import paddle
import glob
import numpy as np
def get_data_transforms(size, isize):
mean_train = [0.485, 0.456, 0.406]
std_train = [0.229, 0.224, 0.225]
data_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.CenterCrop(isize),
transforms.Normalize(mean=mean_train,
std=std_train)])
gt_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.CenterCrop(isize),
transforms.ToTensor()])
return data_transforms, gt_transforms
class MVTecDataset(paddle.io.Dataset):
"""
test 数据集定义
"""
def __init__(self, root, transform, gt_transform, phase):
if phase == 'train':
self.img_path = os.path.join(root, 'train')
else:
self.img_path = os.path.join(root, 'test')
self.gt_path = os.path.join(root, 'ground_truth')
self.transform = transform
self.gt_transform = gt_transform
# load dataset
self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset() # self.labels => good : 0, anomaly : 1
def load_dataset(self):
img_tot_paths = []
gt_tot_paths = []
tot_labels = []
tot_types = []
defect_types = os.listdir(self.img_path)
for defect_type in defect_types:
if defect_type == 'good': # good类型的图片label=0
img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
img_tot_paths.extend(img_paths)
gt_tot_paths.extend([0] * len(img_paths))
tot_labels.extend([0] * len(img_paths))
tot_types.extend(['good'] * len(img_paths))
else: # 其他类型的图片label=1
img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png")
gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png")
img_paths.sort()
gt_paths.sort()
img_tot_paths.extend(img_paths)
gt_tot_paths.extend(gt_paths)
tot_labels.extend([1] * len(img_paths))
tot_types.extend([defect_type] * len(img_paths))
assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
return img_tot_paths, gt_tot_paths, tot_labels, tot_types
def __len__(self):
"""
返回数据集总数
"""
return len(self.img_paths)
def __getitem__(self, idx):
img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx]
img = Image.open(img_path).convert('RGB')
img = self.transform(img)
if gt == 0:
gt = paddle.zeros([1, img.shape[-2], img.shape[-2]])
else:
gt = Image.open(gt)
gt = self.gt_transform(gt)
assert img.shape[1:] == gt.shape[1:], "image.size != gt.size !!!"
return img, gt, label, img_type