Skip to content

Commit

Permalink
Add support for STL10 at resolutions 32, 48, and 96
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyakava committed Mar 25, 2020
1 parent a738cc9 commit ef48201
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 29 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249).
The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch).

Now only experiments on CIFAR-10 are available.
Experiments on CIFAR-10 and STL-10 are available.

This repository carefully implemented important details of the official implementation to reproduce the results.

Expand Down Expand Up @@ -31,19 +31,38 @@ Train the model by 4000 labeled data of CIFAR-10 dataset:
python train.py --gpu <gpu_id> --n-labeled 4000 --out cifar10@4000
```

Train STL-10:

```
python train.py --resolution <32|48|96> --out stl10 --data_root data/stl10 --dataset STL10 --n-labeled 5000
```


### Monitoring training progress
```
tensorboard.sh --port 6006 --logdir cifar10@250
```

## Results (Accuracy)

### CIFAR10

| #Labels | 250 | 500 | 1000 | 2000| 4000 |
|:---|:---:|:---:|:---:|:---:|:---:|
|Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06|
|This code | 88.71 | 88.96 | 90.52 | 92.23 | 93.52 |

(Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. )

### STL10

Using the entire 5000 point dataset:

| Resolution | 32 | 48 | 96 |
|:---|:---:|:---:|:---:|
|Paper | - | - | 94.41 |
|This code | 82.69 | 86.41 | 91.33 |

## References
```
@article{berthelot2019mixmatch,
Expand Down
88 changes: 84 additions & 4 deletions dataset/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@
import torchvision
import torch

import torchvision.transforms as transforms
from torchvision.datasets import STL10

# dict containing supported datasets with their image resolutions
imsize_dict = {'C10': 32, 'STL10': 96}

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

stl10_mean = (0.4914, 0.4822, 0.4465)
stl10_std = (0.2471, 0.2435, 0.2616)

dataset_stats = {
'C10' : {
'mean': cifar10_mean,
'std': cifar10_std
},
'STL10' : {
'mean': stl10_mean,
'std': stl10_std
},
}

class TransformTwice:
def __init__(self, transform):
self.transform = transform
Expand All @@ -27,7 +50,67 @@ def get_cifar10(root, n_labeled,

print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset


def get_stl10(root,
transform_train=None, transform_val=None,
download=True):

training_set = STL10(root, split='train', download=True, transform=transform_train)
dev_set = STL10(root, split='test', download=True, transform=transform_val)
unl_set = STL10(root, split='unlabeled', download=True, transform=transform_train)

print (f"#Labeled: {len(training_set)} #Unlabeled: {len(unl_set)} #Val: {len(dev_set)} #Test: None")
return training_set, unl_set, dev_set, None

def validate_dataset(dataset):
if dataset not in imsize_dict:
raise ValueError("Dataset %s not supported." % dataset)

def get_transforms(dataset, resolution):
dataset_resolution = imsize_dict[dataset]

if dataset == 'STL10':
if resolution == 96:
transform_train = transforms.Compose([
transforms.RandomCrop(96, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(stl10_mean, stl10_std),
])
else:
transform_train = transforms.Compose([
transforms.RandomCrop(86, padding=0),
transforms.Resize(resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(stl10_mean, stl10_std),
])
if dataset_resolution == resolution:
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
])
else:
transform_val = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize(dataset_stats[dataset]['mean'], dataset_stats[dataset]['std']),
])
if dataset == 'C10':
# already normalized in the CIFAR10_labeled/CIFAR10_unlabeled class
transform_train = transforms.Compose([
RandomPadandCrop(resolution),
RandomFlip(),
ToTensor(),
])
transform_val = transforms.Compose([
ToTensor(),
])


return transform_train, transform_val



def train_val_split(labels, n_labeled_per_class):
labels = np.array(labels)
Expand All @@ -47,9 +130,6 @@ def train_val_split(labels, n_labeled_per_class):

return train_labeled_idxs, train_unlabeled_idxs, val_idxs

cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

