-
Notifications
You must be signed in to change notification settings - Fork 426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utilizing resnet_50.pth for 3D Feature Map Extraction #83
Comments
Hi! @aeinkoupaei Did you figure out if this is the correct approach to get the feature map from a image? Also needing to use this make a feature map. Is there any reason you made num_seg_classes=1? |
Hi @Ram2314, To use a pre-trained ResNet model for extracting 3D feature maps, you'll need to focus on the ResNet class within the resnet.py file. Here's what to change: Here's an example of how to use a pre-trained ResNet-10 model for feature extraction: resnet_10 = resnet10(shortcut_type='B', no_cuda=True) resnet_10.load_state_dict(new_state_dict) |
@aeinkoupaei Awesome thanks! Also did you figure out the length, width, height issue? Seems I can set it to anything and it will work? |
Hi, I want to use resnet_50.pth pre-trained encoder to extract 3D feature maps from medical images. Is the following method correct? It seems strange that the parameters of width, height, depth and number of channels can be adjusted manually. Isn't it the case that the resnet_50.pth pre-trained model is trained with a specific architecture, length, width, height, and channel? Therefore, shouldn't the input of the trained model for extracting 3D feature maps have the same dimensions as inputs of the model in the training phase?
resnet50 = resnet50(
sample_input_D=32,
sample_input_H=256,
sample_input_W=256,
shortcut_type='B',
no_cuda=True,
num_seg_classes=1
)
pretrain = torch.load("pretrain/resnet_50.pth") # Load the weights from the pretrained file
pretrained_dict = pretrain['state_dict']
new_state_dict = OrderedDict()
for k, v in pretrained_dict.items():
name = k[7:] # Remove 'module.'
new_state_dict[name] = v
resnet10.load_state_dict(new_state_dict, strict=False)
A_img_feature_map = resnet50(A_img)
The text was updated successfully, but these errors were encountered: