-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
34 lines (25 loc) · 876 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import importlib
import os
import sys
import absl
from utilities.utils import define_flags_with_default
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("-g", type=int, default=0)
args, unknown_flags = parser.parse_known_args()
if args.g < 0:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.g)
from utilities.utils import import_file
config = getattr(import_file(args.config, "default_config"), "get_config")()
config = define_flags_with_default(**config)
absl.flags.FLAGS(sys.argv[:1] + unknown_flags)
trainer = getattr(
importlib.import_module("diffuser.trainer"), absl.flags.FLAGS.trainer
)(config)
trainer.train()
if __name__ == "__main__":
main()