Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add engine for unified entrypoints in downstream projects #695

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_engine

build_without_ops:
runs-on: ubuntu-latest
Expand Down
9 changes: 9 additions & 0 deletions mmcv/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .test import collect_results_cpu, collect_results_gpu, multi_gpu_test
from .utils import (default_args_parser, gather_info, set_random_seed,
setup_cfg, setup_envs)

__all__ = [
'default_args_parser', 'gather_info', 'setup_cfg', 'setup_envs',
'multi_gpu_test', 'collect_results_gpu', 'collect_results_cpu',
'set_random_seed'
]
131 changes: 131 additions & 0 deletions mmcv/engine/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os.path as osp
import pickle
import shutil
import tempfile
import time

import torch
import torch.distributed as dist

import mmcv
from mmcv.runner import get_dist_info


def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.

This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.

Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.

Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be applicable to all downstream codebases.

results.extend(result)

if rank == 0:
batch_size = len(result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true for all codebases?

for _ in range(batch_size * world_size):
prog_bar.update()

# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results


def collect_results_cpu(result_part, size, tmpdir=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring.

rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results


def collect_results_gpu(result_part, size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring.

rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)

if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
Empty file added mmcv/engine/train.py
Empty file.
239 changes: 239 additions & 0 deletions mmcv/engine/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import argparse
import os
import os.path as osp
import random

import numpy as np
import torch

from ..runner import get_dist_info, init_dist
from ..utils import (Config, DictAction, import_modules_from_strings,
mkdir_or_exist)


def set_random_seed(seed, deterministic=False):
"""Set random seed.

Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def default_args_parser():
"""Default argument parser for OpenMMLab projects.

This function is used as a default argument parser in OpenMMLab projects.
To add customized arguments, users can create a new parser function which
calls this functions first.

Returns:
:obj:`argparse.ArgumentParser`: Argument parser
"""
parser = argparse.ArgumentParser(
description='OpenMMLab Default Argument Parser')

# common arguments for both training and testing
parser.add_argument('config', help='config file path')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results in multi-gpu testing.')

# common arguments for training
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')

# common arguments for testing
parser.add_argument(
'--test-only', action='store_true', help='whether to perform evaluate')
parser.add_argument(
'--checkpoint', help='checkpoint file used in evaluation')
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
# TODO: decide whether to maintain two place for modifing eval options
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function')

return parser


def setup_cfg(args, cfg_args):
"""Set up config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a detailed description.

Note:
This function assumes the arguments are parsed from the parser of
defined by :meth:`default_args_parser`, which contains necessary keys
for distributed training including 'launcher', 'local_rank', etc.

Arguments:
args (:obj:`argparse.ArgumentParser`): arguments from entry point
cfg_args (list[str]): list of key-value pairs that will be merged
into cfgs.

Returns:
Config: config dict
"""
cfg = Config.fromfile(args.config)
# merge config from args.cfg_options
if len(cfg_args) > 0:
cfg.merge_from_arg_list(cfg_args)

if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

# initialize some default but necessary options
cfg.seed = cfg.get('seed', None)
cfg.deterministic = cfg.get('deterministic', False)
cfg.resume_from = cfg.get('resume_from', None)

cfg.launcher = args.launcher
cfg.local_rank = args.local_rank
if args.launcher == 'none':
cfg.distributed = False
else:
cfg.distributed = True

if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

if cfg.get('custom_imports', None):
import_modules_from_strings(**cfg['custom_imports'])
return cfg


def setup_envs(cfg, dump_cfg=True):
"""Setup running environments.

This function initialize the running environment.
It does the following things in order:

1. Set local rank in the environment
2. Set cudnn benchmark
3. Initialize distributed function
4. Create work dir anddump config file
5. Set random seed

Args:
cfg (:obj:`Config`): Config object.
dump_cfg: Whether to dump configs.
"""
# set local rank
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(cfg.local_rank)

# set cudnn_benchmark
torch.backends.cudnn.benchmark = cfg.get('cudnn_benchmark', False)

# init distributed env first, since logger depends on the dist info.
if cfg.distributed:
init_dist(cfg.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)

# create work_dir
mkdir_or_exist(osp.abspath(cfg.work_dir))
if cfg.local_rank == 0 and dump_cfg:
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(cfg.filename)))

# set random seeds
if cfg.seed is not None:
set_random_seed(cfg.seed, deterministic=cfg.deterministic)


def gather_info(cfg, logger, env_info_dict):
"""Gather running information.

This function do the following things in order:

1. collect & log env info
2. collect exp name, config

Args:
cfg (:obj:`Config`): Config object.
logger (:obj:`logging.logger`): Logger.
env_info_dict (dict): Environment information.
"""
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
meta['seed'] = cfg.seed
meta['exp_name'] = osp.basename(cfg.filename)

# log some basic info
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
logger.info(f'Set random seed to {cfg.seed}, '
f'deterministic: {cfg.deterministic}')
logger.info(f'Distributed training: {cfg.distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

return meta
Loading