-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhiwei
committed
Oct 18, 2024
1 parent
7233410
commit 9d1e164
Showing
125 changed files
with
10,815 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from detectron2.config import LazyCall as L | ||
from detectron2.layers import ShapeSpec | ||
from detectron2.modeling.meta_arch import CBGeneralizedRCNN | ||
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator | ||
from detectron2.modeling.backbone.fpn import LastLevelMaxPool | ||
from detectron2.modeling.backbone import BasicStem, FPN, ResNet | ||
from detectron2.modeling.box_regression import Box2BoxTransform | ||
from detectron2.modeling.matcher import Matcher | ||
from detectron2.modeling.poolers import ROIPooler | ||
from detectron2.modeling.proposal_generator import RPN, StandardRPNHead | ||
from detectron2.modeling.roi_heads import ( | ||
StandardROIHeads, | ||
FastRCNNOutputLayers, | ||
MaskRCNNConvUpsampleHead, | ||
FastRCNNConvFCHead, | ||
) | ||
|
||
from ..data.constants import constants | ||
|
||
model = L(CBGeneralizedRCNN)( | ||
backbone=L(FPN)( | ||
bottom_up=L(ResNet)( | ||
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), | ||
stages=L(ResNet.make_default_stages)( | ||
depth=50, | ||
stride_in_1x1=True, | ||
norm="FrozenBN", | ||
), | ||
out_features=["res2", "res3", "res4", "res5"], | ||
), | ||
in_features="${.bottom_up.out_features}", | ||
out_channels=256, | ||
top_block=L(LastLevelMaxPool)(), | ||
), | ||
proposal_generator=L(RPN)( | ||
in_features=["p2", "p3", "p4", "p5", "p6"], | ||
head=L(StandardRPNHead)(in_channels=256, num_anchors=3), | ||
anchor_generator=L(DefaultAnchorGenerator)( | ||
sizes=[[32], [64], [128], [256], [512]], | ||
aspect_ratios=[0.5, 1.0, 2.0], | ||
strides=[4, 8, 16, 32, 64], | ||
offset=0.0, | ||
), | ||
anchor_matcher=L(Matcher)( | ||
thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True | ||
), | ||
box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), | ||
batch_size_per_image=256, | ||
positive_fraction=0.5, | ||
pre_nms_topk=(2000, 1000), | ||
post_nms_topk=(1000, 1000), | ||
nms_thresh=0.7, | ||
), | ||
roi_heads=L(StandardROIHeads)( | ||
num_classes=80, | ||
batch_size_per_image=512, | ||
positive_fraction=0.25, | ||
proposal_matcher=L(Matcher)( | ||
thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False | ||
), | ||
box_in_features=["p2", "p3", "p4", "p5"], | ||
box_pooler=L(ROIPooler)( | ||
output_size=7, | ||
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), | ||
sampling_ratio=0, | ||
pooler_type="ROIAlignV2", | ||
), | ||
box_head=L(FastRCNNConvFCHead)( | ||
input_shape=ShapeSpec(channels=256, height=7, width=7), | ||
conv_dims=[], | ||
fc_dims=[1024, 1024], | ||
), | ||
box_predictor=L(FastRCNNOutputLayers)( | ||
input_shape=ShapeSpec(channels=1024), | ||
test_score_thresh=0.05, | ||
box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), | ||
num_classes="${..num_classes}", | ||
), | ||
mask_in_features=["p2", "p3", "p4", "p5"], | ||
mask_pooler=L(ROIPooler)( | ||
output_size=14, | ||
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), | ||
sampling_ratio=0, | ||
pooler_type="ROIAlignV2", | ||
), | ||
mask_head=L(MaskRCNNConvUpsampleHead)( | ||
input_shape=ShapeSpec(channels=256, width=14, height=14), | ||
num_classes="${..num_classes}", | ||
conv_dims=[256, 256, 256, 256, 256], | ||
), | ||
), | ||
pixel_mean=constants.imagenet_bgr256_mean, | ||
pixel_std=constants.imagenet_bgr256_std, | ||
input_format="BGR", | ||
) |
87 changes: 87 additions & 0 deletions
87
EVA/EVA-02/det/configs/common/models/cb_mask_rcnn_vitdet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from functools import partial | ||
import torch.nn as nn | ||
from detectron2.config import LazyCall as L | ||
from detectron2.modeling import CBViT, CBSimpleFeaturePyramid | ||
from detectron2.modeling.backbone.fpn import LastLevelMaxPool | ||
|
||
from .cb_mask_rcnn_fpn import model | ||
from ..data.constants import constants | ||
|
||
model.pixel_mean = constants.imagenet_rgb256_mean | ||
model.pixel_std = constants.imagenet_rgb256_std | ||
model.input_format = "RGB" | ||
|
||
# from apex.normalization import FusedLayerNorm | ||
|
||
# Base | ||
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 | ||
# Creates Simple Feature Pyramid from ViT backbone | ||
model.backbone = L(CBSimpleFeaturePyramid)( | ||
net=L(CBViT)( # Single-scale ViT backbone | ||
img_size=1024, | ||
patch_size=16, | ||
embed_dim=embed_dim, | ||
depth=depth, | ||
num_heads=num_heads, | ||
drop_path_rate=dp, | ||
window_size=14, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | ||
window_block_indexes=[ | ||
# 2, 5, 8 11 for global attention | ||
0, | ||
1, | ||
3, | ||
4, | ||
6, | ||
7, | ||
9, | ||
10, | ||
], | ||
residual_block_indexes=[], | ||
use_rel_pos=True, | ||
out_feature="last_feat", | ||
), | ||
cb_net=L(CBViT)( # Single-scale ViT backbone | ||
img_size=1024, | ||
patch_size=16, | ||
embed_dim=embed_dim, | ||
depth=depth, | ||
num_heads=num_heads, | ||
drop_path_rate=dp, | ||
window_size=14, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
norm_layer=partial(nn.LayerNorm, eps=1e-6), | ||
window_block_indexes=[ | ||
# 2, 5, 8 11 for global attention | ||
0, | ||
1, | ||
3, | ||
4, | ||
6, | ||
7, | ||
9, | ||
10, | ||
], | ||
residual_block_indexes=[], | ||
use_rel_pos=True, | ||
out_feature="last_feat", | ||
), | ||
in_feature="${.net.out_feature}", | ||
out_channels=256, | ||
scale_factors=(4.0, 2.0, 1.0, 0.5), | ||
top_block=L(LastLevelMaxPool)(), | ||
norm="LN", | ||
square_pad=1024, | ||
) | ||
|
||
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" | ||
|
||
# 2conv in RPN: | ||
model.proposal_generator.head.conv_dims = [-1, -1] | ||
|
||
# 4conv1fc box head | ||
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] | ||
model.roi_heads.box_head.fc_dims = [1024] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
from .build import build_backbone, BACKBONE_REGISTRY # noqa F401 isort:skip | ||
|
||
from .backbone import Backbone | ||
from .fpn import FPN | ||
from .regnet import RegNet | ||
from .resnet import ( | ||
BasicStem, | ||
ResNet, | ||
ResNetBlockBase, | ||
build_resnet_backbone, | ||
make_stage, | ||
BottleneckBlock, | ||
) | ||
from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate | ||
from .mvit import MViT | ||
from .swin import SwinTransformer | ||
from .cb_vit import CBViT, CBSimpleFeaturePyramid | ||
|
||
__all__ = [k for k in globals().keys() if not k.startswith("_")] | ||
# TODO can expose more resnet blocks after careful consideration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Dict | ||
import torch.nn as nn | ||
|
||
from detectron2.layers import ShapeSpec | ||
|
||
__all__ = ["Backbone"] | ||
|
||
|
||
class Backbone(nn.Module, metaclass=ABCMeta): | ||
""" | ||
Abstract base class for network backbones. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
The `__init__` method of any subclass can specify its own set of arguments. | ||
""" | ||
super().__init__() | ||
|
||
@abstractmethod | ||
def forward(self): | ||
""" | ||
Subclasses must override this method, but adhere to the same return type. | ||
Returns: | ||
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor | ||
""" | ||
pass | ||
|
||
@property | ||
def size_divisibility(self) -> int: | ||
""" | ||
Some backbones require the input height and width to be divisible by a | ||
specific integer. This is typically true for encoder / decoder type networks | ||
with lateral connection (e.g., FPN) for which feature maps need to match | ||
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific | ||
input size divisibility is required. | ||
""" | ||
return 0 | ||
|
||
@property | ||
def padding_constraints(self) -> Dict[str, int]: | ||
""" | ||
This property is a generalization of size_divisibility. Some backbones and training | ||
recipes require specific padding constraints, such as enforcing divisibility by a specific | ||
integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter | ||
in :paper:vitdet). `padding_constraints` contains these optional items like: | ||
{ | ||
"size_divisibility": int, | ||
"square_size": int, | ||
# Future options are possible | ||
} | ||
`size_divisibility` will read from here if presented and `square_size` indicates the | ||
square padding size if `square_size` > 0. | ||
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints | ||
could be generalized as TypedDict (Python 3.8+) to support more types in the future. | ||
""" | ||
return {} | ||
|
||
def output_shape(self): | ||
""" | ||
Returns: | ||
dict[str->ShapeSpec] | ||
""" | ||
# this is a backward-compatible default | ||
return { | ||
name: ShapeSpec( | ||
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] | ||
) | ||
for name in self._out_features | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
from detectron2.layers import ShapeSpec | ||
from detectron2.utils.registry import Registry | ||
|
||
from .backbone import Backbone | ||
|
||
BACKBONE_REGISTRY = Registry("BACKBONE") | ||
BACKBONE_REGISTRY.__doc__ = """ | ||
Registry for backbones, which extract feature maps from images | ||
The registered object must be a callable that accepts two arguments: | ||
1. A :class:`detectron2.config.CfgNode` | ||
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. | ||
Registered object must return instance of :class:`Backbone`. | ||
""" | ||
|
||
|
||
def build_backbone(cfg, input_shape=None): | ||
""" | ||
Build a backbone from `cfg.MODEL.BACKBONE.NAME`. | ||
Returns: | ||
an instance of :class:`Backbone` | ||
""" | ||
if input_shape is None: | ||
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) | ||
|
||
backbone_name = cfg.MODEL.BACKBONE.NAME | ||
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) | ||
assert isinstance(backbone, Backbone) | ||
return backbone |
Oops, something went wrong.