def normalise(x, mean=cifar10_mean, std=cifar10_std):
x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
x -= mean*255
Expand Down
92 changes: 90 additions & 2 deletions models/wideresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
super(BasicBlock, self).__init__()
Expand All @@ -16,7 +15,7 @@ def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_planes == out_planes)
self.equalInOut = (in_planes == out_planes) and (stride == 1)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False) or None
self.activate_before_residual = activate_before_residual
Expand Down Expand Up @@ -84,4 +83,93 @@ def forward(self, x):
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)

class WideResNet48(nn.Module):
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
super(WideResNet48, self).__init__()
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()

def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 12)
out = out.view(-1, self.nChannels)
return self.fc(out)

class WideResNet96(nn.Module):
def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
super(WideResNet96, self).__init__()
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor, 64*widen_factor]
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
# 4th block
self.block4 = NetworkBlock(n, nChannels[3], nChannels[4], block, 2, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[4], momentum=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.fc = nn.Linear(nChannels[4], num_classes)
self.nChannels = nChannels[4]

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()

def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 12)
out = out.view(-1, self.nChannels)
return self.fc(out)
68 changes: 46 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F

import models.wideresnet as models
Expand Down Expand Up @@ -51,7 +50,12 @@
parser.add_argument('--lambda-u', default=75, type=float)
parser.add_argument('--T', default=0.5, type=float)
parser.add_argument('--ema-decay', default=0.999, type=float)

parser.add_argument('--resolution', default=32, type=int)
# Data options
parser.add_argument('--data_root', default='data',
help='Data directory')
parser.add_argument('--dataset', default='C10',
help='Dataset name: C10 | STL10')

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}
Expand All @@ -67,35 +71,45 @@

best_acc = 0 # best test accuracy

which_model = models.WideResNet
if args.resolution == 96:
which_model = models.WideResNet96
if args.resolution == 48:
which_model = models.WideResNet48


def main():
global best_acc

if not os.path.isdir(args.out):
mkdir_p(args.out)

# Data
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
dataset.RandomPadandCrop(32),
dataset.RandomFlip(),
dataset.ToTensor(),
])

transform_val = transforms.Compose([
dataset.ToTensor(),
])

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val)
dataset.validate_dataset(args.dataset)
print(f'==> Preparing %s' % args.dataset)

transform_train, transform_val = dataset.get_transforms(args.dataset, args.resolution)

if args.dataset == 'C10':
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(args.data_root, args.n_labeled, transform_train=transform_train, transform_val=transform_val)
elif args.dataset == 'STL10':
if args.n_labeled != 5000:
raise ValueError("For STL10 the only supported n_labeled is 5000")
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_stl10(args.data_root, transform_train=transform_train, transform_val=transform_val)

labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
if test_set is not None:
test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
else:
test_loader = None

# Model
print("==> creating WRN-28-2")
print("==> creating %s" % which_model.__name__)

def create_model(ema=False):
model = models.WideResNet(num_classes=10)
model = which_model(num_classes=10)
model = model.cuda()

if ema:
Expand Down Expand Up @@ -123,7 +137,6 @@ def create_model(ema=False):
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
args.out = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_acc = checkpoint['best_acc']
start_epoch = checkpoint['epoch']
Expand All @@ -146,7 +159,10 @@ def create_model(ema=False):
train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda)
_, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
if test_loader is not None:
test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
else:
test_loss, test_acc = [-1, -1]

step = args.val_iteration * (epoch + 1)

Expand Down Expand Up @@ -206,10 +222,18 @@ def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_opti
inputs_x, targets_x = labeled_train_iter.next()

try:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
if args.dataset == 'STL10':
inputs_u, _ = unlabeled_train_iter.next()
inputs_u2, _ = unlabeled_train_iter.next()
else:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
if args.dataset == 'STL10':
inputs_u, _ = unlabeled_train_iter.next()
inputs_u2, _ = unlabeled_train_iter.next()
else:
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()

# measure data loading time
data_time.update(time.time() - end)
Expand Down

0 comments on commit ef48201

Please sign in to comment.