Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSD architecture with VGG16 backbone #3403

Merged
merged 113 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 110 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
db93f46
Early skeleton of API.
datumbox Feb 15, 2021
ebfd624
Merge branch 'master' into models/ssd
datumbox Feb 28, 2021
b2e42bb
Adding MultiFeatureMap and vgg16 backbone.
datumbox Mar 6, 2021
6cfa98c
Merge branch 'master' into models/ssd
datumbox Mar 6, 2021
80da3b9
Merge branch 'models/ssd' of https://github.com/datumbox/vision into …
datumbox Mar 6, 2021
9779324
Making vgg16 backbone same as paper.
datumbox Mar 7, 2021
bffe4bc
Making code generic to support all vggs.
datumbox Mar 8, 2021
eced9f0
Moving vgg's extra layers a separate class + L2 scaling.
datumbox Mar 8, 2021
869ede4
Adding header vgg layers.
datumbox Mar 8, 2021
c5ba9c1
Fix maxpool patching.
datumbox Mar 10, 2021
c91bfae
Refactoring code to allow for support of different backbones & sizes:
datumbox Mar 11, 2021
3820e09
Complete the implementation of DefaultBox generator.
datumbox Mar 11, 2021
044d178
Replace randn with empty.
datumbox Mar 11, 2021
6a7b9b4
Minor refactoring
datumbox Mar 11, 2021
e85e631
Making clamping between 0 and 1 optional.
datumbox Mar 11, 2021
c574878
Merge branch 'master' into models/ssd
datumbox Mar 11, 2021
327e004
Change xywh to xyxy encoding.
datumbox Mar 12, 2021
11c9839
Adding parameters and reusing objects in constructor.
datumbox Mar 12, 2021
34237e4
Temporarily inherit from Retina to avoid dup code.
datumbox Mar 12, 2021
d3f345e
Implement forward methods + temp workarounds to inherit from retina.
datumbox Mar 12, 2021
ac25158
Inherit more methods from retinanet.
datumbox Mar 12, 2021
c9c8148
Merge branch 'master' into models/ssd
datumbox Mar 31, 2021
eed06f4
Fix type error.
datumbox Mar 31, 2021
69c4d24
Merge branch 'master' into models/ssd
datumbox Apr 5, 2021
b185e91
Add Regression loss.
datumbox Apr 7, 2021
3a9166f
Fixing JIT issues.
datumbox Apr 7, 2021
a604ab4
Change JIT workaround to minimize new code.
datumbox Apr 7, 2021
6e996d9
Fixing initialization bug.
datumbox Apr 7, 2021
c524ee1
Add classification loss.
datumbox Apr 7, 2021
44d8a0b
Update todos.
datumbox Apr 7, 2021
15b3ebf
Add weight loading support.
datumbox Apr 8, 2021
f67db92
Support SSD512.
datumbox Apr 8, 2021
d19144d
Change kernel_size to get output size 1x1
datumbox Apr 8, 2021
661eb31
Add xavier init and refactoring.
datumbox Apr 8, 2021
dcdd04d
Adding unit-tests and fixing JIT issues.
datumbox Apr 8, 2021
5a2e22c
Merge branch 'master' into models/ssd
datumbox Apr 8, 2021
5b5e8f8
Add a test for dbox generator.
datumbox Apr 8, 2021
6e04fb9
Merge branch 'master' into models/ssd
datumbox Apr 8, 2021
c709805
Remove unnecessary import.
datumbox Apr 8, 2021
b3d40c4
Merge branch 'master' into models/ssd
datumbox Apr 9, 2021
2d0f267
Workaround on GeneralizedRCNNTransform to support fixed size input.
datumbox Apr 9, 2021
39abfb4
Remove unnecessary random calls from the test.
datumbox Apr 9, 2021
0b7eb43
Remove more rand calls from the test.
datumbox Apr 9, 2021
e74b4fe
change mapping and handling of empty labels
datumbox Apr 11, 2021
3f0c99c
Fix JIT warnings.
datumbox Apr 11, 2021
eb33940
Speed up loss.
datumbox Apr 11, 2021
c880de4
Convert 0-1 dboxes to original size.
datumbox Apr 11, 2021
0883889
Fix warning.
datumbox Apr 11, 2021
7c56cc8
Fix tests.
datumbox Apr 11, 2021
218ca55
Update comments.
datumbox Apr 11, 2021
36f53f5
Fixing minor bugs.
datumbox Apr 12, 2021
fe95322
Introduce a custom DBoxMatcher.
datumbox Apr 12, 2021
67195c6
Merge branch 'master' into models/ssd
datumbox Apr 12, 2021
0342e7e
Minor refactoring
datumbox Apr 12, 2021
acdcd78
Move extra layer definition inside feature extractor.
datumbox Apr 13, 2021
6c3b3fa
handle no bias on init.
datumbox Apr 13, 2021
9ad0634
Remove fixed image size limitation
datumbox Apr 14, 2021
5a00a0c
Change initialization values for bias of classification head.
datumbox Apr 14, 2021
0347c36
Refactoring and update test file.
datumbox Apr 14, 2021
5661ac7
Adding ResNet backbone.
datumbox Apr 14, 2021
9e1da62
Minor refactoring.
datumbox Apr 15, 2021
61482c6
Remove inheritance of retina and general refactoring.
datumbox Apr 16, 2021
dc5b7d5
SSD should fix the input size.
datumbox Apr 16, 2021
82f8ddb
Fixing messages and comments.
datumbox Apr 17, 2021
fc90ffa
Merge branch 'master' into models/ssd
datumbox Apr 17, 2021
2cbd58d
Silently ignoring exception if test-only.
datumbox Apr 17, 2021
52940d4
Update comments.
datumbox Apr 18, 2021
db432f6
Update regression loss.
datumbox Apr 19, 2021
9f221ee
Merge branch 'master' into models/ssd
datumbox Apr 21, 2021
ff6ba4a
Restore Xavier init everywhere, update the negative sampling method, …
datumbox Apr 22, 2021
458d01e
Merge branch 'master' into models/ssd
datumbox Apr 23, 2021
84d81f4
Merge branch 'master' into models/ssd
datumbox Apr 23, 2021
88bd38f
Fixing tests.
datumbox Apr 23, 2021
fad5508
Refactor to move the losses from the Head to the SSD.
datumbox Apr 26, 2021
38e6e72
Removing resnet50 ssd version.
datumbox Apr 26, 2021
30de463
Adding support for best performing backbone and its config.
datumbox Apr 26, 2021
cdcbbcd
Refactor and clean up the API.
datumbox Apr 26, 2021
efebeb5
Fix lint
datumbox Apr 26, 2021
90e7b67
Update todos and comments.
datumbox Apr 26, 2021
8ec186e
Adding RandomHorizontalFlip and RandomIoUCrop transforms.
datumbox Apr 26, 2021
ebb7f90
Adding necessary checks to our tranforms.
datumbox Apr 26, 2021
92552de
Adding RandomZoomOut.
datumbox Apr 27, 2021
9b4b2ce
Adding RandomPhotometricDistort.
datumbox Apr 27, 2021
6f0a61e
Moving Detection transforms to references.
datumbox Apr 27, 2021
6ce9bd4
Update presets
datumbox Apr 27, 2021
60c6f72
fix lint
datumbox Apr 27, 2021
a818cc6
Merge branch 'master' into models/ssd
datumbox Apr 27, 2021
ff83c2d
leave compose and object
datumbox Apr 27, 2021
2423a2a
Adding scaling for completeness.
datumbox Apr 27, 2021
017c634
Adding params in the repr
datumbox Apr 27, 2021
3669795
Remove unnecessary import.
datumbox Apr 27, 2021
20e5839
Merge branch 'master' into models/ssd
datumbox Apr 27, 2021
75f578b
Merge branch 'master' into models/ssd
datumbox Apr 27, 2021
0f581d3
minor refactoring
datumbox Apr 28, 2021
3fb1e0b
Remove unnecessary call.
datumbox Apr 28, 2021
1084847
Give better names to DBox* classes
datumbox Apr 28, 2021
57140bb
Port num_anchors estimation in generator
datumbox Apr 28, 2021
8942dd0
Remove rescaling and fix presets
datumbox Apr 28, 2021
517c1da
Add the ability to pass a custom head and refactoring.
datumbox Apr 28, 2021
2deb51e
fix lint
datumbox Apr 28, 2021
937fde8
Merge branch 'master' into models/ssd
datumbox Apr 28, 2021
02a2af5
Fix unit-test
datumbox Apr 28, 2021
2befe43
Update todos.
datumbox Apr 28, 2021
a167edc
Change mean values.
datumbox Apr 28, 2021
7c4d70d
Change the default parameter of SSD to train the full VGG16 and remov…
datumbox Apr 29, 2021
a62d4e6
Adding documentation
datumbox Apr 29, 2021
bc8063a
Adding weights and updating readmes.
datumbox Apr 29, 2021
b2d5ec9
Merge branch 'master' into models/ssd
datumbox Apr 29, 2021
4760197
Update the model weights with a more performing model.
datumbox Apr 30, 2021
3dce96d
Merge branch 'master' into models/ssd
datumbox Apr 30, 2021
365d1ef
Adding doc for head.
datumbox Apr 30, 2021
06477d6
Merge branch 'master' into models/ssd
datumbox Apr 30, 2021
6c94ff0
Restore import.
datumbox Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,18 @@ Object Detection, Instance Segmentation and Person Keypoint Detection
The models subpackage contains definitions for the following model
architectures for detection:

