-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutilis.py
37 lines (33 loc) · 875 Bytes
/
utilis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import torch
import random
import torch.nn as nn
def set_seed(seed):
"""
setup random seed to fix the result
Args:
seed: random seed
Returns: None
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class style():
BLACK = '\033[30m'
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
UNDERLINE = '\033[4m'
RESET = '\033[0m'
def glorot_init(input_dim, output_dim):
init_range = np.sqrt(6.0/(input_dim + output_dim))
initial = torch.rand(input_dim, output_dim)*2*init_range - init_range
return nn.Parameter(initial)