Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add vis-cam tool #577

Merged
merged 37 commits into from
Dec 23, 2021
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7841e27
add cam-grad tool
Ezra-Yu Nov 20, 2021
6e3e5c6
Merge branch 'open-mmlab:master' into cam
Ezra-Yu Nov 24, 2021
1636df9
refactor cam-grad tool
Ezra-Yu Nov 24, 2021
f23c1b2
Merge branch 'open-mmlab:master' into cam
Ezra-Yu Nov 29, 2021
86a3f57
add docs
Ezra-Yu Nov 29, 2021
a955b45
update docs
Ezra-Yu Nov 30, 2021
6f2fbec
Merge branch 'cam' of github.com:Ezra-Yu/mmclassification into cam
Ezra-Yu Nov 30, 2021
f2cb8b2
Update docs and support Transformer
Ezra-Yu Dec 2, 2021
b6ce116
Merge branch 'open-mmlab:master' into cam
Ezra-Yu Dec 2, 2021
44ad8eb
remove pictures and use link
Ezra-Yu Dec 2, 2021
0bb57dd
Merge branch 'cam' of github.com:Ezra-Yu/mmclassification into cam
Ezra-Yu Dec 2, 2021
e8b2368
replace example img and finish EN docs
Ezra-Yu Dec 3, 2021
f2653c1
improve docs
Ezra-Yu Dec 7, 2021
271040c
improve code
Ezra-Yu Dec 7, 2021
a1c93a1
Merge remote-tracking branch 'origin/master' into cam
Ezra-Yu Dec 7, 2021
780ea76
Fix MobileNet V3 configs
mzr1996 Dec 8, 2021
569a72e
Refactor to support more powerful feature extraction.
mzr1996 Dec 8, 2021
998fa0a
Add unit tests
mzr1996 Dec 8, 2021
c32a4a4
Fix unit test
mzr1996 Dec 8, 2021
a9293dd
merge feature extraction
Ezra-Yu Dec 13, 2021
d82276b
fix distortion of visualization exapmles in docs
Ezra-Yu Dec 13, 2021
cbd2b7c
fix distortion
Ezra-Yu Dec 13, 2021
28900e7
fix distortion
Ezra-Yu Dec 13, 2021
3474cb1
fix distortion
Ezra-Yu Dec 13, 2021
2e5d6ba
merge master
Ezra-Yu Dec 20, 2021
03bc739
merge master
Ezra-Yu Dec 20, 2021
7a2152f
merge fix conficts
Ezra-Yu Dec 20, 2021
ba1ddfd
Imporve the tool
mzr1996 Dec 22, 2021
f38af5f
Support use both attribute name and index to get layer
mzr1996 Dec 22, 2021
0a6e046
add default get_target-layers
Ezra-Yu Dec 22, 2021
76afdec
Merge branch 'cam' of github.com:Ezra-Yu/mmclassification into cam
Ezra-Yu Dec 22, 2021
40490ce
add default get_target-layers
Ezra-Yu Dec 22, 2021
fd51580
update docs
Ezra-Yu Dec 22, 2021
e3cad26
update docs
Ezra-Yu Dec 22, 2021
e2874a9
add additional printt info when not using target-layers
Ezra-Yu Dec 22, 2021
3445afa
Imporve docs
mzr1996 Dec 23, 2021
9d5c053
Fix enumerate list.
mzr1996 Dec 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add unit tests
mzr1996 committed Dec 8, 2021
commit 998fa0a06f2c537dec2e4fd2cb9a8eb526e482bd
100 changes: 99 additions & 1 deletion tests/test_models/test_classifiers.py
Original file line number Diff line number Diff line change
@@ -73,6 +73,19 @@ def test_image_classifier():
pred = model(single_img, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 1

pred = model.simple_test(imgs, softmax=False)
assert isinstance(pred, list) and len(pred) == 16
assert len(pred[0] == 10)

pred = model.simple_test(imgs, softmax=False, post_process=False)
assert isinstance(pred, torch.Tensor)
assert pred.shape == (16, 10)

soft_pred = model.simple_test(imgs, softmax=True, post_process=False)
assert isinstance(soft_pred, torch.Tensor)
assert soft_pred.shape == (16, 10)
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))

# test pretrained
# TODO remove deprecated pretrained
with pytest.warns(UserWarning):
@@ -83,7 +96,7 @@ def test_image_classifier():
type='Pretrained', checkpoint='checkpoint')

# test show_result
img = np.random.random_integers(0, 255, (224, 224, 3)).astype(np.uint8)
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)

with tempfile.TemporaryDirectory() as tmpdir:
@@ -304,3 +317,88 @@ def forward(self, x):

with pytest.warns(DeprecationWarning):
model.extract_feat(imgs)


def test_classifier_extract_feat():
model_cfg = ConfigDict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss'),
topk=(1, 5),
))