- `Faster R-CNN ResNet-50 FPN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN ResNet-50 FPN <https://arxiv.org/abs/1703.06870>`_
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_

The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision.

The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images so that they have a minimum size
of ``800``. This option can be changed by passing the option ``min_size``
to the constructor of the models.
The models internally resize the images but the behaviour varies depending
on the model. Check the constructor of the models for more information.


For object detection and instance segmentation, the pre-trained
Expand Down Expand Up @@ -425,6 +426,7 @@ Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD VGG16 25.1 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========

Expand Down Expand Up @@ -483,6 +485,7 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD VGG16 0.2093 0.0744 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Expand All @@ -502,6 +505,12 @@ RetinaNet
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn


SSD
------------

.. autofunction:: torchvision.models.detection.ssd300_vgg16


Mask R-CNN
----------

Expand Down
8 changes: 8 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```

### SSD VGG16
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
```


### Mask R-CNN
```
Expand Down
22 changes: 16 additions & 6 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@


class DetectionPresetTrain:
def __init__(self, hflip_prob=0.5):
trans = [T.ToTensor()]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))

self.transforms = T.Compose(trans)
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
datumbox marked this conversation as resolved.
Show resolved Hide resolved
if data_augmentation == 'hflip':
self.transforms = T.Compose([
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
elif data_augmentation == 'ssd':
self.transforms = T.Compose([
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
])
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

def __call__(self, img, target):
return self.transforms(img, target)
Expand Down
10 changes: 6 additions & 4 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def get_dataset(name, image_set, transform, data_path):
return ds, num_classes


def get_transform(train):
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
def get_transform(train, data_augmentation):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()


def main(args):
Expand All @@ -60,8 +60,9 @@ def main(args):
# Data loading code
print("Loading data")

dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation),
args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)

print("Creating data loaders")
if args.distributed:
Expand Down Expand Up @@ -179,6 +180,7 @@ def main(args):
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down
230 changes: 210 additions & 20 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import random
import torch
import torchvision

from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional


def _flip_coco_person_keypoints(kps, width):
Expand All @@ -23,27 +27,213 @@ def __call__(self, image, target):
return image, target


class RandomHorizontalFlip(object):
def __init__(self, prob):
self.prob = prob

def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target


class ToTensor(object):
def __call__(self, image, target):
class ToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
return image, target


class RandomIoUCrop(nn.Module):
datumbox marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None:
raise ValueError("The targets can't be None for this transform.")

if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_w, orig_h = F._get_image_size(image)

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return image, target

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue

# check for any valid boxes with centers within the crop area
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue

# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]],
dtype=boxes.dtype, device=boxes.device))
if ious.max() < min_jaccard_overlap:
continue
Comment on lines +114 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double-checked the logic and it seems good to me.

For the future, we might be able to avoid some of the excessive continue by more carefully selecting the sampling.

For example, in the first block we can sample the aspect ratio in log-scale so that the aspect ration will be correct from the beginning, and then sample one value for the scale.
The same can be done for the crop (sampling so that none of the values are zero after rounding).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, this can definitely be improved. I've implemented it straight as originally described and cross-referencing it with the original implementation to be sure as many similar implementations online are bugged. I would not touch this until there are proper unit-tests in place to ensure we maintain the same behaviour as this transform was crucial for hitting the accuracy reported on the paper.


# keep only valid boxes and perform cropping
target["boxes"] = boxes
target["labels"] = target["labels"][is_within_crop_area]
target["boxes"][:, 0::2] -= left
target["boxes"][:, 1::2] -= top
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
image = F.crop(image, top, left, new_h, new_w)

return image, target


class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5):
super().__init__()
if fill is None:
fill = [0., 0., 0.]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
self.p = p

@torch.jit.unused
def _get_fill_value(self, is_pil):
# type: (bool) -> int
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

if torch.rand(1) < self.p:
return image, target

orig_w, orig_h = F._get_image_size(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)

r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)

if torch.jit.is_scripting():
fill = 0
else:
fill = self._get_fill_value(F._is_pil_image(image))

image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor):
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \
image[..., :, (left + orig_w):] = v

if target is not None:
target["boxes"][:, 0::2] += left
target["boxes"][:, 1::2] += top

return image, target


class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5):
super().__init__()
self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast)
self._hue = T.ColorJitter(hue=hue)
self._saturation = T.ColorJitter(saturation=saturation)
self.p = p

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
elif image.ndimension() == 2:
image = image.unsqueeze(0)

r = torch.rand(7)

if r[0] < self.p:
image = self._brightness(image)

contrast_before = r[1] < 0.5
if contrast_before:
if r[2] < self.p:
image = self._contrast(image)

if r[3] < self.p:
image = self._saturation(image)

if r[4] < self.p:
image = self._hue(image)

if not contrast_before:
if r[5] < self.p:
image = self._contrast(image)

if r[6] < self.p:
channels = F._get_image_num_channels(image)
permutation = torch.randperm(channels)

is_pil = F._is_pil_image(image)
if is_pil:
image = F.to_tensor(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)

return image, target
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_available_video_models():
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
}


Expand Down
Loading