Skip to content

Commit

Permalink
Merge remote-tracking branch 'mmcv/master' into fix_wandb_logger_drop…
Browse files Browse the repository at this point in the history
…_res_bug_by_delete_step_param
  • Loading branch information
shenmishajing committed Apr 9, 2021
2 parents ed213c3 + 03a2e3a commit a1ec8a7
Show file tree
Hide file tree
Showing 9 changed files with 956 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/readme.md
7 changes: 7 additions & 0 deletions mmcv/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
single_gpu_test)

__all__ = [
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
'single_gpu_test'
]
197 changes: 197 additions & 0 deletions mmcv/engine/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import os.path as osp
import pickle
import shutil
import tempfile
import time

import torch
import torch.distributed as dist

import mmcv
from mmcv.runner import get_dist_info


def single_gpu_test(model, data_loader):
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)

# use the first key as main key to calculate the batch size
batch_size = len(next(iter(data.values())))
for _ in range(batch_size):
prog_bar.update()
return results


def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting
``gpu_collect=True``, it encodes results to gpu tensors and use gpu
communication for results collection. On cpu mode it saves the results on
different gpus to ``tmpdir`` and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)

if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
prog_bar.update()

# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results


def collect_results_cpu(result_part, size, tmpdir=None):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_result = mmcv.load(part_file)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results


def collect_results_gpu(result_part, size):
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
communication for results collection.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)

if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
7 changes: 6 additions & 1 deletion mmcv/onnx/onnx_utils/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,13 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
# scales[0] is NoneType in Pytorch == 1.5.1
# scales[0] is TensorType with sizes = [] in Pytorch == 1.6.0
# scales[0] is ListType in Pytorch == 1.7.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f'
# scales[0] is TensorType with sizes = [2] in Pytorch == 1.8.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' or (
scales[0].type().kind() == 'TensorType' and
(sum(scales[0].type().sizes()) > 1)) else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])

Expand Down
13 changes: 7 additions & 6 deletions mmcv/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
EMAHook, Fp16OptimizerHook, Hook, IterTimerHook,
LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook,
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
MlflowLoggerHook, OptimizerHook, PaviLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
Expand All @@ -37,5 +38,5 @@
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix',
'Sequential', 'ModuleList'
'EvalHook', 'DistEvalHook', 'Sequential', 'ModuleList'
]
4 changes: 3 additions & 1 deletion mmcv/runner/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
Expand All @@ -18,5 +19,6 @@
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook'
'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook',
'EvalHook', 'DistEvalHook'
]
Loading

0 comments on commit a1ec8a7

Please sign in to comment.