-
Notifications
You must be signed in to change notification settings - Fork 22
/
feamap_visual.py
101 lines (81 loc) · 3.52 KB
/
feamap_visual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
# 下面的模块是根据所指定的模型筛选出指定层的特征图输出,
# 如果未指定也就是extracted_layers是None则以字典的形式输出全部的特征图,
# 另外因为全连接层本身是一维的没必要输出因此进行了过滤。
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (256, 256))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = 'path/imgs/AAA.jpg'
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
# 这里主要是一些参数,比如要提取的网络,网络的权重,要提取的层,指定的图像放大的大小,存储路径等等。
net = models.resnet101().to(device)
net.load_state_dict(torch.load('PATH/weights/resnet101-5d3b4d8f.pth')) #
exact_list = None
dst = './feautures/'
therd_size = 640
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
# 这段主要是存储图片,为每个层创建一个文件夹将特征图以JET的colormap进行按顺序存储到该文件夹,
# 并且如果特征图过小也会对特征图放大同时存储原始图和放大后的图。
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
# plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.cpu().numpy()
feature_img = feature[i, :, :]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.jpg')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.jpg')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()