diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index 394b9331bf0..25888afea76 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -85,13 +85,20 @@ class MultiScaleRoIAlign(nn.Module): """ Multi-scale RoIAlign pooling, which is useful for detection with or without FPN. - It infers the scale of the pooling via the heuristics present in the FPN paper. + It infers the scale of the pooling via the heuristics specified in eq. 1 + of the `Feature Pyramid Network paper `_. + They keyword-only parameters ``canonical_scale`` and ``canonical_level`` + correspond respectively to ``224`` and ``k0=4`` in eq. 1, and + have the following meaning: ``canonical_level`` is the target level of the pyramid from + which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``. Args: featmap_names (List[str]): the names of the feature maps that will be used for the pooling. output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region sampling_ratio (int): sampling ratio for ROIAlign + canonical_scale (int, optional): canonical_scale for LevelMapper + canonical_level (int, optional): canonical_level for LevelMapper Examples:: @@ -120,6 +127,9 @@ def __init__( featmap_names: List[str], output_size: Union[int, Tuple[int], List[int]], sampling_ratio: int, + *, + canonical_scale: int = 224, + canonical_level: int = 4, ): super(MultiScaleRoIAlign, self).__init__() if isinstance(output_size, int): @@ -129,6 +139,8 @@ def __init__( self.output_size = tuple(output_size) self.scales = None self.map_levels = None + self.canonical_scale = canonical_scale + self.canonical_level = canonical_level def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor: concat_boxes = torch.cat(boxes, dim=0) @@ -173,7 +185,12 @@ def setup_scales( lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() self.scales = scales - self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max)) + self.map_levels = initLevelMapper( + int(lvl_min), + int(lvl_max), + canonical_scale=self.canonical_scale, + canonical_level=self.canonical_level, + ) def forward( self,