-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
121 lines (110 loc) · 4.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
'''
Copyright (C) 2020, Northwestern University and Lawrence Berkeley National Laboratory
See COPYRIGHT notice in top-level directory.
'''
import tensorflow as tf
import time
import argparse
import numpy as np
import multiprocessing as mp
import horovod.tensorflow as hvd
from model import model
from train import Trainer
from io_daemon import IOdaemon
from feeder_sync import cosmoflow_sync
from feeder_async import cosmoflow_async
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", type = int, default = 8,
help = "number of training samples for each mini-batch")
parser.add_argument("-o", "--overlap", type = int, default = 1,
help = "0: do not overlap I/O with computation, 1: overlap I/O with computation")
parser.add_argument("-p", "--checkpoint", type = int, default = 0,
help = "0: do not checkpoint the model, 1: checkpoint the model")
parser.add_argument("-e", "--epochs", type = int, default = 1,
help = "number of epochs")
parser.add_argument("-k", "--buffer_size", type = int, default = 128,
help = "buffer size with respect to the number of samples")
parser.add_argument("-f", "--file_shuffle", type = int, default = 0,
help = "shuffle the files across the processes")
parser.add_argument("-r", "--record_acc", type = int, default = 0,
help = "write the loss and accuracy into output files")
parser.add_argument("-v", "--evaluate", type = int, default = 0,
help = "evaluate the model after every epoch")
parser.add_argument("-y", "--config", default = "test.yaml",
help = "yaml file that describes the input data")
parser.add_argument("-a", "--async_io", type = int, default = 1,
help = "asynchronous I/O module")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_parser()
# Initialize Horovod.tensorflow.
hvd.init()
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
# On Summit, each resource set can view its own GPUs only.
# So, the visible devices should be set to gpu:0 for every process.
#tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
# multiprocessing synchronization
lock = mp.Lock()
cv = mp.Condition(lock = lock)
finish = mp.Value('i')
finish.value = 0
data_buffer_size = args.buffer_size * 128 * 128 * 128 * 12
label_buffer_size = args.buffer_size * 4
# Initialize the shared memory space between TF process and I/O process.
data = []
label = []
num_samples = []
if args.overlap == 0:
num_buffers = 1
else:
num_buffers = 2
for i in range(num_buffers):
data.append(mp.RawArray('H', data_buffer_size))
label.append(mp.RawArray('f', label_buffer_size))
num_samples.append(mp.Value('i'))
num_samples[i].value = 0
# Initialize model, dataset, and trainer.
cosmo_model = model()
if args.async_io == 1:
dataset = cosmoflow_async(args.config, lock, cv,
data, label, num_samples,
batch_size = args.batch_size,
buffer_size = args.buffer_size)
# Initialize the I/O daemon.
async_io_module = IOdaemon(dataset,
args.file_shuffle,
args.buffer_size)
else:
dataset = cosmoflow_sync(args.config,
do_shuffle = args.file_shuffle,
batch_size = args.batch_size,
buffer_size = args.buffer_size)
trainer = Trainer(cosmo_model,
args.async_io,
dataset,
do_shuffle = args.file_shuffle,
num_epochs = args.epochs,
do_checkpoint = args.checkpoint,
do_record_acc = args.record_acc,
do_evaluate = args.evaluate)
if args.async_io == 1:
io_process = mp.Process(target = async_io_module.run,
args = (lock, cv, finish,
data, label, num_samples))
io_process.start()
# Start the training.
trainer.train()
# Kill the I/O process and finish the program.
if args.async_io == 1:
lock.acquire()
finish.value = 1
cv.notify()
lock.release()
io_process.join()
print ("All done!")