Skip to content

Commit

Permalink
Add dir inference support.
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayuXu0 committed Oct 10, 2022
1 parent 3dc1277 commit 7fef33c
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 72 deletions.
85 changes: 30 additions & 55 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
from argparse import ArgumentParser

import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine.utils import ProgressBar
from mmengine.utils import ProgressBar, scandir

from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules

IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif',\
'tiff', 'webp', 'pfm'


class LoadFiles:

def __init__(self, path):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(os.path.abspath(p))
if '*' in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(
sorted(
glob.glob(os.path.join(p, '**/*.*'),
recursive=True))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f'{p} does not exist')
self.files = [
x for x in files if x.split('.')[-1].lower() in IMG_FORMATS
]
self.num_files = len(self.files)
self.count = 0
assert self.num_files > 0, f'No images found in {p}.'

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count == self.num_files:
raise StopIteration
path = self.files[self.count]
self.count += 1
return path

def __len__(self):
return self.num_files


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument(
'source', help='Image source, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-path', default='./', help='Path to output file')
parser.add_argument(
'--out-path', default='./output', help='Path to output file')
parser.add_argument(
'--recursive',
action='store_true',
help='Recursively scan the directory when source is dir.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
Expand All @@ -85,22 +43,39 @@ def main(args):
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

files = LoadFiles(args.img)
progress_bar = ProgressBar(len(files))
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')

is_file = os.path.isfile(args.source) and os.path.splitext(
args.source)[-1] in (IMG_EXTENSIONS)
is_dir = os.path.isdir(args.source)

files = []
if is_file:
files.append(args.source)
elif is_dir:
files = [
os.path.join(args.source, file) for file in scandir(
args.source, IMG_EXTENSIONS, recursive=args.recursive)
]
assert len(files) > 0, 'Images list is empty, please check input source.'

# start detector inference
progress_bar = ProgressBar(len(files))
for file in files:
result = inference_detector(model, file)
img = mmcv.imread(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
_, file_name = os.path.split(file)
rel_path = os.path.relpath(
file, args.source) if is_dir else os.path.basename(file)
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=os.path.join(args.out_path, file_name),
out_file=os.path.join(args.out_path, rel_path.replace('/', '_')),
pred_score_thr=args.score_thr)
progress_bar.update()

Expand Down
23 changes: 10 additions & 13 deletions demo/video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,23 @@ def main():

for frame in video_reader:
result = inference_detector(model, frame)
visualizer.add_datasample(
'result',
frame,
data_sample=result,
draw_gt=False,
show=False,
wait_time=0,
out_file=None,
pred_score_thr=args.score_thr)
if args.show:
cv2.namedWindow('video', 0)
mmcv.imshow(frame, 'video', args.wait_time)
else:
visualizer.add_datasample(
'result',
frame,
data_sample=result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=None,
pred_score_thr=args.score_thr)
mmcv.imshow(visualizer.get_image(), 'video', args.wait_time)
if args.out:
video_writer.write(visualizer.get_image())
progress_bar.update()

if video_writer:
video_writer.release()
cv2.destroyAllWindows()


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install MMYOLO from source, just run the following command.

```shell
python demo/image_demo.py demo/demo.jpg yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth --device cpu
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \
--device cpu
```

You will see a new image `result.jpg` on your current folder, where bounding boxes are plotted.
Expand Down
42 changes: 41 additions & 1 deletion mmyolo/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import mmcv
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from mmcv.transforms import BaseTransform, LoadImageFromFile
from mmcv.transforms.utils import cache_randomness
from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations
from mmdet.datasets.transforms import Resize as MMDET_Resize
Expand Down Expand Up @@ -661,3 +661,43 @@ def _get_translation_matrix(x: float, y: float) -> np.ndarray:
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
dtype=np.float32)
return translation_matrix


@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""

def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""

img = results['img']
if self.to_float32:
img = img.astype(np.float32)

results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
4 changes: 2 additions & 2 deletions mmyolo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""

from mmdet.registry import TRANSFORMS as MMDET_TRANSFORM
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
Expand All @@ -23,6 +22,7 @@
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
Expand All @@ -42,7 +42,7 @@
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMDET_TRANSFORM)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)

# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
Expand Down

0 comments on commit 7fef33c

Please sign in to comment.