model = CLASSIFIERS.build(model_cfg)

# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
assert outs[0].shape == (1, 64, 56, 56)
assert outs[1].shape == (1, 128, 28, 28)
assert outs[2].shape == (1, 256, 14, 14)
assert outs[3].shape == (1, 512, 7, 7)

# test neck output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
assert outs[0].shape == (1, 64)
assert outs[1].shape == (1, 128)
assert outs[2].shape == (1, 256)
assert outs[3].shape == (1, 512)

# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 512)

# test transformer style feature extraction
model_cfg = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer', arch='b', out_indices=[-3, -2, -1]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=1024,
loss=dict(type='CrossEntropyLoss'),
))
model = CLASSIFIERS.build(model_cfg)

# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)

# test neck output (the same with backbone)
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)

# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 1024)

# test extract_feats
multi_imgs = [torch.rand(1, 3, 224, 224) for _ in range(3)]
outs = model.extract_feats(multi_imgs)
for outs_per_img in outs:
for out in outs_per_img:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)

outs = model.extract_feats(multi_imgs, stage='pre_logits')
for out_per_img in outs:
assert out_per_img.shape == (1, 1024)
173 changes: 141 additions & 32 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
@@ -4,36 +4,52 @@
import pytest
import torch

from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
from mmcls.models.heads import (ClsHead, ConformerHead, LinearClsHead,
MultiLabelClsHead, MultiLabelLinearClsHead,
StackedLinearClsHead, VisionTransformerClsHead)


@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_cls_head(feat):
fake_gt_label = torch.randint(0, 10, (4, ))

# test ClsHead with cal_acc=False
head = ClsHead()
fake_gt_label = torch.randint(0, 2, (4, ))

# test forward_train with cal_acc=True
head = ClsHead(cal_acc=True)
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
assert 'accuracy' in losses

# test ClsHead with cal_acc=True
head = ClsHead(cal_acc=True)
feat = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))

# test forward_train with cal_acc=False
head = ClsHead()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test ClsHead with weight
# test forward_train with weight
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])

losses_ = head.forward_train(feat, fake_gt_label)
losses = head.forward_train(feat, fake_gt_label, weight=weight)
assert losses['loss'].item() == losses_['loss'].item() * 0.5

# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_close(pred, torch.softmax(logits, dim=1))

# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)


@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
def test_linear_head(feat):
@@ -50,35 +66,85 @@ def test_linear_head(feat):
head.init_weights()
assert abs(head.fc.weight).sum() > 0

# test simple_test
head = LinearClsHead(10, 3)
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4

with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = LinearClsHead(10, 3)
pred = head.simple_test(feat)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_close(pred, torch.softmax(logits, dim=1))

@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)


@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_multilabel_head(feat):
head = MultiLabelClsHead()
fake_gt_label = torch.randint(0, 2, (4, 3))
fake_gt_label = torch.randint(0, 2, (4, 10))

losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_close(pred, torch.sigmoid(logits))

# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)


@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_multilabel_linear_head(feat):
head = MultiLabelLinearClsHead(3, 5)
fake_gt_label = torch.randint(0, 2, (4, 3))
head = MultiLabelLinearClsHead(10, 5)
fake_gt_label = torch.randint(0, 2, (4, 10))

head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_close(pred, torch.sigmoid(logits))

# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)


@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_stacked_linear_cls_head(feat):
@@ -93,20 +159,28 @@ def test_stacked_linear_cls_head(feat):

# test forward with default setting
head = StackedLinearClsHead(
num_classes=3, in_channels=5, mid_channels=[10])
num_classes=10, in_channels=5, mid_channels=[20])
head.init_weights()

losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test simple test
# test simple_test with post_process
pred = head.simple_test(feat)
assert len(pred) == 4

# test simple test in tracing
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == torch.Size((4, 3))
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_close(pred, torch.softmax(logits, dim=1))

# test pre_logits
features = head.pre_logits(feat)
assert features.shape == (4, 20)

# test forward with full function
head = StackedLinearClsHead(
@@ -144,16 +218,51 @@ def test_vit_head():
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0

# test simple_test
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4

with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_close(pred, torch.softmax(logits, dim=1))

# test pre_logits
features = head.pre_logits(fake_features)
assert features.shape == (4, 20)

# test assertion
with pytest.raises(ValueError):
VisionTransformerClsHead(-1, 100)


def test_conformer_head():
fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], )
fake_gt_label = torch.randint(0, 10, (4, ))

# test conformer head forward
head = ConformerHead(num_classes=10, in_channels=[64, 96])
losses = head.forward_train(fake_features, fake_gt_label)
assert losses['loss'].item() > 0

# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)

# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_close(pred, torch.softmax(sum(logits), dim=1))

# test pre_logits
features = head.pre_logits(fake_features)
assert features is fake_features[0]