-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'mmcv/master' into fix_wandb_logger_drop…
…_res_bug_by_delete_step_param
- Loading branch information
Showing
9 changed files
with
956 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
../README.md | ||
../README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test, | ||
single_gpu_test) | ||
|
||
__all__ = [ | ||
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', | ||
'single_gpu_test' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import os.path as osp | ||
import pickle | ||
import shutil | ||
import tempfile | ||
import time | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
import mmcv | ||
from mmcv.runner import get_dist_info | ||
|
||
|
||
def single_gpu_test(model, data_loader): | ||
"""Test model with a single gpu. | ||
This method tests model with a single gpu and displays test progress bar. | ||
Args: | ||
model (nn.Module): Model to be tested. | ||
data_loader (nn.Dataloader): Pytorch data loader. | ||
Returns: | ||
list: The prediction results. | ||
""" | ||
model.eval() | ||
results = [] | ||
dataset = data_loader.dataset | ||
prog_bar = mmcv.ProgressBar(len(dataset)) | ||
for data in data_loader: | ||
with torch.no_grad(): | ||
result = model(return_loss=False, **data) | ||
results.extend(result) | ||
|
||
# use the first key as main key to calculate the batch size | ||
batch_size = len(next(iter(data.values()))) | ||
for _ in range(batch_size): | ||
prog_bar.update() | ||
return results | ||
|
||
|
||
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): | ||
"""Test model with multiple gpus. | ||
This method tests model with multiple gpus and collects the results | ||
under two different modes: gpu and cpu modes. By setting | ||
``gpu_collect=True``, it encodes results to gpu tensors and use gpu | ||
communication for results collection. On cpu mode it saves the results on | ||
different gpus to ``tmpdir`` and collects them by the rank 0 worker. | ||
Args: | ||
model (nn.Module): Model to be tested. | ||
data_loader (nn.Dataloader): Pytorch data loader. | ||
tmpdir (str): Path of directory to save the temporary results from | ||
different gpus under cpu mode. | ||
gpu_collect (bool): Option to use either gpu or cpu to collect results. | ||
Returns: | ||
list: The prediction results. | ||
""" | ||
model.eval() | ||
results = [] | ||
dataset = data_loader.dataset | ||
rank, world_size = get_dist_info() | ||
if rank == 0: | ||
prog_bar = mmcv.ProgressBar(len(dataset)) | ||
time.sleep(2) # This line can prevent deadlock problem in some cases. | ||
for i, data in enumerate(data_loader): | ||
with torch.no_grad(): | ||
result = model(return_loss=False, **data) | ||
results.extend(result) | ||
|
||
if rank == 0: | ||
batch_size = len(result) | ||
for _ in range(batch_size * world_size): | ||
prog_bar.update() | ||
|
||
# collect results from all ranks | ||
if gpu_collect: | ||
results = collect_results_gpu(results, len(dataset)) | ||
else: | ||
results = collect_results_cpu(results, len(dataset), tmpdir) | ||
return results | ||
|
||
|
||
def collect_results_cpu(result_part, size, tmpdir=None): | ||
"""Collect results under cpu mode. | ||
On cpu mode, this function will save the results on different gpus to | ||
``tmpdir`` and collect them by the rank 0 worker. | ||
Args: | ||
result_part (list): Result list containing result parts | ||
to be collected. | ||
size (int): Size of the results, commonly equal to length of | ||
the results. | ||
tmpdir (str | None): temporal directory for collected results to | ||
store. If set to None, it will create a random temporal directory | ||
for it. | ||
Returns: | ||
list: The collected results. | ||
""" | ||
rank, world_size = get_dist_info() | ||
# create a tmp dir if it is not specified | ||
if tmpdir is None: | ||
MAX_LEN = 512 | ||
# 32 is whitespace | ||
dir_tensor = torch.full((MAX_LEN, ), | ||
32, | ||
dtype=torch.uint8, | ||
device='cuda') | ||
if rank == 0: | ||
mmcv.mkdir_or_exist('.dist_test') | ||
tmpdir = tempfile.mkdtemp(dir='.dist_test') | ||
tmpdir = torch.tensor( | ||
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') | ||
dir_tensor[:len(tmpdir)] = tmpdir | ||
dist.broadcast(dir_tensor, 0) | ||
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() | ||
else: | ||
mmcv.mkdir_or_exist(tmpdir) | ||
# dump the part result to the dir | ||
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) | ||
dist.barrier() | ||
# collect all parts | ||
if rank != 0: | ||
return None | ||
else: | ||
# load results of all parts from tmp dir | ||
part_list = [] | ||
for i in range(world_size): | ||
part_file = osp.join(tmpdir, f'part_{i}.pkl') | ||
part_result = mmcv.load(part_file) | ||
# When data is severely insufficient, an empty part_result | ||
# on a certain gpu could makes the overall outputs empty. | ||
if part_result: | ||
part_list.append(part_result) | ||
# sort the results | ||
ordered_results = [] | ||
for res in zip(*part_list): | ||
ordered_results.extend(list(res)) | ||
# the dataloader may pad some samples | ||
ordered_results = ordered_results[:size] | ||
# remove tmp dir | ||
shutil.rmtree(tmpdir) | ||
return ordered_results | ||
|
||
|
||
def collect_results_gpu(result_part, size): | ||
"""Collect results under gpu mode. | ||
On gpu mode, this function will encode results to gpu tensors and use gpu | ||
communication for results collection. | ||
Args: | ||
result_part (list): Result list containing result parts | ||
to be collected. | ||
size (int): Size of the results, commonly equal to length of | ||
the results. | ||
Returns: | ||
list: The collected results. | ||
""" | ||
rank, world_size = get_dist_info() | ||
# dump result part to tensor with pickle | ||
part_tensor = torch.tensor( | ||
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') | ||
# gather all result part tensor shape | ||
shape_tensor = torch.tensor(part_tensor.shape, device='cuda') | ||
shape_list = [shape_tensor.clone() for _ in range(world_size)] | ||
dist.all_gather(shape_list, shape_tensor) | ||
# padding result part tensor to max length | ||
shape_max = torch.tensor(shape_list).max() | ||
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') | ||
part_send[:shape_tensor[0]] = part_tensor | ||
part_recv_list = [ | ||
part_tensor.new_zeros(shape_max) for _ in range(world_size) | ||
] | ||
# gather all result part | ||
dist.all_gather(part_recv_list, part_send) | ||
|
||
if rank == 0: | ||
part_list = [] | ||
for recv, shape in zip(part_recv_list, shape_list): | ||
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) | ||
# When data is severely insufficient, an empty part_result | ||
# on a certain gpu could makes the overall outputs empty. | ||
if part_result: | ||
part_list.append(part_result) | ||
# sort the results | ||
ordered_results = [] | ||
for res in zip(*part_list): | ||
ordered_results.extend(list(res)) | ||
# the dataloader may pad some samples | ||
ordered_results = ordered_results[:size] | ||
return ordered_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.