diff --git a/tencentpretrain/utils/config.py b/tencentpretrain/utils/config.py index 6f221a3..b5d2f6e 100644 --- a/tencentpretrain/utils/config.py +++ b/tencentpretrain/utils/config.py @@ -3,21 +3,21 @@ from argparse import Namespace -def load_hyperparam(default_args): +def load_hyperparam(args): """ Load arguments form argparse and config file Priority: default options < config file < command line args """ - with open(default_args.config_path, mode="r", encoding="utf-8") as f: + with open(args.config_path, mode="r", encoding="utf-8") as f: config_args_dict = json.load(f) - default_args_dict = vars(default_args) + args_dict = vars(args) - command_line_args_dict = {k: default_args_dict[k] for k in [ + command_line_args_dict = {k: args_dict[k] for k in [ a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a) ]} - default_args_dict.update(config_args_dict) - default_args_dict.update(command_line_args_dict) - args = Namespace(**default_args_dict) + args_dict.update(config_args_dict) + args_dict.update(command_line_args_dict) + args = Namespace(**args_dict) return args