-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdeep_clustering_dataset.py
156 lines (119 loc) · 5.68 KB
/
deep_clustering_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Created on Tuesday April 20 2020
@author: Ahmad Mustapha ([email protected])
"""
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets import VisionDataset
import torch
import copy
import numpy as np
import os
class DeepClusteringDataset(Dataset):
""" A Datset **Decorator** that adds changing labels to pseudolabels
functionality.
Args:
original_dataset (list): Pytorch Dataset
transform (callable, optional): a function/transform that takes in
an PIL image and returns a
transformed version
"""
def __init__(self, original_dataset, transform=None):
self.dataset = copy.deepcopy(original_dataset)
self.original_dataset = original_dataset
if isinstance(self.original_dataset, ImageFolder):
self.imgs = self.dataset.imgs
elif isinstance(self.original_dataset, VisionDataset):
self.data = self.dataset.data
if hasattr(self.dataset, "targets"):
self.targets = self.dataset.targets
elif hasattr(self.dataset, "labels"):
self.targets = self.dataset.labels
else:
raise Exception("The entered dataset is not supported - no labels/targets variables")
else:
raise Exception("The passed original dataset is of unsupported dataset instance")
if transform:
self.dataset.transform = transform
else:
self.dataset.transform = original_dataset.transform
self.transform = self.dataset.transform
self.instance_wise_weights= None
def set_transform(self, transform):
self.dataset.transform = transform
self.transform = self.dataset.transform
def __len__(self):
return self.dataset.__len__()
def __getitem__(self, index):
if self.instance_wise_weights:
return self.dataset.__getitem__(index)+ (self.instance_wise_weights[index],)
else:
return self.dataset.__getitem__(index)
def get_targets(self):
if isinstance(self.original_dataset, ImageFolder):
return [target for (path ,target) in self.original_dataset.imgs]
elif isinstance(self.original_dataset, VisionDataset):
if hasattr(self.original_dataset, "targets"):
return self.original_dataset.targets
elif hasattr(self.original_dataset, "labels"):
return self.original_dataset.labels
else:
raise Exception("The entered dataset is not supported - no labels/targets variables")
else:
raise Exception("The passed original dataset is of unsupported dataset instance")
def set_pseudolabels(self, pseudolabels):
if isinstance(self.dataset, ImageFolder):
for i, pseudolabel in enumerate(pseudolabels):
self.imgs[i] = (self.imgs[i][0], torch.tensor(pseudolabel, dtype=torch.long))
elif isinstance(self.dataset, VisionDataset):
if hasattr(self.original_dataset, "targets"):
self.dataset.targets = torch.tensor(pseudolabels, dtype=torch.long)
self.targets = self.dataset.targets
elif hasattr(self.original_dataset, "labels"):
self.dataset.labels = torch.tensor(pseudolabels, dtype=torch.long)
self.targets = self.dataset.labels
else:
raise Exception("The entered dataset is not supported - no labels/targets variables")
else:
raise Exception("The passed original dataset is of unsupported dataset instance")
# TODO - remove if unused
def set_instance_wise_weights(self, weights):
self.instance_wise_weights = weights
return
def unset_instance_wise_weights(self):
self.instance_wise_weights = None
return
def save_pseudolabels(self, path , tag):
if not os.path.isdir(path):
os.mkdir(path)
grouped_indices = self.group_indices_by_labels()
np.save(os.path.join(path, str(tag)), grouped_indices )
def get_pseudolabels(self):
if isinstance(self.dataset, ImageFolder):
return [pseudolabel.item() for (path, pseudolabel) in self.imgs]
elif isinstance(self.dataset, VisionDataset):
return self.targets
else:
raise Exception("The passed original dataset is of unsupported dataset instance")
def unset_pseudolabels(self):
if isinstance(self.dataset, ImageFolder):
self.imgs= self.original_dataset.imgs
elif isinstance(self.dataset, VisionDataset):
self.targets = self.original_dataset.targets
else:
raise Exception("The passed original dataset is of unsupported dataset instance")
def group_indices_by_labels(self):
if isinstance(self.dataset, ImageFolder):
n_labels = len(np.unique([ label for (_, label) in self.imgs]))
grouped_indices = [[] for i in range(n_labels)]
for i, (path, label) in enumerate(self.imgs):
grouped_indices[label].append(i)
return grouped_indices
elif isinstance(self.dataset, VisionDataset):
n_labels = len(np.unique(self.targets))
grouped_indices = [[] for i in range(n_labels)]
for i, label in enumerate(self.targets):
grouped_indices[label].append(i)
return grouped_indices
else:
raise Exception("The passed original dataset is of unsupported dataset instance")