-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
135 lines (125 loc) · 6.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torchvision
import torchvision.transforms as transforms
class BinaryDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, return_idx=False):
x, y = torch.load(root)
self.data = x
self.labels = y
self.root = root
self.transform = transform
self.return_idx = return_idx
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x_t = self.data[idx].type(torch.float)
y_t = self.labels[idx]
if self.transform:
x_t = self.transform(x_t)
if self.return_idx:
return (x_t, y_t, idx)
else:
return (x_t, y_t)
class IndexedDataset(torch.utils.data.Dataset):
"""
Wraps another dataset to sample from. Returns the sampled indices during iteration.
In other words, instead of producing (X, y) it produces (X, y, idx)
source: https://github.com/tneumann/minimal_glo/blob/master/glo.py
"""
def __init__(self, base_dataset):
self.base = base_dataset
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
img, label = self.base[idx]
return (img, label, idx)
def get_dataset(name, batch_size, test_batch=10000, embedding=False, return_idx=False):
if name == 'mnist':
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
testset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
if return_idx:
trainset = IndexedDataset(trainset)
testset = IndexedDataset(testset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, num_workers=4)
train_size = 60000
test_size = 10000
num_of_classes = 10
elif name == 'emnist':
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
trainset = torchvision.datasets.EMNIST(root='./data',train=True,split='balanced',download=True,transform=transform)
testset = torchvision.datasets.EMNIST(root='./data',train=False,split='balanced',download=True,transform=transform)
trainset.train_data = trainset.train_data.permute(0, 2, 1)
testset.test_data = testset.test_data.permute(0, 2, 1)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, num_workers=4)
train_size = 112800
test_size = 10000
num_of_classes = 47
elif name == 'fashion':
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, num_workers=4)
train_size = 60000
test_size = 10000
num_of_classes = 10
elif name == 'cifar':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, num_workers=4)
train_size = 50000
test_size = 10000
num_of_classes = 10
elif name == 'stl':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
trainset = torchvision.datasets.STL10(root='./data', split='train', download=True,transform=transform)
testset = torchvision.datasets.STL10(root='./data', split='test', download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, num_workers=4)
train_size = 5000
test_size = 8000
num_of_classes = 10
else:
if embedding:
X = torch.load(name)
dev = torch.device('cuda:0')
X = X.to(dev)
mu = X.mean(dim=0)
std = X.std(dim=0)
X = ((X-mu)/std).cpu()
dataset_size = X.shape[0]
Y = torch.zeros(dataset_size, dtype=torch.int)
dataset = torch.utils.data.TensorDataset(X, Y)
else:
dataset = BinaryDataset(name, transform=transforms.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]), return_idx=return_idx)
dataset_size = dataset.__len__()
R = torch.randperm(dataset_size)
train_indices = torch.utils.data.SubsetRandomSampler(R[test_batch:])
test_indices = torch.utils.data.SubsetRandomSampler(R[:test_batch])
trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_indices)
testloader = torch.utils.data.DataLoader(dataset, batch_size=100, sampler=test_indices)
num_of_classes = 1
train_size = dataset_size-test_batch
test_size = test_batch
if embedding:
return trainloader, testloader, train_size, test_size, num_of_classes, mu, std
else:
return trainloader, testloader, train_size, test_size, num_of_classes