forked from simonalexanderson/StyleGestures
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
172 lines (145 loc) · 5.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import json
import os
import shutil
import subprocess
import sys
from os.path import join
import joblib as jl
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import smplx
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import SMPLXWrapper
from data_processing.prepare_saga_dataset import DatasetProcessor
from motion.datasets.motion_data import MotionDataset, TestDataset
module_path = os.path.abspath(join("data_processing"))
if module_path not in sys.path:
sys.path.append(module_path)
def inv_standardize(data, scaler):
shape = data.shape
flat = data.copy().reshape((shape[0] * shape[1], shape[2]))
scaled = scaler.inverse_transform(flat).reshape(shape)
return scaled
def open_data_array(data_root, fname, framerate):
"""
Open the .npz archive with the given filename and return the motion as an array.
"""
fpath = join(data_root, f"{fname}_{framerate}fps.npy")
return np.load(fpath).astype(np.float32)
class SagaDataModule(LightningDataModule):
def __init__(
self,
# Batch size is defined by the model
batch_size: int,
dropout: float,
seqlen: int,
n_lookahead: int,
data_root: str,
x_channels: int,
cond_channels: int,
framerate: int,
is_training: bool,
num_workers: int,
trinity: bool,
add_caption_to_vis: bool,
):
super().__init__()
self.data_root = data_root
self.batch_size = batch_size
self.num_workers = num_workers
self.is_trinity = trinity
self.add_caption_to_vis = add_caption_to_vis
self.motion_scaler = jl.load(join(data_root, "motion_scaler.sav"))
self.fps = framerate
if is_training:
# Load the data. This should already be standardized.
train_input = open_data_array(data_root, "train_input", framerate)
train_output = open_data_array(data_root, "train_output", framerate)
val_input = open_data_array(data_root, "val_input", framerate)
val_output = open_data_array(data_root, "val_output", framerate)
# Create pytorch data sets
self.train_dataset = MotionDataset(
train_input, train_output, seqlen, n_lookahead, dropout,
)
self.validation_dataset = MotionDataset(
val_input, val_output, seqlen, n_lookahead, dropout,
)
# Load longer sequences for tuning the network
subj_evaluation_seq_input = open_data_array(
data_root, "val_seq_input", framerate
)
self.velocity_histogram_output = torch.as_tensor(
open_data_array(data_root, "val_output_for_histograms", framerate)
)
self.velocity_histogram_input = torch.as_tensor(
open_data_array(data_root, "val_input_for_histograms", framerate)
)
else:
self.train_dataset = None
self.validation_dataset = None
# use this to generate test data for evaluation.
subj_evaluation_seq_input = open_data_array(
data_root, "test_seq_input", framerate
)
# initialise test output with zeros (mean pose)
self.n_x_channels = self.motion_scaler.mean_.shape[0]
self.n_cond_channels = (
self.n_x_channels * seqlen
) + subj_evaluation_seq_input.shape[2] * (seqlen + 1 + n_lookahead)
test_output = np.zeros(
(
subj_evaluation_seq_input.shape[0],
subj_evaluation_seq_input.shape[1],
self.n_x_channels,
)
).astype(np.float32)
self.test_dataset = TestDataset(subj_evaluation_seq_input, test_output)
self.eval_batch = self.load_eval_batch()
self.n_visualization_samples = len(self.eval_batch)
self.prepare_visualization_data(is_training)
def prepare_visualization_data(self, is_training):
partition = "val" if is_training else "test"
with open(f"data_processing/{partition}_visualization_metadata.json", "r") as f:
self.visualization_metadata = json.load(f)
self.visualization_audio_paths = []
self.visualization_text_transcripts = []
self.visualization_gest_types = []
for recording_name, segments in self.visualization_metadata.items():
for idx, segment in enumerate(segments):
_, gest_type, text = segment
audio_file = f"{recording_name}_{partition}_{str(idx).zfill(3)}.wav"
self.visualization_audio_paths.append(
join(self.data_root, f"{partition}_seq_audio", audio_file)
)
self.visualization_gest_types.append(gest_type)
self.visualization_text_transcripts.append(text)
def load_eval_batch(self):
loader = DataLoader(
self.test_dataset,
batch_size=len(self.test_dataset),
num_workers=8,
shuffle=False,
)
return next(iter(loader))
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.validation_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
drop_last=True,
)
def n_channels(self):
return self.n_x_channels, self.n_cond_channels