-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
164 lines (139 loc) · 4.74 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import pandas as pd
import os
import constants
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def load_csv(csv_path):
"""
Loads a given csv as a dataframe with correct types.
"""
df = pd.read_csv(csv_path, header=0)
columns = df.columns[1:]
# Convert values to int.
df = pd.concat([df[[df.columns[0]]], pd.DataFrame(
[pd.to_numeric(df[e], errors='coerce') for e in columns]
).T], axis=1)
columns = df.columns[1:]
return df, columns
def get_train_datagen():
"""
Get Image DataGenerator with config for training.
"""
train_datagen = ImageDataGenerator(rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.15)
return train_datagen
def get_test_datagen():
"""
Get Image DataGenerator with config for testing.
"""
test_datagen = ImageDataGenerator(rescale=1./255)
return test_datagen
def get_ffhq_train(batch_size,shuffle):
"""
Get training and validation data generators for ffhq.
"""
train_datagen = get_train_datagen()
ffhq_data, columns = load_csv("./face_data/age_gender/labels.csv")
train_generator = train_datagen.flow_from_dataframe(
dataframe=ffhq_data,
directory="./face_data/age_gender",
x_col="filename",
y_col=columns,
target_size=(
constants.IMG_HEIGHT, constants.IMG_WIDTH),
batch_size=batch_size,
class_mode='raw',
subset='training',
shuffle=shuffle,
seed=1)
validation_generator = train_datagen.flow_from_dataframe(
dataframe=ffhq_data,
directory="./face_data/age_gender",
x_col="filename",
y_col=columns,
target_size=(
constants.IMG_HEIGHT, constants.IMG_WIDTH),
batch_size=batch_size,
class_mode='raw',
subset='validation',
shuffle=shuffle,
seed=1)
return train_generator, validation_generator, columns
def get_ffhq_test(test_dataset):
"""
Get testing data generator for ffhq.
Must choose dataset to test:
overall,
ffhqgenerated
or ffhq
"""
if test_dataset == "overall":
generated_test_data, columns = load_csv(
"./face_data/age_gender_test/labels_generated.csv")
ffhq_test_data, columns = load_csv(
"./face_data/age_gender_test/labels.csv")
ffhq_test_data = pd.concat([generated_test_data, ffhq_test_data])
elif test_dataset == "ffhqgenerated":
ffhq_test_data, columns = load_csv(
"./face_data/age_gender_test/labels_generated.csv")
else:
ffhq_test_data, columns = load_csv(
"./face_data/age_gender_test/labels.csv")
test_datagen = get_test_datagen()
test_generator = test_datagen.flow_from_dataframe(
dataframe=ffhq_test_data,
directory="./face_data/age_gender_test",
x_col="filename",
y_col=columns,
target_size=(
constants.IMG_HEIGHT, constants.IMG_WIDTH),
class_mode='raw',
shuffle=False,
seed=1)
return test_generator, columns
def get_celeba(batch_size):
"""
Get celeba training and validation generators
"""
train_datagen = get_train_datagen()
train_df, columns = load_csv(
"./celeba-dataset/list_attr_celeba.csv")
columns = ["Blond_Hair", "Black_Hair", "Male", "No_Beard", "Young"]
columns = columns[:2]
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_df[:1000],
directory="celeba-dataset/img_align_celeba/img_align_celeba",
x_col="image_id",
y_col=columns,
target_size=(
constants.IMG_HEIGHT, constants.IMG_WIDTH),
batch_size=batch_size,
class_mode='raw',
subset='training')
validation_generator = train_datagen.flow_from_dataframe(
dataframe=train_df[:1000],
directory="celeba-dataset/img_align_celeba/img_align_celeba",
x_col="image_id",
y_col=columns,
target_size=(
constants.IMG_HEIGHT, constants.IMG_WIDTH),
batch_size=batch_size,
class_mode='raw',
subset='validation')
return train_generator, validation_generator, columns
def get_training_data(batch_size, dataset="ffhq",shuffle=True):
"""
get training and validation generators for input dataset.
Shuffled by default.
"""
if dataset == "celeba":
return get_celeba(batch_size)
else:
return get_ffhq_train(batch_size,shuffle)
def get_testing_data(test_dataset="ffhq"):
"""
Get testing data generator
"""
return get_ffhq_test(test_dataset)