Skip to content

Commit

Permalink
Refactor dataloader getter in UnitTest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 13, 2021
1 parent f122aa7 commit a78e4ac
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 81 deletions.
18 changes: 7 additions & 11 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
from pathlib import Path
import unittest

import torch.utils.data
from torch import Tensor

from yolort.data import COCOEvaluator, DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset
from yolort.data.transforms import default_train_transforms
from yolort.utils.dataset_utils import prepare_coco128, get_data_loader, DummyCOCODetectionDataset

from typing import Dict


class DataPipelineTester(unittest.TestCase):
def test_vanilla_dataloader(self):
def test_vanilla_dataset(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
Expand All @@ -33,11 +30,9 @@ def test_vanilla_dataloader(self):
self.assertIsInstance(image, Tensor)
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
def test_vanilla_dataloader(self):
batch_size = 8
data_loader = get_data_loader(mode='train', batch_size=batch_size)
# Test the dataloader
images, targets = next(iter(data_loader))

Expand Down Expand Up @@ -76,6 +71,7 @@ def test_prepare_coco128(self):
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

@unittest.skip("Currently it isn't well implemented")
def test_coco_evaluator(self):
coco_evaluator = COCOEvaluator()
pass
42 changes: 9 additions & 33 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest

import torch
Expand All @@ -10,16 +9,12 @@
import pytorch_lightning as pl

from yolort.data import DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms

from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset
from yolort.utils.dataset_utils import DummyCOCODetectionDataset, get_data_loader

from typing import Dict

Expand Down Expand Up @@ -65,26 +60,9 @@ def test_train_with_vanilla_module(self):
# Define the device
device = torch.device('cpu')

# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

batch_size = 4

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
train_dataloader = get_data_loader(mode='train')
# Sample a pair of images/targets
images, targets = next(iter(data_loader))
images, targets = next(iter(train_dataloader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

Expand All @@ -109,19 +87,17 @@ def test_train_one_epoch(self):
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)

def test_test_dataloaders(self):
# Config dataset
num_samples = 128
batch_size = 4
# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=num_samples)
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
@unittest.skip("Currently it isn't well implemented")
def test_test_with_dataloader(self):
# Get dataloader to test
val_dataloader = get_data_loader(mode='val')

# Load model
model = yolov5s(pretrained=True)
model.eval()
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=data_module.val_dataloader(batch_size=batch_size))
trainer.test(model, test_dataloaders=val_dataloader)

def test_predict_with_vanilla_model(self):
# Set image inputs
Expand Down
1 change: 0 additions & 1 deletion yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .flash_utils import get_callable_dict
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import update_module_state_from_ultralytics
from .file_utils import prepare_coco128
77 changes: 74 additions & 3 deletions test/dataset_utils.py → yolort/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,82 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import random
import torch
from torch.utils.data import Dataset
from pathlib import Path, PosixPath
from zipfile import ZipFile

import torch
from torchvision import ops

from ..data.coco import CocoDetection
from ..data.transforms import (
collate_fn,
default_train_transforms,
default_val_transforms,
)


def prepare_coco128(
data_path: PosixPath,
dirname: str = 'coco128',
) -> None:
"""
Prepare coco128 dataset to test.
Args:
data_path (PosixPath): root path of coco128 dataset.
dirname (str): the directory name of coco128 dataset. Default: 'coco128'.
"""
if not data_path.is_dir():
print(f'Create a new directory: {data_path}')
data_path.mkdir(parents=True, exist_ok=True)

zip_path = data_path / 'coco128.zip'
coco128_url = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip'
if not zip_path.is_file():
print(f'Downloading coco128 datasets form {coco128_url}')
torch.hub.download_url_to_file(coco128_url, zip_path, hash_prefix='a67d2887')

coco128_path = data_path / dirname
if not coco128_path.is_dir():
print(f'Unzipping dataset to {coco128_path}')
with ZipFile(zip_path, 'r') as zip_obj:
zip_obj.extractall(data_path)


def get_data_loader(mode: str = 'train', batch_size: int = 4):
# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

if mode == 'train':
dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
elif mode == 'val':
dataset = CocoDetection(image_root, annotation_file, default_val_transforms())
else:
raise NotImplementedError(f"Currently not support {mode} mode")

# We adopt the sequential sampler in order to repeat the experiment
sampler = torch.utils.data.SequentialSampler(dataset)

loader = torch.utils.data.DataLoader(
dataset,
batch_size,
sampler=sampler,
drop_last=False,
collate_fn=collate_fn,
num_workers=0,
)

return loader


class DummyCOCODetectionDataset(Dataset):
class DummyCOCODetectionDataset(torch.utils.data.Dataset):
"""
Generate a dummy dataset for detection
Example::
Expand Down
33 changes: 0 additions & 33 deletions yolort/utils/file_utils.py

This file was deleted.

0 comments on commit a78e4ac

Please sign in to comment.