-
Notifications
You must be signed in to change notification settings - Fork 4
/
util.py
206 lines (155 loc) · 6.31 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from typing import Dict
import os
import subprocess
import random
import pickle
import torch
import numpy as np
import argparse
class Args:
dataset = None
epoch = None
lr = None
lr_scheduler = None
lr_milestones = None
lr_gamma = None
obs_len = None
pred_len = None
train_batch_size = None
test_batch_size = None
seed = None
gpu_num = None
checkpoint = None
data_dir = None
log_dir = None
cuda = None
end_centered = None
data_flip = None
data_scaling = None
# Arguments for the building of tree
split_thea = None
split_temporal_interval = None
tree_degree = None
num_k = None
class ModelArgs:
# Arguments for model
in_dim = 2
obs_len = 8
pred_len = 12
hidden1 = 1024
hidden2 = 256
enc_dim = 64
att_layer = 3
tf = True # teacher forcing
out_dim = 2
num_k = 20
def add_argument(parser):
assert isinstance(parser, argparse.ArgumentParser)
parser.add_argument('--dataset', type=str, default='eth', help='eth,hotel,univ,zara1,zara2,sdd')
parser.add_argument('--data_dir', type=str,
default='./dataset/')
parser.add_argument('--log_dir', type=str)
parser.add_argument('--epoch', type=int, default=350)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lr_scheduler', type=int, default=0, help='0:MultiStepLR, 1:CosineAnnealingLR, other numbers:None')
parser.add_argument('--lr_milestones', type=int, nargs='+', default=[50, 150, 250])
parser.add_argument('--lr_gamma', type=float, default=0.5)
parser.add_argument('--obs_len', type=int, default=8)
parser.add_argument('--pred_len', type=int, default=12)
parser.add_argument('--train_batch_size', type=int, default=512,
help='256 or 512 for eth-ucy, 512 for sdd')
parser.add_argument('--test_batch_size', type=int, default=512,
help='256, 512 or 4096 for eth-ucy, 4096 for sdd')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--gpu_num', type=str, default='6')
parser.add_argument('--checkpoint', type=str, default='./checkpoints/')
parser.add_argument('--end_centered', action='store_true')
parser.add_argument('--data_flip', action='store_true')
parser.add_argument('--data_scaling', type=float, nargs='+', default=None)
parser.add_argument('--split_thea', type=int, default=4)
parser.add_argument('--split_temporal_interval', type=int, default=4)
parser.add_argument('--tree_degree', type=int, default=3)
parser.add_argument('--num_k', type=int, default=20)
def get_input_data(data_dict: Dict, key=None):
try:
return data_dict[key]
except KeyError:
print('KeyError')
args: Args = None
logger = None
def init(args_: Args, logger_):
global args, logger
args = args_
logger = logger_
# assert os.path.exists(args.checkpoint + args.dataset)
assert os.path.exists(args.data_dir + 'test')
assert os.path.exists(args.data_dir + 'train')
if args.log_dir is None:
args.log_dir = args.checkpoint + args.dataset
# os.makedirs(args.checkpoint + args.dataset, exist_ok=True)
# os.makedirs(args.log_dir, exist_ok=True)
if os.path.exists(args.checkpoint + args.dataset):
subprocess.check_output('rm -r {}'.format(args.checkpoint + args.dataset), shell=True, encoding='utf-8')
os.makedirs(args.checkpoint + args.dataset, exist_ok=False)
logger.info("*******" + ' args ' + "******")
# args_dict = vars(args)
# for key in args_dict:
# print("\033[32m" + key + "\033[0m", args_dict[key], end='\t')
# print('')
logging(vars(args_), verbose=True, sep=' ', save_as_pickle=True, file_type=args.dataset + '.args')
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
def logging(*inputs, verbose=False, sep=' ', save_as_pickle=False, file_type='args', append_log=False):
'''
write something into log file
:return:
'''
if verbose:
print(*inputs, sep=sep)
if not hasattr(args, 'log_dir'):
return
file = os.path.join(args.log_dir, file_type)
if save_as_pickle:
with open(file, 'wb') as pickle_file:
pickle.dump(*inputs, pickle_file)
if append_log:
with open(file, "a", encoding='utf-8') as fout:
print(*tuple(inputs), file=fout, sep=sep)
print(file=fout)
def get_train_test_data(data_path, dataset_name, batch_size, is_test):
if is_test:
if dataset_name == 'sdd':
return data_path + '/test' + "/social_" + dataset_name + "_test" + "_" + str(
4096) + "_" + str(0) + "_" + str(100) + ".pickle"
else:
return data_path + '/test' + "/social_" + dataset_name + "_test" + "_" + str(
batch_size) + "_" + str(0) + "_" + str(50) + ".pickle"
else:
if dataset_name == 'sdd':
return data_path + '/train' + "/social_" + dataset_name + "_train" + "_" + str(
512) + "_" + str(0) + "_" + str(100) + ".pickle"
else:
return data_path + '/train' + "/social_" + dataset_name + "_train" + "_" + str(
batch_size) + "_" + str(0) + "_" + str(50) + ".pickle"
def data_augmentation(data_, end_centered, is_flip, data_scaling):
if end_centered:
data_ = data_ - data_[:, 7:8]
if is_flip:
data_ = np.flip(data_, axis=-1).copy()
if data_scaling is not None:
data_[:, :, 0] = data_[:, :, 0] * data_scaling[0]
data_[:, :, 1] = data_[:, :, 1] * data_scaling[1]
return data_
def get_ade_fde(pred_trajs, gt_trajs, num_k):
pred_trajs = pred_trajs.reshape(gt_trajs.shape[0], num_k, gt_trajs.shape[1], -1)
gt_trajs = gt_trajs.unsqueeze(1)
norm_ = torch.norm(pred_trajs - gt_trajs, p=2, dim=-1)
ade_ = torch.mean(norm_, dim=-1)
fde_ = norm_[:, :, -1]
min_ade, _ = torch.min(ade_, dim=-1)
min_fde, _ = torch.min(fde_, dim=-1)
min_ade = torch.sum(min_ade)
min_fde = torch.sum(min_fde)
return min_ade, min_fde