Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoudw-zdw committed Mar 13, 2023
0 parents commit 5b9a39e
Show file tree
Hide file tree
Showing 115 changed files with 7,991 additions and 0 deletions.
82 changes: 82 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need
<div align="center">

<div>
<a href='http://www.lamda.nju.edu.cn/zhoudw' target='_blank'>Da-Wei Zhou</a><sup>1</sup>&emsp;
<a href='http://www.lamda.nju.edu.cn/yehj' target='_blank'>Han-Jia Ye</a><sup>1</sup>&emsp;
<a href='http://www.lamda.nju.edu.cn/zhandc' target='_blank'>De-Chuan Zhan</a><sup>1</sup>&emsp;
<a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu </a><sup>2</sup>
</div>
<div>
<sup>1</sup>State Key Laboratory for Novel Software Technology, Nanjing University&emsp;

<sup>2</sup>S-Lab, Nanyang Technological University&emsp;
</div>
</div>

The code repository for "Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need" in PyTorch.





## Updates
[03/2023] Code has been released.

## Introduction
Class-incremental learning (CIL) aims to adapt to emerging new classes without forgetting old ones. Traditional CIL models are trained from scratch to continually acquire knowledge as data evolves. Recently, pre-training has achieved substantial progress, making vast pre-trained models (PTMs) accessible for CIL. Contrary to traditional methods, PTMs possess generalizable embeddings, which can be easily transferred for CIL. In this work, we revisit CIL with PTMs and argue that the core factors in CIL are adaptivity for model updating and generalizability for knowledge transferring. 1) We first reveal that frozen PTM can already provide generalizable
embeddings for CIL. Surprisingly, a simple baseline (**SimpleCIL**) which continually sets the classifiers of PTM to prototype features can beat state-of-the-art even without training on the downstream task. 2) Due to the distribution gap between pre-trained and downstream datasets, PTM can be further cultivated with adaptivity via model adaptation. We propose **ADapt And Merge** (ADAM), which aggregates the embeddings of PTM and adapted models for classifier construction. ADAM is a general framework that can be orthogonally combined with any parameter-efficient tuning method, which holds the advantages of PTM’s generalizability and adapted model’s adaptivity. 3) Additionally, we find previous benchmarks are unsuitable in the era of PTM due to data overlapping and propose four new benchmarks for assessment, namely ImageNet-A, ObjectNet, OmniBenchmark, and VTAB. Extensive experiments validate the effectiveness of ADAM with a unified and concise framework.

<div align="center">
<img src="imgs/adam.png" width="95%">

<h3>TL;DR</h3>

A simple baseline (**SimpleCIL**) beats SOTA even without training on the downstream task. **ADapt And Merge** (ADAM) extends SimpleCIL with better adaptivity and generalizability. Four new benchmarks are proposed for assessment.
</div>



