Skip to content

Commit

Permalink
Introduce coco evaluator metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 13, 2021
1 parent a22ba8f commit f122aa7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 47 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ ipython
# pycocotools need python3.7 as minimal
# pip install -U pycocotools>=2.0.2 # corresponds to https://github.com/ppwwyyxx/cocoapi
pytorch-lightning>=1.1.1
torchmetrics
6 changes: 5 additions & 1 deletion test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.utils.data
from torch import Tensor

from yolort.data import DetectionDataModule
from yolort.data import COCOEvaluator, DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
from yolort.utils import prepare_coco128
Expand Down Expand Up @@ -75,3 +75,7 @@ def test_prepare_coco128(self):
prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

def test_coco_evaluator(self):
coco_evaluator = COCOEvaluator()
pass
14 changes: 14 additions & 0 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def test_train_one_epoch(self):
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)

def test_test_dataloaders(self):
# Config dataset
num_samples = 128
batch_size = 4
# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=num_samples)
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
# Load model
model = yolov5s(pretrained=True)
model.eval()
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=data_module.val_dataloader(batch_size=batch_size))

def test_predict_with_vanilla_model(self):
# Set image inputs
img_name = "test/assets/zidane.jpg"
Expand Down
1 change: 1 addition & 0 deletions yolort/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from .coco_eval import COCOEvaluator
from .data_pipeline import DataPipeline
from .data_module import DetectionDataModule, VOCDetectionDataModule, CocoDetectionDataModule
83 changes: 40 additions & 43 deletions yolort/data/coco_eval.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO evaluator that works in distributed mode.
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
The difference is that there is less copy-pasting from pycocotools
in the end of the file, as python3 can suppress prints with contextlib
"""

# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import os
import copy
import contextlib

import numpy as np
import copy

import torch

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO

from ..utils.misc import all_gather


class CocoEvaluator(object):
def __init__(self, coco_gt, iou_types):
from torchmetrics import Metric

from typing import List, Any, Callable, Optional


class COCOEvaluator(Metric):
"""
COCO evaluator that works in distributed mode.
"""
def __init__(
self,
coco_gt: Any,
iou_types: List[str] = ['bbox'],
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
Expand All @@ -44,38 +55,29 @@ def update(self, predictions):
# suppress pycocotools prints
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
self.coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()

coco_eval = self.coco_eval[iou_type]

coco_eval.cocoDt = coco_dt
coco_eval.cocoDt = self.coco_dt
coco_eval.params.imgIds = list(img_ids)
img_ids, eval_imgs = evaluate(coco_eval)

self.eval_imgs[iou_type].append(eval_imgs)

def synchronize_between_processes(self):
for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])

def accumulate(self):
def compute(self):
for coco_eval in self.coco_eval.values():
coco_eval.accumulate()

def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
coco_eval.summarize()

def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
raise ValueError(f"Unknown iou type {iou_type}, fell free to report on GitHub issues")

def prepare_for_coco_detection(self, predictions):
coco_results = []
Expand Down Expand Up @@ -140,14 +142,12 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################


def evaluate(self):
'''
From pycocotools, just removed the prints and fixed a Python3 bug about unicode
not defined. Mostly copy-paste from
<https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py>
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
Expand All @@ -165,18 +165,19 @@ def evaluate(self):
p.maxDets = sorted(p.maxDets)
self.params = p

self._prepare()
self._prepare() # bottleneck

# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]

if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks

self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}
(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
} # bottleneck

evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
Expand All @@ -192,7 +193,3 @@ def evaluate(self):
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
20 changes: 17 additions & 3 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import warnings
import argparse
import json

import torch
from torch import Tensor
from torchvision.ops import box_iou

from pytorch_lightning import LightningModule

from . import yolo
from .transform import GeneralizedYOLOTransform
from ..data import DetectionDataModule, DataPipeline
from ..data import DetectionDataModule, DataPipeline, COCOEvaluator

from typing import Any, List, Dict, Tuple, Optional

Expand All @@ -31,6 +29,7 @@ def __init__(
num_classes: int = 80,
min_size: int = 320,
max_size: int = 416,
coco_gt: Optional[Any] = None,
**kwargs: Any,
):
"""
Expand All @@ -52,6 +51,9 @@ def __init__(

self._data_pipeline = None

# metrics
self.evaluator = None if coco_gt else COCOEvaluator(coco_gt, iou_types=["bbox"])

# used only on torchscript mode
self._has_warned = False

Expand Down Expand Up @@ -137,6 +139,18 @@ def training_step(self, batch, batch_idx):
self.log_dict(loss_dict, on_step=True, on_epoch=True, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
"""
The test step.
"""
preds = self._forward_impl(*batch)
# log step metric
self.log('eval_step', self.evaluator(preds))

def test_epoch_end(self, outs):
# log epoch metric
self.log('coco_eval', self.evaluator.compute())

@torch.no_grad()
def predict(
self,
Expand Down

0 comments on commit f122aa7

Please sign in to comment.