-
Notifications
You must be signed in to change notification settings - Fork 3
/
util.py
76 lines (65 loc) · 2.59 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import random
import numpy as np
import math
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.utils import make_grid
def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
'''
Converts a torch Tensor into an image Numpy array
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.clamp_(*min_max) # clamp
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
if out_type == np.uint8:
img_np = ((img_np+1) * 127.5).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type).squeeze()
def postprocess(images):
return [tensor2img(image) for image in images]
def set_seed(seed, gl_seed=0):
""" set random seed, gl_seed used in worker_init_fn function """
if seed >=0 and gl_seed>=0:
seed += gl_seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
''' change the deterministic and benchmark maybe cause uncertain convolution behavior.
speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html '''
if seed >=0 and gl_seed>=0: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else: # faster, less reproducible
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def set_gpu(args, distributed=False, rank=0):
""" set parameter to gpu or ddp """
if args is None:
return None
if distributed and isinstance(args, torch.nn.Module):
return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True)
else:
return args.cuda()
def set_device(args, distributed=False, rank=0):
""" set parameter to gpu or cpu """
if torch.cuda.is_available():
if isinstance(args, list):
return (set_gpu(item, distributed, rank) for item in args)
elif isinstance(args, dict):
return {key:set_gpu(args[key], distributed, rank) for key in args}
else:
args = set_gpu(args, distributed, rank)
return args