-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_gazecap.py
executable file
·74 lines (58 loc) · 2.87 KB
/
train_gazecap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/python
import argparse
from itracker.training import experiment
import logging_config
def parse_args():
""" Builds a parser for CLI arguments.
Returns:
The parser it built. """
parser = argparse.ArgumentParser()
parser.add_argument("train_dataset",
help="The location of the training dataset.")
parser.add_argument("test_dataset",
help="The location of the testing dataset.")
parser.add_argument("-v", "--valid_dataset",
help="The location of the validation dataset.")
parser.add_argument("-m", "--model",
help="Existing model to load. Necessary if validating.")
parser.add_argument("-o", "--output", default="eye_model.hd5",
help="Where to save the trained model.")
parser.add_argument("-f", "--fine_tune", action="store_true",
help="Fine-tune the model.")
parser.add_argument("--tpu", default=None,
help="Name of a TPU to train on. (experimental)")
parser.add_argument("--bucket", default=None,
help="Google cloud storage bucket URL when using TPU.")
parser.add_argument("--autoencoder", action="store_true",
help="Specifies that we want to evaluate the autoencoder.")
parser.add_argument("--autoencoder_weights",
help="For branched autoencoder network, location of autoencoder weights.")
parser.add_argument("--clusters",
help="For branched autoencoder network, location of cluster data.")
parser.add_argument("--batch_size", type=int, default=32,
help="Examples in each batch.")
parser.add_argument("--testing_interval", type=int, default=4,
help="No. of training iterations to run before testing.")
parser.add_argument("--learning_rate", type=float, default=0.001,
help="The initial learning rate.")
parser.add_argument("--momentum", type=float, default=0.9,
help="The initial momentum.")
parser.add_argument("--training_steps", type=int, default=40,
help="Num. of batches to run for each training iter.")
parser.add_argument("--testing_steps", type=int, default=24,
help="Num. of batches to run for each testing iter.")
parser.add_argument("--valid_iters", type=int, default=124,
help="How many iterations to validate for.")
parser.add_argument("--pose", action="store_true",
help="Whether the dataset supports head pose.")
parser.add_argument("--reg", type=float, default=0.0005,
help="Alpha value to use for l2 regularization.")
return parser
def main():
logging_config.configure_logging()
parser = parse_args()
# Create and start the experiment.
my_experiment = experiment.Experiment(parser)
my_experiment.run()
if __name__ == "__main__":
main()