-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5b9a39e
Showing
115 changed files
with
7,991 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need | ||
<div align="center"> | ||
|
||
<div> | ||
<a href='http://www.lamda.nju.edu.cn/zhoudw' target='_blank'>Da-Wei Zhou</a><sup>1</sup>  | ||
<a href='http://www.lamda.nju.edu.cn/yehj' target='_blank'>Han-Jia Ye</a><sup>1</sup>  | ||
<a href='http://www.lamda.nju.edu.cn/zhandc' target='_blank'>De-Chuan Zhan</a><sup>1</sup>  | ||
<a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu </a><sup>2</sup> | ||
</div> | ||
<div> | ||
<sup>1</sup>State Key Laboratory for Novel Software Technology, Nanjing University  | ||
|
||
<sup>2</sup>S-Lab, Nanyang Technological University  | ||
</div> | ||
</div> | ||
|
||
The code repository for "Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need" in PyTorch. | ||
|
||
|
||
|
||
|
||
|
||
## Updates | ||
[03/2023] Code has been released. | ||
|
||
## Introduction | ||
Class-incremental learning (CIL) aims to adapt to emerging new classes without forgetting old ones. Traditional CIL models are trained from scratch to continually acquire knowledge as data evolves. Recently, pre-training has achieved substantial progress, making vast pre-trained models (PTMs) accessible for CIL. Contrary to traditional methods, PTMs possess generalizable embeddings, which can be easily transferred for CIL. In this work, we revisit CIL with PTMs and argue that the core factors in CIL are adaptivity for model updating and generalizability for knowledge transferring. 1) We first reveal that frozen PTM can already provide generalizable | ||
embeddings for CIL. Surprisingly, a simple baseline (**SimpleCIL**) which continually sets the classifiers of PTM to prototype features can beat state-of-the-art even without training on the downstream task. 2) Due to the distribution gap between pre-trained and downstream datasets, PTM can be further cultivated with adaptivity via model adaptation. We propose **ADapt And Merge** (ADAM), which aggregates the embeddings of PTM and adapted models for classifier construction. ADAM is a general framework that can be orthogonally combined with any parameter-efficient tuning method, which holds the advantages of PTM’s generalizability and adapted model’s adaptivity. 3) Additionally, we find previous benchmarks are unsuitable in the era of PTM due to data overlapping and propose four new benchmarks for assessment, namely ImageNet-A, ObjectNet, OmniBenchmark, and VTAB. Extensive experiments validate the effectiveness of ADAM with a unified and concise framework. | ||
|
||
<div align="center"> | ||
<img src="imgs/adam.png" width="95%"> | ||
|
||
<h3>TL;DR</h3> | ||
|
||
A simple baseline (**SimpleCIL**) beats SOTA even without training on the downstream task. **ADapt And Merge** (ADAM) extends SimpleCIL with better adaptivity and generalizability. Four new benchmarks are proposed for assessment. | ||
</div> | ||
|
||
|
||
|
||
## Requirements | ||
### Environment | ||
1. [torch 1.11.0](https://github.com/pytorch/pytorch) | ||
2. [torchvision 0.12.0](https://github.com/pytorch/vision) | ||
3. [timm 0.6.12](https://github.com/huggingface/pytorch-image-models) | ||
|
||
|
||
### Dataset | ||
We provide the processed datasets as follows: | ||
- **CIFAR100**: will be automatically downloaded by the code. | ||
- **CUB200**: Google Drive: [link](https://drive.google.com/file/d/1XbUpnWpJPnItt5zQ6sHJnsjPncnNLvWb/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EVV4pT9VJ9pBrVs2x0lcwd0BlVQCtSrdbLVfhuajMry-lA?e=L6Wjsc) | ||
- **ImageNet-R**: Google Drive: [link](https://drive.google.com/file/d/1SG4TbiL8_DooekztyCVK8mPmfhMo8fkR/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EU4jyLL29CtBsZkB6y-JSbgBzWF5YHhBAUz1Qw8qM2954A?e=hlWpNW) | ||
- **ImageNet-A**:Google Drive: [link](https://drive.google.com/file/d/19l52ua_vvTtttgVRziCZJjal0TPE9f2p/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/ERYi36eg9b1KkfEplgFTW3gBg1otwWwkQPSml0igWBC46A?e=NiTUkL) | ||
- **OmniBenchmark**: Google Drive: [link](https://drive.google.com/file/d/1AbCP3zBMtv_TDXJypOCnOgX8hJmvJm3u/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EcoUATKl24JFo3jBMnTV2WcBwkuyBH0TmCAy6Lml1gOHJA?e=eCNcoA) | ||
- **VTAB**: Google Drive: [link](https://drive.google.com/file/d/1xUiwlnx4k0oDhYi26KL5KwrCAya-mvJ_/view?usp=sharing) or Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EQyTP1nOIH5PrfhXtpPgKQ8BlEFW2Erda1t7Kdi3Al-ePw?e=Yt4RnV) | ||
- **ObjectNet**: Onedrive: [link](https://entuedu-my.sharepoint.com/:u:/g/personal/n2207876b_e_ntu_edu_sg/EZFv9uaaO1hBj7Y40KoCvYkBnuUZHnHnjMda6obiDpiIWw?e=4n8Kpy) You can also refer to the [filelist](https://drive.google.com/file/d/147Mta-HcENF6IhZ8dvPnZ93Romcie7T6/view?usp=sharing) if the file is too large to download. | ||
|
||
These subsets are sampled from the original datasets. Please note that I do not have the right to distribute these datasets. If the distribution violates the license, I shall provide the filenames instead. | ||
|
||
You need to modify the path of the datasets in `./utils/data.py` according to your own path. | ||
|
||
## Running scripts | ||
Please follow the settings in the `exps` folder to prepare your json files, and then run: | ||
|
||
``` | ||
python main.py --config ./exps/[configname].json | ||
``` | ||
|
||
|
||
## Acknolegment | ||
This repo is based on [CIL_Survey](https://github.com/zhoudw-zdw/CIL_Survey) and [PyCIL](https://github.com/G-U-N/PyCIL). | ||
|
||
The implemenations of parameter-efficient tuning methods are based on [VPT](https://github.com/sagizty/VPT), [AdaptFormer](https://github.com/ShoufaChen/AdaptFormer), and [SSF](https://github.com/dongzelian/SSF). | ||
|
||
## Correspondence | ||
If you have any questions, please contact me via [email](mailto:[email protected]) or open an [issue](https://github.com/zhoudw-zdw/RevisitingCIL/issues/new). | ||
|
||
|
||
<div align="center"> | ||
|
||
![visitors](https://visitor-badge.glitch.me/badge?page_id=zhoudw-zdw.RevisitingCIL&left_color=green&right_color=red) | ||
|
||
</div> |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
''' | ||
Reference: | ||
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py | ||
''' | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class DownsampleA(nn.Module): | ||
def __init__(self, nIn, nOut, stride): | ||
super(DownsampleA, self).__init__() | ||
assert stride == 2 | ||
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) | ||
|
||
def forward(self, x): | ||
x = self.avg(x) | ||
return torch.cat((x, x.mul(0)), 1) | ||
|
||
|
||
class DownsampleB(nn.Module): | ||
def __init__(self, nIn, nOut, stride): | ||
super(DownsampleB, self).__init__() | ||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) | ||
self.bn = nn.BatchNorm2d(nOut) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
return x | ||
|
||
|
||
class DownsampleC(nn.Module): | ||
def __init__(self, nIn, nOut, stride): | ||
super(DownsampleC, self).__init__() | ||
assert stride != 1 or nIn != nOut | ||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return x | ||
|
||
|
||
class DownsampleD(nn.Module): | ||
def __init__(self, nIn, nOut, stride): | ||
super(DownsampleD, self).__init__() | ||
assert stride == 2 | ||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) | ||
self.bn = nn.BatchNorm2d(nOut) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
return x | ||
|
||
|
||
class ResNetBasicblock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(ResNetBasicblock, self).__init__() | ||
|
||
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
self.bn_a = nn.BatchNorm2d(planes) | ||
|
||
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn_b = nn.BatchNorm2d(planes) | ||
|
||
self.downsample = downsample | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
basicblock = self.conv_a(x) | ||
basicblock = self.bn_a(basicblock) | ||
basicblock = F.relu(basicblock, inplace=True) | ||
|
||
basicblock = self.conv_b(basicblock) | ||
basicblock = self.bn_b(basicblock) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
return F.relu(residual + basicblock, inplace=True) | ||
|
||
|
||
class CifarResNet(nn.Module): | ||
""" | ||
ResNet optimized for the Cifar Dataset, as specified in | ||
https://arxiv.org/abs/1512.03385.pdf | ||
""" | ||
|
||
def __init__(self, block, depth, channels=3): | ||
super(CifarResNet, self).__init__() | ||
|
||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||
layer_blocks = (depth - 2) // 6 | ||
|
||
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn_1 = nn.BatchNorm2d(16) | ||
|
||
self.inplanes = 16 | ||
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) | ||
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) | ||
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) | ||
self.avgpool = nn.AvgPool2d(8) | ||
self.out_dim = 64 * block.expansion | ||
self.fc = nn.Linear(64*block.expansion, 10) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
# m.bias.data.zero_() | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
nn.init.kaiming_normal_(m.weight) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv_1_3x3(x) # [bs, 16, 32, 32] | ||
x = F.relu(self.bn_1(x), inplace=True) | ||
|
||
x_1 = self.stage_1(x) # [bs, 16, 32, 32] | ||
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] | ||
x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] | ||
|
||
pooled = self.avgpool(x_3) # [bs, 64, 1, 1] | ||
features = pooled.view(pooled.size(0), -1) # [bs, 64] | ||
|
||
return { | ||
'fmaps': [x_1, x_2, x_3], | ||
'features': features | ||
} | ||
|
||
@property | ||
def last_conv(self): | ||
return self.stage_3[-1].conv_b | ||
|
||
|
||
def resnet20mnist(): | ||
"""Constructs a ResNet-20 model for MNIST.""" | ||
model = CifarResNet(ResNetBasicblock, 20, 1) | ||
return model | ||
|
||
|
||
def resnet32mnist(): | ||
"""Constructs a ResNet-32 model for MNIST.""" | ||
model = CifarResNet(ResNetBasicblock, 32, 1) | ||
return model | ||
|
||
|
||
def resnet20(): | ||
"""Constructs a ResNet-20 model for CIFAR-10.""" | ||
model = CifarResNet(ResNetBasicblock, 20) | ||
return model | ||
|
||
|
||
def resnet32(): | ||
"""Constructs a ResNet-32 model for CIFAR-10.""" | ||
model = CifarResNet(ResNetBasicblock, 32) | ||
return model | ||
|
||
|
||
def resnet44(): | ||
"""Constructs a ResNet-44 model for CIFAR-10.""" | ||
model = CifarResNet(ResNetBasicblock, 44) | ||
return model | ||
|
||
|
||
def resnet56(): | ||
"""Constructs a ResNet-56 model for CIFAR-10.""" | ||
model = CifarResNet(ResNetBasicblock, 56) | ||
return model | ||
|
||
|
||
def resnet110(): | ||
"""Constructs a ResNet-110 model for CIFAR-10.""" | ||
model = CifarResNet(ResNetBasicblock, 110) | ||
return model |
Oops, something went wrong.