Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation methods in tasks #93

Merged
merged 2 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@


class DataPipelineTester(unittest.TestCase):
def test_vanilla_dataset(self):
def test_get_dataset(self):
# Acquire the images and labels from the coco128 dataset
dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
# Test the datasets
image, target = next(iter(dataset))
image, target = next(iter(train_dataset))
self.assertIsInstance(image, Tensor)
self.assertIsInstance(target, Dict)

def test_vanilla_dataloader(self):
def test_get_dataloader(self):
batch_size = 8
data_loader = data_helper.get_dataloader(data_root='data-bin', mode='train', batch_size=batch_size)
# Test the dataloader
Expand All @@ -38,7 +38,7 @@ def test_vanilla_dataloader(self):
def test_detection_data_module(self):
# Setup the DataModule
batch_size = 4
train_dataset = data_helper.DummyCOCODetectionDataset(num_samples=128)
train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
self.assertEqual(data_module.batch_size, batch_size)

Expand Down
6 changes: 4 additions & 2 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def test_train_with_vanilla_module(self):

def test_training_step(self):
# Setup the DataModule
train_dataset = data_helper.DummyCOCODetectionDataset(num_samples=128)
data_module = DetectionDataModule(train_dataset, batch_size=16)
data_path = 'data-bin'
train_dataset = data_helper.get_dataset(data_root=data_path, mode='train')
val_dataset = data_helper.get_dataset(data_root=data_path, mode='val')
data_module = DetectionDataModule(train_dataset, val_dataset, batch_size=16)
# Load model
model = yolov5s()
model.train()
Expand Down
69 changes: 0 additions & 69 deletions yolort/data/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,72 +89,3 @@ def get_dataloader(data_root: str, mode: str = 'val', batch_size: int = 4):
)

return loader


class DummyCOCODetectionDataset(torch.utils.data.Dataset):
"""
Generate a dummy dataset for detection
Example::
>>> ds = DummyDetectionDataset()
>>> dl = DataLoader(ds, batch_size=16)
"""
def __init__(
self,
im_size_min: int = 320,
im_size_max: int = 640,
num_classes: int = 5,
num_boxes_max: int = 12,
class_start: int = 0,
num_samples: int = 1000,
box_fmt: str = "cxcywh",
normalize: bool = True,
):
"""
Args:
im_size_min: Minimum image size
im_size_max: Maximum image size
num_classes: Number of classes for image.
num_boxes_max: Maximum number of boxes per images
num_samples: how many samples to use in this dataset.
box_fmt: Format of Bounding boxes, supported : "xyxy", "xywh", "cxcywh"
"""
super().__init__()
self.im_size_min = im_size_min
self.im_size_max = im_size_max
self.num_classes = num_classes
self.num_boxes_max = num_boxes_max
self.num_samples = num_samples
self.box_fmt = box_fmt
self.class_start = class_start
self.class_end = self.class_start + self.num_classes
self.normalize = normalize

def __len__(self):
return self.num_samples

@staticmethod
def _random_bbox(img_shape):
_, h, w = img_shape
xs = torch.randint(w, (2,), dtype=torch.float32)
ys = torch.randint(h, (2,), dtype=torch.float32)

# A small hacky fix to avoid degenerate boxes.
return [min(xs), min(ys), max(xs) + 1, max(ys) + 1]

def __getitem__(self, idx: int):
h = random.randint(self.im_size_min, self.im_size_max)
w = random.randint(self.im_size_min, self.im_size_max)
img_shape = (3, h, w)
img = torch.rand(img_shape)

num_boxes = random.randint(1, self.num_boxes_max)
labels = torch.randint(self.class_start, self.class_end, (num_boxes,), dtype=torch.long)

boxes = torch.tensor([self._random_bbox(img_shape) for _ in range(num_boxes)], dtype=torch.float32)
boxes = ops.clip_boxes_to_image(boxes, (h, w))
# No problems if we pass same in_fmt and out_fmt, it is covered by box_convert
boxes = ops.box_convert(boxes, in_fmt="xyxy", out_fmt=self.box_fmt)
if self.normalize:
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
image_id = torch.tensor([idx])
return img, {"image_id": image_id, "boxes": boxes, "labels": labels}
12 changes: 11 additions & 1 deletion yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
import torch
from torch import Tensor
import torch.nn.functional as F
from torchvision.ops import box_convert
from torchvision.ops import box_convert, box_iou

from typing import Tuple


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


class BoxCoder(object):
"""
This class encodes and decodes a set of bounding boxes into
Expand Down
14 changes: 14 additions & 0 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from . import yolo
from .transform import GeneralizedYOLOTransform
from ._utils import _evaluate_iou
from ..data import DetectionDataModule, DataPipeline, COCOEvaluator

from typing import Any, List, Dict, Tuple, Optional, Union
Expand Down Expand Up @@ -142,6 +143,19 @@ def training_step(self, batch, batch_idx):
self.log_dict(loss_dict, on_step=True, on_epoch=True, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
preds = self._forward_impl(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, preds)]).mean()
outs = {"val_iou": iou}
self.log_dict(outs, on_step=True, on_epoch=True, prog_bar=True)
return outs

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
self.log("avg_val_iou", avg_iou)

def test_step(self, batch, batch_idx):
"""
The test step.
Expand Down