From 35ace6c855ca26eb06f9da6691dbeec3f14560a5 Mon Sep 17 00:00:00 2001 From: Ilya Kavalerov Date: Wed, 11 Dec 2019 15:29:27 -0500 Subject: [PATCH] Add support for STL10 at resolutions 32, 48, and 96 --- README.md | 21 +++++++++- dataset/cifar10.py | 88 ++++++++++++++++++++++++++++++++++++++++-- models/wideresnet.py | 92 +++++++++++++++++++++++++++++++++++++++++++- train.py | 68 +++++++++++++++++++++----------- 4 files changed, 240 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 9842801..3cf7444 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249). The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch). -Now only experiments on CIFAR-10 are available. +Experiments on CIFAR-10 and STL-10 are available. This repository carefully implemented important details of the official implementation to reproduce the results. @@ -31,12 +31,22 @@ Train the model by 4000 labeled data of CIFAR-10 dataset: python train.py --gpu --n-labeled 4000 --out cifar10@4000 ``` +Train STL-10: + +``` +python train.py --resolution <32|48|96> --out stl10 --data_root data/stl10 --dataset STL10 --n-labeled 5000 +``` + + ### Monitoring training progress ``` tensorboard.sh --port 6006 --logdir cifar10@250 ``` ## Results (Accuracy) + +### CIFAR10 + | #Labels | 250 | 500 | 1000 | 2000| 4000 | |:---|:---:|:---:|:---:|:---:|:---:| |Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06| @@ -44,6 +54,15 @@ tensorboard.sh --port 6006 --logdir cifar10@250 (Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. ) +### STL10 + +Using the entire 5000 point dataset: + +| Resolution | 32 | 48 | 96 | +|:---|:---:|:---:|:---:| +|Paper | - | - | 94.41 | +|This code | 82.69 | 86.41 | 91.33 | + ## References ``` @article{berthelot2019mixmatch, diff --git a/dataset/cifar10.py b/dataset/cifar10.py index f74b963..a2a986b 100644 --- a/dataset/cifar10.py +++ b/dataset/cifar10.py @@ -4,6 +4,29 @@ import torchvision import torch +import torchvision.transforms as transforms +from torchvision.datasets import STL10 + +# dict containing supported datasets with their image resolutions +imsize_dict = {'C10': 32, 'STL10': 96} + +cifar10_mean = (0.4914, 0.4822, 0.4465) +cifar10_std = (0.2023, 0.1994, 0.2010) + +stl10_mean = (0.4914, 0.4822, 0.4465) +stl10_std = (0.2471, 0.2435, 0.2616) + +dataset_stats = { + 'C10' : { + 'mean': cifar10_mean, + 'std': cifar10_std + }, + 'STL10' : { + 'mean': stl10_mean, + 'std': stl10_std + }, +} + class TransformTwice: def __init__(self, transform): self.transform = transform @@ -27,7 +50,67 @@ def get_cifar10(root, n_labeled, print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}") return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset - + +def get_stl10(root, + transform_train=None, transform_val=None, + download=True): + + training_set = STL10(root, split='train', download=True, transform=transform_train) + dev_set = STL10(root, split='test', download=True, transform=transform_val) + unl_set = STL10(root, split='unlabeled', download=True, transform=transform_train) + + print (f"#Labeled: {len(training_set)} #Unlabeled: {len(unl_set)} #Val: {len(dev_set)} #Test: None") + return training_set, unl_set, dev_set, None + +def validate_dataset(dataset): + if dataset not in imsize_dict: + raise ValueError("Dataset %s not supported." % dataset) + +def get_transforms(dataset, resolution): + dataset_resolution = imsize_dict[dataset] + + if dataset == 'STL10': + if resolution == 96: + transform_train = transforms.Compose([ + transforms.RandomCrop(96, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(stl10_mean, stl10_std), + ]) + else: + transform_train = transforms.Compose([ + transforms.RandomCrop(86, padding=0), + transforms.Resize(resolution), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(stl10_mean, stl10_std), + ]) + if dataset_resolution == resolution: + transform_val = transforms.Compose([ + ToTensor(), + transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']), + ]) + else: + transform_val = transforms.Compose([ + transforms.Resize(resolution), + ToTensor(), + transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']), + ]) + if dataset == 'C10': + # already normalized in the CIFAR10_labeled/CIFAR10_unlabeled class + transform_train = transforms.Compose([ + RandomPadandCrop(resolution), + RandomFlip(), + ToTensor(), + ]) + transform_val = transforms.Compose([ + ToTensor(), + ]) + + + return transform_train, transform_val + + def train_val_split(labels, n_labeled_per_class): labels = np.array(labels) @@ -47,9 +130,6 @@ def train_val_split(labels, n_labeled_per_class): return train_labeled_idxs, train_unlabeled_idxs, val_idxs -cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 -cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 - def normalise(x, mean=cifar10_mean, std=cifar10_std): x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)] x -= mean*255 diff --git a/models/wideresnet.py b/models/wideresnet.py index 5b74fad..903ff9f 100644 --- a/models/wideresnet.py +++ b/models/wideresnet.py @@ -3,7 +3,6 @@ import torch.nn as nn import torch.nn.functional as F - class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False): super(BasicBlock, self).__init__() @@ -16,7 +15,7 @@ def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_ self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) self.droprate = dropRate - self.equalInOut = (in_planes == out_planes) + self.equalInOut = (in_planes == out_planes) and (stride == 1) self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None self.activate_before_residual = activate_before_residual @@ -84,4 +83,93 @@ def forward(self, x): out = self.relu(self.bn1(out)) out = F.avg_pool2d(out, 8) out = out.view(-1, self.nChannels) + return self.fc(out) + +class WideResNet48(nn.Module): + def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0): + super(WideResNet48, self).__init__() + nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] + assert((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) + self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.zero_() + + def forward(self, x): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 12) + out = out.view(-1, self.nChannels) + return self.fc(out) + +class WideResNet96(nn.Module): + def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0): + super(WideResNet96, self).__init__() + nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor, 64*widen_factor] + assert((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) + # 4th block + self.block4 = NetworkBlock(n, nChannels[3], nChannels[4], block, 2, dropRate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) + self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.zero_() + + def forward(self, x): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.block4(out) + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 12) + out = out.view(-1, self.nChannels) return self.fc(out) \ No newline at end of file diff --git a/train.py b/train.py index eb38221..8c66243 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,6 @@ import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data as data -import torchvision.transforms as transforms import torch.nn.functional as F import models.wideresnet as models @@ -51,7 +50,12 @@ parser.add_argument('--lambda-u', default=75, type=float) parser.add_argument('--T', default=0.5, type=float) parser.add_argument('--ema-decay', default=0.999, type=float) - +parser.add_argument('--resolution', default=32, type=int) +# Data options +parser.add_argument('--data_root', default='data', + help='Data directory') +parser.add_argument('--dataset', default='C10', + help='Dataset name: C10 | STL10') args = parser.parse_args() state = {k: v for k, v in args._get_kwargs()} @@ -67,6 +71,13 @@ best_acc = 0 # best test accuracy +which_model = models.WideResNet +if args.resolution == 96: + which_model = models.WideResNet96 +if args.resolution == 48: + which_model = models.WideResNet48 + + def main(): global best_acc @@ -74,28 +85,31 @@ def main(): mkdir_p(args.out) # Data - print(f'==> Preparing cifar10') - transform_train = transforms.Compose([ - dataset.RandomPadandCrop(32), - dataset.RandomFlip(), - dataset.ToTensor(), - ]) - - transform_val = transforms.Compose([ - dataset.ToTensor(), - ]) - - train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val) + dataset.validate_dataset(args.dataset) + print(f'==> Preparing %s' % args.dataset) + + transform_train, transform_val = dataset.get_transforms(args.dataset, args.resolution) + + if args.dataset == 'C10': + train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(args.data_root, args.n_labeled, transform_train=transform_train, transform_val=transform_val) + elif args.dataset == 'STL10': + if args.n_labeled != 5000: + raise ValueError("For STL10 the only supported n_labeled is 5000") + train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_stl10(args.data_root, transform_train=transform_train, transform_val=transform_val) + labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) - val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) - test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) + val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) + if test_set is not None: + test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) + else: + test_loader = None # Model - print("==> creating WRN-28-2") + print("==> creating %s" % which_model.__name__) def create_model(ema=False): - model = models.WideResNet(num_classes=10) + model = which_model(num_classes=10) model = model.cuda() if ema: @@ -123,7 +137,6 @@ def create_model(ema=False): # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' - args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] @@ -146,7 +159,10 @@ def create_model(ema=False): train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda) _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') - test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') + if test_loader is not None: + test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') + else: + test_loss, test_acc = [-1, -1] step = args.val_iteration * (epoch + 1) @@ -206,10 +222,18 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti inputs_x, targets_x = labeled_train_iter.next() try: - (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() + if args.dataset == 'STL10': + inputs_u, _ = unlabeled_train_iter.next() + inputs_u2, _ = unlabeled_train_iter.next() + else: + (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() except: unlabeled_train_iter = iter(unlabeled_trainloader) - (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() + if args.dataset == 'STL10': + inputs_u, _ = unlabeled_train_iter.next() + inputs_u2, _ = unlabeled_train_iter.next() + else: + (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() # measure data loading time data_time.update(time.time() - end)