diff --git a/docs/en/user_guides/useful_tools.md b/docs/en/user_guides/useful_tools.md index 34a2f2c99e8..5cce0cb97e6 100644 --- a/docs/en/user_guides/useful_tools.md +++ b/docs/en/user_guides/useful_tools.md @@ -509,3 +509,78 @@ python tools/analysis_tools/confusion_matrix.py ${CONFIG} ${DETECTION_RESULTS} And you will get a confusion matrix like this: ![confusion_matrix_example](https://user-images.githubusercontent.com/12907710/140513068-994cdbf4-3a4a-48f0-8fd8-2830d93fd963.png) + +## COCO Separated & Occluded Mask Metric + +Detecting occluded objects still remains a challenge for state-of-the-art object detectors. +We implemented the metric presented in paper [A Tri-Layer Plugin to Improve Occluded Detection](https://arxiv.org/abs/2210.10046) to calculate the recall of separated and occluded masks. + +There are two ways to use this metric: + +### Offline evaluation + +We provide a script to calculate the metric with a dumped prediction file. + +First, use the `tools/test.py` script to dump the detection results: + +```shell +python tools/test.py ${CONFIG} ${MODEL_PATH} --out results.pkl +``` + +Then, run the `tools/analysis_tools/coco_occluded_separated_recall.py` script to get the recall of separated and occluded masks: + +```shell +python tools/analysis_tools/coco_occluded_separated_recall.py results.pkl --out occluded_separated_recall.json +``` + +The output should be like this: + +``` +loading annotations into memory... +Done (t=0.51s) +creating index... +index created! +processing detection results... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 5000/5000, 109.3 task/s, elapsed: 46s, ETA: 0s +computing occluded mask recall... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 5550/5550, 780.5 task/s, elapsed: 7s, ETA: 0s +COCO occluded mask recall: 58.79% +COCO occluded mask success num: 3263 +computing separated mask recall... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 3522/3522, 778.3 task/s, elapsed: 5s, ETA: 0s +COCO separated mask recall: 31.94% +COCO separated mask success num: 1125 + ++-----------+--------+-------------+ +| mask type | recall | num correct | ++-----------+--------+-------------+ +| occluded | 58.79% | 3263 | +| separated | 31.94% | 1125 | ++-----------+--------+-------------+ +Evaluation results have been saved to occluded_separated_recall.json. +``` + +### Online evaluation + +We implement `CocoOccludedSeparatedMetric` which inherits from the `CocoMetic`. +To evaluate the recall of separated and occluded masks during training, just replace the evaluator metric type with `'CocoOccludedSeparatedMetric'` in your config: + +```python +val_evaluator = dict( + type='CocoOccludedSeparatedMetric', # modify this + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False) +test_evaluator = val_evaluator +``` + +Please cite the paper if you use this metric: + +```latex +@article{zhan2022triocc, + title={A Tri-Layer Plugin to Improve Occluded Detection}, + author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew}, + journal={British Machine Vision Conference}, + year={2022} +} +``` diff --git a/docs/zh_cn/user_guides/useful_tools.md b/docs/zh_cn/user_guides/useful_tools.md index 7f88f79bb4a..e2b2d626d70 100644 --- a/docs/zh_cn/user_guides/useful_tools.md +++ b/docs/zh_cn/user_guides/useful_tools.md @@ -485,3 +485,78 @@ python tools/analysis_tools/confusion_matrix.py ${CONFIG} ${DETECTION_RESULTS} 最后你可以得到如图的混淆矩阵: ![confusion_matrix_example](https://user-images.githubusercontent.com/12907710/140513068-994cdbf4-3a4a-48f0-8fd8-2830d93fd963.png) + +## COCO 分离和遮挡实例分割性能评估 + +对于最先进的目标检测器来说,检测被遮挡的物体仍然是一个挑战。 +我们实现了论文 [A Tri-Layer Plugin to Improve Occluded Detection](https://arxiv.org/abs/2210.10046) 中提出的指标来计算分离和遮挡目标的召回率。 + +使用此评价指标有两种方法: + +### 离线评测 + +我们提供了一个脚本对存储后的检测结果文件计算指标。 + +首先,使用 `tools/test.py` 脚本存储检测结果: + +```shell +python tools/test.py ${CONFIG} ${MODEL_PATH} --out results.pkl +``` + +然后,运行 `tools/analysis_tools/coco_occluded_separated_recall.py` 脚本来计算分离和遮挡目标的掩码的召回率: + +```shell +python tools/analysis_tools/coco_occluded_separated_recall.py results.pkl --out occluded_separated_recall.json +``` + +输出如下: + +``` +loading annotations into memory... +Done (t=0.51s) +creating index... +index created! +processing detection results... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 5000/5000, 109.3 task/s, elapsed: 46s, ETA: 0s +computing occluded mask recall... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 5550/5550, 780.5 task/s, elapsed: 7s, ETA: 0s +COCO occluded mask recall: 58.79% +COCO occluded mask success num: 3263 +computing separated mask recall... +[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 3522/3522, 778.3 task/s, elapsed: 5s, ETA: 0s +COCO separated mask recall: 31.94% +COCO separated mask success num: 1125 + ++-----------+--------+-------------+ +| mask type | recall | num correct | ++-----------+--------+-------------+ +| occluded | 58.79% | 3263 | +| separated | 31.94% | 1125 | ++-----------+--------+-------------+ +Evaluation results have been saved to occluded_separated_recall.json. +``` + +### 在线评测 + +我们实现继承自 `CocoMetic` 的 `CocoOccludedSeparatedMetric`。 +要在训练期间评估分离和遮挡掩码的召回率,只需在配置中将 evaluator 类型替换为 `CocoOccludedSeparatedMetric`: + +```python +val_evaluator = dict( + type='CocoOccludedSeparatedMetric', # 修改此处 + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False) +test_evaluator = val_evaluator +``` + +如果您使用了此指标,请引用论文: + +```latex +@article{zhan2022triocc, + title={A Tri-Layer Plugin to Improve Occluded Detection}, + author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew}, + journal={British Machine Vision Conference}, + year={2022} +} +``` diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index 9b2ee7c31bc..da000e0d535 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cityscapes_metric import CityScapesMetric from .coco_metric import CocoMetric +from .coco_occluded_metric import CocoOccludedSeparatedMetric from .coco_panoptic_metric import CocoPanopticMetric from .crowdhuman_metric import CrowdHumanMetric +from .dump_det_results import DumpDetResults from .dump_proposals_metric import DumpProposals from .lvis_metric import LVISMetric from .openimages_metric import OpenImagesMetric @@ -10,5 +12,6 @@ __all__ = [ 'CityScapesMetric', 'CocoMetric', 'CocoPanopticMetric', 'OpenImagesMetric', - 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals' + 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', + 'CocoOccludedSeparatedMetric', 'DumpDetResults' ] diff --git a/mmdet/evaluation/metrics/coco_metric.py b/mmdet/evaluation/metrics/coco_metric.py index 67267442115..bd56803da3d 100644 --- a/mmdet/evaluation/metrics/coco_metric.py +++ b/mmdet/evaluation/metrics/coco_metric.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Sequence, Union import numpy as np +import torch from mmengine.evaluator import BaseMetric from mmengine.fileio import FileClient, dump, load from mmengine.logging import MMLogger @@ -350,7 +351,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: # encode mask to RLE if 'masks' in pred: result['masks'] = encode_mask_results( - pred['masks'].detach().cpu().numpy()) + pred['masks'].detach().cpu().numpy()) if isinstance( + pred['masks'], torch.Tensor) else pred['masks'] # some detectors use different scores for bbox and mask if 'mask_scores' in pred: result['mask_scores'] = pred['mask_scores'].cpu().numpy() diff --git a/mmdet/evaluation/metrics/coco_occluded_metric.py b/mmdet/evaluation/metrics/coco_occluded_metric.py new file mode 100644 index 00000000000..544ff4426ba --- /dev/null +++ b/mmdet/evaluation/metrics/coco_occluded_metric.py @@ -0,0 +1,211 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import os.path as osp +from typing import Dict, List, Optional, Union + +import mmengine +import numpy as np +from mmengine.fileio import load +from mmengine.logging import print_log +from pycocotools import mask as coco_mask +from terminaltables import AsciiTable + +from mmdet.registry import METRICS +from .coco_metric import CocoMetric + + +@METRICS.register_module() +class CocoOccludedSeparatedMetric(CocoMetric): + """Metric of separated and occluded masks which presented in paper `A Tri- + Layer Plugin to Improve Occluded Detection. + + `_. + + Separated COCO and Occluded COCO are automatically generated subsets of + COCO val dataset, collecting separated objects and partially occluded + objects for a large variety of categories. In this way, we define + occlusion into two major categories: separated and partially occluded. + + - Separation: target object segmentation mask is separated into distinct + regions by the occluder. + - Partial Occlusion: target object is partially occluded but the + segmentation mask is connected. + + These two new scalable real-image datasets are to benchmark a model's + capability to detect occluded objects of 80 common categories. + + Please cite the paper if you use this dataset: + + @article{zhan2022triocc, + title={A Tri-Layer Plugin to Improve Occluded Detection}, + author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew}, + journal={British Machine Vision Conference}, + year={2022} + } + + Args: + occluded_ann (str): Path to the occluded coco annotation file. + separated_ann (str): Path to the separated coco annotation file. + score_thr (float): Score threshold of the detection masks. + Defaults to 0.3. + iou_thr (float): IoU threshold for the recall calculation. + Defaults to 0.75. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', 'proposal', and 'proposal_fast'. + Defaults to 'bbox'. + """ + default_prefix: Optional[str] = 'coco' + + def __init__( + self, + *args, + occluded_ann: + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa + separated_ann: + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa + score_thr: float = 0.3, + iou_thr: float = 0.75, + metric: Union[str, List[str]] = ['bbox', 'segm'], + **kwargs) -> None: + super().__init__(*args, metric=metric, **kwargs) + # load from local file + if osp.isfile(occluded_ann) and not osp.isabs(occluded_ann): + occluded_ann = osp.join(self.data_root, occluded_ann) + if osp.isfile(separated_ann) and not osp.isabs(separated_ann): + separated_ann = osp.join(self.data_root, separated_ann) + self.occluded_ann = load(occluded_ann) + self.separated_ann = load(separated_ann) + self.score_thr = score_thr + self.iou_thr = iou_thr + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + coco_metric_res = super().compute_metrics(results) + eval_res = self.evaluate_occluded_separated(results) + coco_metric_res.update(eval_res) + return coco_metric_res + + def evaluate_occluded_separated(self, results: List[tuple]) -> dict: + """Compute the recall of occluded and separated masks. + + Args: + results (list[tuple]): Testing results of the dataset. + + Returns: + dict[str, float]: The recall of occluded and separated masks. + """ + dict_det = {} + print_log('processing detection results...') + prog_bar = mmengine.ProgressBar(len(results)) + for i in range(len(results)): + gt, dt = results[i] + img_id = dt['img_id'] + cur_img_name = self._coco_api.imgs[img_id]['file_name'] + if cur_img_name not in dict_det.keys(): + dict_det[cur_img_name] = [] + + for bbox, score, label, mask in zip(dt['bboxes'], dt['scores'], + dt['labels'], dt['masks']): + cur_binary_mask = coco_mask.decode(mask) + dict_det[cur_img_name].append([ + score, self.dataset_meta['classes'][label], + cur_binary_mask, bbox + ]) + dict_det[cur_img_name].sort( + key=lambda x: (-x[0], x[3][0], x[3][1]) + ) # rank by confidence from high to low, avoid same confidence + prog_bar.update() + print_log('\ncomputing occluded mask recall...', logger='current') + occluded_correct_num, occluded_recall = self.compute_recall( + dict_det, gt_ann=self.occluded_ann, is_occ=True) + print_log( + f'\nCOCO occluded mask recall: {occluded_recall:.2f}%', + logger='current') + print_log( + f'COCO occluded mask success num: {occluded_correct_num}', + logger='current') + print_log('computing separated mask recall...', logger='current') + separated_correct_num, separated_recall = self.compute_recall( + dict_det, gt_ann=self.separated_ann, is_occ=False) + print_log( + f'\nCOCO separated mask recall: {separated_recall:.2f}%', + logger='current') + print_log( + f'COCO separated mask success num: {separated_correct_num}', + logger='current') + table_data = [ + ['mask type', 'recall', 'num correct'], + ['occluded', f'{occluded_recall:.2f}%', occluded_correct_num], + ['separated', f'{separated_recall:.2f}%', separated_correct_num] + ] + table = AsciiTable(table_data) + print_log('\n' + table.table, logger='current') + return dict( + occluded_recall=occluded_recall, separated_recall=separated_recall) + + def compute_recall(self, + result_dict: dict, + gt_ann: list, + is_occ: bool = True) -> tuple: + """Compute the recall of occluded or separated masks. + + Args: + result_dict (dict): Processed mask results. + gt_ann (list): Occluded or separated coco annotations. + is_occ (bool): Whether the annotation is occluded mask. + Defaults to True. + Returns: + tuple: number of correct masks and the recall. + """ + correct = 0 + prog_bar = mmengine.ProgressBar(len(gt_ann)) + for iter_i in range(len(gt_ann)): + cur_item = gt_ann[iter_i] + cur_img_name = cur_item[0] + cur_gt_bbox = cur_item[3] + if is_occ: + cur_gt_bbox = [ + cur_gt_bbox[0], cur_gt_bbox[1], + cur_gt_bbox[0] + cur_gt_bbox[2], + cur_gt_bbox[1] + cur_gt_bbox[3] + ] + cur_gt_class = cur_item[1] + cur_gt_mask = coco_mask.decode(cur_item[4]) + + assert cur_img_name in result_dict.keys() + cur_detections = result_dict[cur_img_name] + + correct_flag = False + for i in range(len(cur_detections)): + cur_det_confidence = cur_detections[i][0] + if cur_det_confidence < self.score_thr: + break + cur_det_class = cur_detections[i][1] + if cur_det_class != cur_gt_class: + continue + cur_det_mask = cur_detections[i][2] + cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask) + if cur_iou >= self.iou_thr: + correct_flag = True + break + if correct_flag: + correct += 1 + prog_bar.update() + recall = correct / len(gt_ann) * 100 + return correct, recall + + def mask_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: + """Compute IoU between two masks.""" + mask1_area = np.count_nonzero(mask1 == 1) + mask2_area = np.count_nonzero(mask2 == 1) + intersection = np.count_nonzero(np.logical_and(mask1 == 1, mask2 == 1)) + iou = intersection / (mask1_area + mask2_area - intersection) + return iou diff --git a/mmdet/evaluation/metrics/dump_det_results.py b/mmdet/evaluation/metrics/dump_det_results.py new file mode 100644 index 00000000000..f3071d19a6a --- /dev/null +++ b/mmdet/evaluation/metrics/dump_det_results.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +from mmengine.evaluator import DumpResults +from mmengine.evaluator.metric import _to_cpu + +from mmdet.registry import METRICS +from mmdet.structures.mask import encode_mask_results + + +@METRICS.register_module() +class DumpDetResults(DumpResults): + """Dump model predictions to a pickle file for offline evaluation. + + Different from `DumpResults` in MMEngine, it compresses instance + segmentation masks into RLE format. + + Args: + out_file_path (str): Path of the dumped file. Must end with '.pkl' + or '.pickle'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + """ + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """transfer tensors in predictions to CPU.""" + data_samples = _to_cpu(data_samples) + for data_sample in data_samples: + # remove gt + data_sample.pop('gt_instances', None) + data_sample.pop('ignored_instances', None) + data_sample.pop('gt_panoptic_seg', None) + + if 'pred_instances' in data_sample: + pred = data_sample['pred_instances'] + # encode mask to RLE + if 'masks' in pred: + pred['masks'] = encode_mask_results(pred['masks'].numpy()) + if 'pred_panoptic_seg' in data_sample: + warnings.warn( + 'Panoptic segmentation map will not be compressed. ' + 'The dumped file will be extremely large! ' + 'Suggest using `CocoPanopticMetric` to save the coco ' + 'format json and segmentation png files directly.') + self.results.extend(data_samples) diff --git a/tests/test_evaluation/test_metrics/test_coco_occluded_metric.py b/tests/test_evaluation/test_metrics/test_coco_occluded_metric.py new file mode 100644 index 00000000000..29c4d568554 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_coco_occluded_metric.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from tempfile import TemporaryDirectory + +import mmengine +import numpy as np + +from mmdet.datasets import CocoDataset +from mmdet.evaluation import CocoOccludedSeparatedMetric + + +def test_coco_occluded_separated_metric(): + ann = [[ + 'fake1.jpg', 'person', 8, [219.9, 176.12, 11.14, 34.23], { + 'size': [480, 640], + 'counts': b'nYW31n>2N2FNbA48Kf=?XBDe=m0OM3M4YOPB8_>L4JXao5' + } + ]] * 3 + dummy_mask = np.zeros((10, 10), dtype=np.uint8) + dummy_mask[:5, :5] = 1 + rle = { + 'size': [480, 640], + 'counts': b'nYW31n>2N2FNbA48Kf=?XBDe=m0OM3M4YOPB8_>L4JXao5' + } + res = [(None, + dict( + img_id=0, + bboxes=np.array([[50, 60, 70, 80]] * 2), + masks=[rle] * 2, + labels=np.array([0, 1], dtype=np.int64), + scores=np.array([0.77, 0.77])))] * 3 + + tempdir = TemporaryDirectory() + ann_path = osp.join(tempdir.name, 'coco_occluded.pkl') + mmengine.dump(ann, ann_path) + + metric = CocoOccludedSeparatedMetric( + ann_file='tests/data/coco_sample.json', + occluded_ann=ann_path, + separated_ann=ann_path, + metric=[]) + metric.dataset_meta = CocoDataset.METAINFO + eval_res = metric.compute_metrics(res) + assert isinstance(eval_res, dict) + assert eval_res['occluded_recall'] == 100 + assert eval_res['separated_recall'] == 100 diff --git a/tests/test_evaluation/test_metrics/test_dump_det_results.py b/tests/test_evaluation/test_metrics/test_dump_det_results.py new file mode 100644 index 00000000000..fc793229730 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_dump_det_results.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +import torch +from mmengine.fileio import load +from torch import Tensor + +from mmdet.evaluation import DumpDetResults +from mmdet.structures.mask import encode_mask_results + + +class TestDumpResults(TestCase): + + def test_init(self): + with self.assertRaisesRegex(ValueError, + 'The output file must be a pkl file.'): + DumpDetResults(out_file_path='./results.json') + + def test_process(self): + metric = DumpDetResults(out_file_path='./results.pkl') + data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, data_samples) + self.assertEqual(len(metric.results), 1) + self.assertEqual(metric.results[0]['data'][0].device, + torch.device('cpu')) + + metric = DumpDetResults(out_file_path='./results.pkl') + masks = torch.zeros(10, 10, 4) + data_samples = [ + dict(pred_instances=dict(masks=masks), gt_instances=[]) + ] + metric.process(None, data_samples) + self.assertEqual(len(metric.results), 1) + self.assertEqual(metric.results[0]['pred_instances']['masks'], + encode_mask_results(masks.numpy())) + self.assertNotIn('gt_instances', metric.results[0]) + + def test_compute_metrics(self): + temp_dir = tempfile.TemporaryDirectory() + path = osp.join(temp_dir.name, 'results.pkl') + metric = DumpDetResults(out_file_path=path) + data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, data_samples) + metric.compute_metrics(metric.results) + self.assertTrue(osp.isfile(path)) + + results = load(path) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['data'][0].device, torch.device('cpu')) + + temp_dir.cleanup() diff --git a/tools/analysis_tools/coco_occluded_separated_recall.py b/tools/analysis_tools/coco_occluded_separated_recall.py new file mode 100644 index 00000000000..e61f2ccd945 --- /dev/null +++ b/tools/analysis_tools/coco_occluded_separated_recall.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +import mmengine +from mmengine.logging import print_log + +from mmdet.datasets import CocoDataset +from mmdet.evaluation import CocoOccludedSeparatedMetric + + +def main(): + parser = ArgumentParser( + description='Compute recall of COCO occluded and separated masks ' + 'presented in paper https://arxiv.org/abs/2210.10046.') + parser.add_argument('result', help='result file (pkl format) path') + parser.add_argument('--out', help='file path to save evaluation results') + parser.add_argument( + '--score-thr', + type=float, + default=0.3, + help='Score threshold for the recall calculation. Defaults to 0.3') + parser.add_argument( + '--iou-thr', + type=float, + default=0.75, + help='IoU threshold for the recall calculation. Defaults to 0.75.') + parser.add_argument( + '--ann', + default='data/coco/annotations/instances_val2017.json', + help='coco annotation file path') + args = parser.parse_args() + + results = mmengine.load(args.result) + assert 'masks' in results[0]['pred_instances'], \ + 'The results must be predicted by instance segmentation model.' + metric = CocoOccludedSeparatedMetric( + ann_file=args.ann, iou_thr=args.iou_thr, score_thr=args.score_thr) + metric.dataset_meta = CocoDataset.METAINFO + for datasample in results: + metric.process(data_batch=None, data_samples=[datasample]) + metric_res = metric.compute_metrics(metric.results) + if args.out is not None: + mmengine.dump(metric_res, args.out) + print_log(f'Evaluation results have been saved to {args.out}.') + + +if __name__ == '__main__': + main() diff --git a/tools/test.py b/tools/test.py index 4de587901b8..7fddcf4fbf4 100644 --- a/tools/test.py +++ b/tools/test.py @@ -4,10 +4,10 @@ import os.path as osp from mmengine.config import Config, DictAction -from mmengine.evaluator import DumpResults from mmengine.runner import Runner from mmdet.engine.hooks.utils import trigger_visualization_hook +from mmdet.evaluation import DumpDetResults from mmdet.registry import RUNNERS @@ -92,7 +92,7 @@ def main(): assert args.out.endswith(('.pkl', '.pickle')), \ 'The dump file must be a pkl file.' runner.test_evaluator.metrics.append( - DumpResults(out_file_path=args.out)) + DumpDetResults(out_file_path=args.out)) # start testing runner.test()