Skip to content

Commit

Permalink
Fix datamodule in unit-test and remove unused codes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 19, 2021
1 parent b0bbedb commit 5b461e1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 76 deletions.
10 changes: 5 additions & 5 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@


class DataPipelineTester(unittest.TestCase):
def test_vanilla_dataset(self):
def test_get_dataset(self):
# Acquire the images and labels from the coco128 dataset
dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
# Test the datasets
image, target = next(iter(dataset))
image, target = next(iter(train_dataset))
self.assertIsInstance(image, Tensor)
self.assertIsInstance(target, Dict)

def test_vanilla_dataloader(self):
def test_get_dataloader(self):
batch_size = 8
data_loader = data_helper.get_dataloader(data_root='data-bin', mode='train', batch_size=batch_size)
# Test the dataloader
Expand All @@ -38,7 +38,7 @@ def test_vanilla_dataloader(self):
def test_detection_data_module(self):
# Setup the DataModule
batch_size = 4
train_dataset = data_helper.DummyCOCODetectionDataset(num_samples=128)
train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
self.assertEqual(data_module.batch_size, batch_size)

Expand Down
6 changes: 4 additions & 2 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def test_train_with_vanilla_module(self):

def test_training_step(self):
# Setup the DataModule
train_dataset = data_helper.DummyCOCODetectionDataset(num_samples=128)
data_module = DetectionDataModule(train_dataset, batch_size=16)
data_path = 'data-bin'
train_dataset = data_helper.get_dataset(data_root=data_path, mode='train')
val_dataset = data_helper.get_dataset(data_root=data_path, mode='val')
data_module = DetectionDataModule(train_dataset, val_dataset, batch_size=16)
# Load model
model = yolov5s()
model.train()
Expand Down
69 changes: 0 additions & 69 deletions yolort/data/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,72 +89,3 @@ def get_dataloader(data_root: str, mode: str = 'val', batch_size: int = 4):
)

return loader


class DummyCOCODetectionDataset(torch.utils.data.Dataset):
"""
Generate a dummy dataset for detection
Example::
>>> ds = DummyDetectionDataset()
>>> dl = DataLoader(ds, batch_size=16)
"""
def __init__(
self,
im_size_min: int = 320,
im_size_max: int = 640,
num_classes: int = 5,
num_boxes_max: int = 12,
class_start: int = 0,
num_samples: int = 1000,
box_fmt: str = "cxcywh",
normalize: bool = True,
):
"""
Args:
im_size_min: Minimum image size
im_size_max: Maximum image size
num_classes: Number of classes for image.
num_boxes_max: Maximum number of boxes per images
num_samples: how many samples to use in this dataset.
box_fmt: Format of Bounding boxes, supported : "xyxy", "xywh", "cxcywh"
"""
super().__init__()
self.im_size_min = im_size_min
self.im_size_max = im_size_max
self.num_classes = num_classes
self.num_boxes_max = num_boxes_max
self.num_samples = num_samples
self.box_fmt = box_fmt
self.class_start = class_start
self.class_end = self.class_start + self.num_classes
self.normalize = normalize

def __len__(self):
return self.num_samples

@staticmethod
def _random_bbox(img_shape):
_, h, w = img_shape
xs = torch.randint(w, (2,), dtype=torch.float32)
ys = torch.randint(h, (2,), dtype=torch.float32)

# A small hacky fix to avoid degenerate boxes.
return [min(xs), min(ys), max(xs) + 1, max(ys) + 1]

def __getitem__(self, idx: int):
h = random.randint(self.im_size_min, self.im_size_max)
w = random.randint(self.im_size_min, self.im_size_max)
img_shape = (3, h, w)
img = torch.rand(img_shape)

num_boxes = random.randint(1, self.num_boxes_max)
labels = torch.randint(self.class_start, self.class_end, (num_boxes,), dtype=torch.long)

boxes = torch.tensor([self._random_bbox(img_shape) for _ in range(num_boxes)], dtype=torch.float32)
boxes = ops.clip_boxes_to_image(boxes, (h, w))
# No problems if we pass same in_fmt and out_fmt, it is covered by box_convert
boxes = ops.box_convert(boxes, in_fmt="xyxy", out_fmt=self.box_fmt)
if self.normalize:
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
image_id = torch.tensor([idx])
return img, {"image_id": image_id, "boxes": boxes, "labels": labels}

0 comments on commit 5b461e1

Please sign in to comment.