-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloaders.py
84 lines (49 loc) · 3.49 KB
/
loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from partition import *
import torchvision.transforms as transforms
import torchvision
from torchvision import datasets
from copy import copy
class CIFAR100_withIndex(datasets.CIFAR100):
def __getitem__(self, index):
img, label = super(CIFAR100_withIndex, self).__getitem__(index)
return (img, label, index)
def loader_build(dataset,mode, partition,splits,n_clients,beta,batch_size, common_dataset_size):
transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
if dataset=='cifar10':
trainSet = torchvision.datasets.CIFAR10(root='./CifarTrainData', train=True,
download=True, transform=transform)
trainSet_val= torchvision.datasets.CIFAR10(root='./CifarTrainData', train=True,
download=True, transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))]))
testSet = torchvision.datasets.CIFAR10(root='./CifarTrainData', train=False,
download=True, transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))]))
elif dataset=='cifar100':
trainSet = CIFAR100_withIndex(root='./CifarTrainData', train=True,
download=True, transform=transform)
trainSet_val = CIFAR100_withIndex(root='./CifarTrainData', train=True,
download=True, transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
testSet = CIFAR100_withIndex(root='./CifarTrainData', train=False,
download=True, transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]))
if dataset=='cifar10':
n_classes=10
elif dataset=='cifar100':
n_classes=100
test_loader = torch.utils.data.DataLoader(testSet)
if mode == 'distillation':
train_loader_list, valid_loader_list, y_train, _, train_set, valid_set, valid_dataset_server = \
client_subset_creation(partition, trainSet, splits, n_clients , beta, batch_size, mode, common_dataset_size)
_, _, _, valid_loader_server,_,_,_ = \
client_subset_creation(partition, trainSet_val, splits, n_clients , beta, batch_size, mode,common_dataset_size)
elif mode=='traditional':
train_loader_list, valid_loader_list, y_train, _, train_set, valid_set, valid_dataset_server = \
client_subset_creation(partition,trainSet_val,splits,n_clients,beta, batch_size, mode, common_dataset_size)
_, _, _, valid_loader_server, _, _, _ = \
client_subset_creation(partition, trainSet, splits, n_clients , beta, batch_size, mode, common_dataset_size)
return train_loader_list, valid_loader_list, test_loader, n_classes, valid_loader_server, train_set, valid_set, valid_dataset_server,testSet