## Requirements
### Environment
1. [torch 1.11.0](https://github.com/pytorch/pytorch)
2. [torchvision 0.12.0](https://github.com/pytorch/vision)
3. [timm 0.6.12](https://github.com/huggingface/pytorch-image-models)


### Dataset
We provide the processed datasets as follows:
- **CIFAR100**: will be automatically downloaded by the code.
- **CUB200**: Google Drive: [link](https://drive.google.com/file/d/1XbUpnWpJPnItt5zQ6sHJnsjPncnNLvWb/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EVV4pT9VJ9pBrVs2x0lcwd0BlVQCtSrdbLVfhuajMry-lA?e=L6Wjsc)
- **ImageNet-R**: Google Drive: [link](https://drive.google.com/file/d/1SG4TbiL8_DooekztyCVK8mPmfhMo8fkR/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EU4jyLL29CtBsZkB6y-JSbgBzWF5YHhBAUz1Qw8qM2954A?e=hlWpNW)
- **ImageNet-A**:Google Drive: [link](https://drive.google.com/file/d/19l52ua_vvTtttgVRziCZJjal0TPE9f2p/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/ERYi36eg9b1KkfEplgFTW3gBg1otwWwkQPSml0igWBC46A?e=NiTUkL)
- **OmniBenchmark**: Google Drive: [link](https://drive.google.com/file/d/1AbCP3zBMtv_TDXJypOCnOgX8hJmvJm3u/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EcoUATKl24JFo3jBMnTV2WcBwkuyBH0TmCAy6Lml1gOHJA?e=eCNcoA)
- **VTAB**: Google Drive: [link](https://drive.google.com/file/d/1xUiwlnx4k0oDhYi26KL5KwrCAya-mvJ_/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EQyTP1nOIH5PrfhXtpPgKQ8BlEFW2Erda1t7Kdi3Al-ePw?e=Yt4RnV)
- **ObjectNet**: Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EZFv9uaaO1hBj7Y40KoCvYkBnuUZHnHnjMda6obiDpiIWw?e=4n8Kpy) You can also refer to the [filelist](https://drive.google.com/file/d/147Mta-HcENF6IhZ8dvPnZ93Romcie7T6/view?usp=sharing) if the file is too large to download.

These subsets are sampled from the original datasets. Please note that I do not have the right to distribute these datasets. If the distribution violates the license, I shall provide the filenames instead.

You need to modify the path of the datasets in `./utils/data.py` according to your own path.

## Running scripts
Please follow the settings in the `exps` folder to prepare your json files, and then run:

```
python main.py --config ./exps/[configname].json
```


## Acknolegment
This repo is based on [CIL_Survey](https://github.com/zhoudw-zdw/CIL_Survey) and [PyCIL](https://github.com/G-U-N/PyCIL).

The implemenations of parameter-efficient tuning methods are based on [VPT](https://github.com/sagizty/VPT), [AdaptFormer](https://github.com/ShoufaChen/AdaptFormer), and [SSF](https://github.com/dongzelian/SSF).

## Correspondence
If you have any questions, please contact me via [email](mailto:[email protected]) or open an [issue](https://github.com/zhoudw-zdw/RevisitingCIL/issues/new).


<div align="center">

![visitors](https://visitor-badge.glitch.me/badge?page_id=zhoudw-zdw.RevisitingCIL&left_color=green&right_color=red)

</div>
Empty file added convs/__init__.py
Empty file.
Binary file added convs/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/bamboo_vit.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/cifar_resnet.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/clip.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/clipmodel.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/linears.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/mae.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added convs/__pycache__/resnet.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/resnet_cbam.cpython-39.pyc
Binary file not shown.
Binary file added convs/__pycache__/resnet_scale.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added convs/__pycache__/ucir_resnet.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added convs/__pycache__/vpt.cpython-39.pyc
Binary file not shown.
198 changes: 198 additions & 0 deletions convs/cifar_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
'''
Reference:
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
'''
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class DownsampleA(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleA, self).__init__()
assert stride == 2
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)

def forward(self, x):
x = self.avg(x)
return torch.cat((x, x.mul(0)), 1)


class DownsampleB(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleB, self).__init__()
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(nOut)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x


class DownsampleC(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleC, self).__init__()
assert stride != 1 or nIn != nOut
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)

def forward(self, x):
x = self.conv(x)
return x


class DownsampleD(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleD, self).__init__()
assert stride == 2
self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(nOut)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x


class ResNetBasicblock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResNetBasicblock, self).__init__()

self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_a = nn.BatchNorm2d(planes)

self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)

self.downsample = downsample

def forward(self, x):
residual = x

basicblock = self.conv_a(x)
basicblock = self.bn_a(basicblock)
basicblock = F.relu(basicblock, inplace=True)

basicblock = self.conv_b(basicblock)
basicblock = self.bn_b(basicblock)

if self.downsample is not None:
residual = self.downsample(x)

return F.relu(residual + basicblock, inplace=True)


class CifarResNet(nn.Module):
"""
ResNet optimized for the Cifar Dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""

def __init__(self, block, depth, channels=3):
super(CifarResNet, self).__init__()

# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6

self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_1 = nn.BatchNorm2d(16)

self.inplanes = 16
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
self.avgpool = nn.AvgPool2d(8)
self.out_dim = 64 * block.expansion
self.fc = nn.Linear(64*block.expansion, 10)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
# m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
x = F.relu(self.bn_1(x), inplace=True)

x_1 = self.stage_1(x) # [bs, 16, 32, 32]
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]

pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
features = pooled.view(pooled.size(0), -1) # [bs, 64]

return {
'fmaps': [x_1, x_2, x_3],
'features': features
}

@property
def last_conv(self):
return self.stage_3[-1].conv_b


def resnet20mnist():
"""Constructs a ResNet-20 model for MNIST."""
model = CifarResNet(ResNetBasicblock, 20, 1)
return model


def resnet32mnist():
"""Constructs a ResNet-32 model for MNIST."""
model = CifarResNet(ResNetBasicblock, 32, 1)
return model


def resnet20():
"""Constructs a ResNet-20 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 20)
return model


def resnet32():
"""Constructs a ResNet-32 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 32)
return model


def resnet44():
"""Constructs a ResNet-44 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 44)
return model


def resnet56():
"""Constructs a ResNet-56 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 56)
return model


def resnet110():
"""Constructs a ResNet-110 model for CIFAR-10."""
model = CifarResNet(ResNetBasicblock, 110)
return model
Loading

0 comments on commit 5b9a39e

Please sign in to comment.