-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdemo.py
90 lines (65 loc) · 2.48 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# -*- coding: utf-8 -*-
"""
Demo code
@author: Denis Tome'
"""
from torch.utils.data import DataLoader
import torch
from torchvision import transforms
from base import SetType
import dataset.transform as trsf
from dataset import Mocap
from utils import config, ConsoleLogger
from utils import evaluate, io
LOGGER = ConsoleLogger("Main")
def main():
"""Main"""
LOGGER.info('Starting demo...')
# ------------------- Data loader -------------------
data_transform = transforms.Compose([
trsf.ImageTrsf(),
trsf.Joints3DTrsf(),
trsf.ToTensor()])
# let's load data from validation set as example
data = Mocap(
config.dataset.val,
SetType.VAL,
transform=data_transform)
data_loader = DataLoader(
data,
batch_size=config.data_loader.batch_size,
shuffle=config.data_loader.shuffle)
# ------------------- Evaluation -------------------
eval_body = evaluate.EvalBody()
eval_upper = evaluate.EvalUpperBody()
eval_lower = evaluate.EvalUpperBody()
# ------------------- Read dataset frames -------------------
for it, (img, p2d, p3d, action) in enumerate(data_loader):
LOGGER.info('Iteration: {}'.format(it))
LOGGER.info('Images: {}'.format(img.shape))
LOGGER.info('p2ds: {}'.format(p2d.shape))
LOGGER.info('p3ds: {}'.format(p3d.shape))
LOGGER.info('Actions: {}'.format(action))
# -----------------------------------------------------------
# ------------------- Run your model here -------------------
# -----------------------------------------------------------
# TODO: replace p3d_hat with model preditions
p3d_hat = torch.ones_like(p3d)
# Evaluate results using different evaluation metrices
y_output = p3d_hat.data.cpu().numpy()
y_target = p3d.data.cpu().numpy()
eval_body.eval(y_output, y_target, action)
eval_upper.eval(y_output, y_target, action)
eval_lower.eval(y_output, y_target, action)
# TODO: remove break
break
# ------------------- Save results -------------------
LOGGER.info('Saving evaluation results...')
res = {'FullBody': eval_body.get_results(),
'UpperBody': eval_upper.get_results(),
'LowerBody': eval_lower.get_results()}
io.write_json(config.eval.output_file, res)
LOGGER.info('Done.')
if __name__ == "__main__":
main()