Skip to content

Commit

Permalink
Add dir inference support.
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayuXu0 committed Oct 10, 2022
1 parent 3dc1277 commit 1d9e091
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 143 deletions.
76 changes: 22 additions & 54 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
from argparse import ArgumentParser

import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine.utils import ProgressBar
from mmengine.utils import ProgressBar, scandir

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules

IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif',\
'tiff', 'webp', 'pfm'


class LoadFiles:

def __init__(self, path):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(os.path.abspath(p))
if '*' in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(
sorted(
glob.glob(os.path.join(p, '**/*.*'),
recursive=True))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f'{p} does not exist')
self.files = [
x for x in files if x.split('.')[-1].lower() in IMG_FORMATS
]
self.num_files = len(self.files)
self.count = 0
assert self.num_files > 0, f'No images found in {p}.'

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count == self.num_files:
raise StopIteration
path = self.files[self.count]
self.count += 1
return path

def __len__(self):
return self.num_files
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument(
'img', help='Image img, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-path', default='./', help='Path to output file')
parser.add_argument(
'--out-path', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
Expand All @@ -85,22 +42,33 @@ def main(args):
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

files = LoadFiles(args.img)
progress_bar = ProgressBar(len(files))
is_file = os.path.splitext(args.img)[-1] in (IMG_EXTENSIONS)
is_dir = os.path.isdir(args.img)

files = []
if is_file:
files.append(args.img)
elif is_dir:
files = [
os.path.join(args.img, file)
for file in scandir(args.img, IMG_EXTENSIONS)
]
assert len(files) > 0, 'Images list is empty, please check input img.'

# start detector inference
progress_bar = ProgressBar(len(files))
for file in files:
result = inference_detector(model, file)
img = mmcv.imread(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
_, file_name = os.path.split(file)
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=os.path.join(args.out_path, file_name),
out_file=os.path.join(args.out_path, os.path.basename(file)),
pred_score_thr=args.score_thr)
progress_bar.update()

Expand Down
82 changes: 0 additions & 82 deletions demo/video_demo.py

This file was deleted.

17 changes: 15 additions & 2 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,23 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install MMYOLO from source, just run the following command.

```shell
python demo/image_demo.py demo/demo.jpg yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth --device cpu
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# Optional parameters
# --output ./output *The detection results are output to the specified directory. Default: ./output
# --device cuda:0 *The computing resources used, including cuda and cpu. Default: cuda:0
# --show *Display the results on the screen. Default: False
# --score-thr 0.3 *Confidence threshold. Default: 0.3
```

You will see a new image `result.jpg` on your current folder, where bounding boxes are plotted.
You will see a new image on your `output` folder, where bounding boxes are plotted.

Supported input types:

- Single image, include jpg, jpeg, png, ppm, bmp, pgm, tif, tiff, webp.
- Folder, all image files in the folder will be traversed and the corresponding results will be output.

Option (b). If you install MMYOLO with MIM, open your python interpreter and copy&paste the following codes.

Expand Down
16 changes: 13 additions & 3 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,21 @@ mim download mmyolo --config yolov5_s-v61_syncbn_fast_8xb16-300e_coco --dest .
```shell
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
--device cpu
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth

# 可选参数
# --output ./output *检测结果输出到指定目录下,默认为./output
# --device cuda:0 *使用的计算资源,包括cuda, cpu等,默认为cuda:0
# --show *使用该参数表示在屏幕上显示检测结果,默认为False
# --score-thr 0.3 *置信度阈值,默认为0.3
```

你会在当前文件夹中看到一个新的图像 `result.jpg`,图像中包含有网络预测的检测框。
运行结束后,你会在output文件夹中看到检测结果图像,图像中包含有网络预测的检测框。

支持输入类型包括

- 单张图片, 支持jpg, jpeg, png, ppm, bmp, pgm, tif, tiff, webp。
- 文件目录,会遍历文件目录下所有图片文件,并输出对应结果。

方案 2. 如果你通过 MIM 安装的 MMYOLO, 那么可以打开你的 Python 解析器,复制并粘贴以下代码:

Expand Down
4 changes: 2 additions & 2 deletions mmyolo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""

from mmdet.registry import TRANSFORMS as MMDET_TRANSFORM
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
Expand All @@ -23,6 +22,7 @@
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
Expand All @@ -42,7 +42,7 @@
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMDET_TRANSFORM)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)

# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
Expand Down

0 comments on commit 1d9e091

Please sign in to comment.