From a6588eae153fdf7c3f48a7eeb33f08b8f9338a4e Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 19 Jun 2024 21:03:23 +0400 Subject: [PATCH] upgrade to 2.0-style for ray.tune --- search_params.py | 53 +++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/search_params.py b/search_params.py index 911e7344..a9d9e197 100644 --- a/search_params.py +++ b/search_params.py @@ -141,7 +141,10 @@ def init_search_algorithm(search_alg, metric=None, mode=None): from ray.tune.search.bayesopt import BayesOptSearch return BayesOptSearch(metric=metric, mode=mode) - logging.info(f"{search_alg} search is found, run BasicVariantGenerator().") + elif search_alg == "basic_variant": + pass + else: + logging.info(f"No Search algorithm is found. Will run basic_variant") def prepare_retrain_config(best_config, best_log_dir, retrain): @@ -171,13 +174,11 @@ def prepare_retrain_config(best_config, best_log_dir, retrain): best_config.merge_train_val = False -def load_static_data(config, merge_train_val=False): +def load_static_data(config): """Preload static data once for multiple trials. Args: config (AttributeDict): Config of the experiment. - merge_train_val (bool, optional): Whether to merge the training and validation data. - Defaults to False. Returns: dict: A dict of static data containing datasets, classes, and word_dict. @@ -187,7 +188,7 @@ def load_static_data(config, merge_train_val=False): test_data=config.test_file, val_data=config.val_file, val_size=config.val_size, - merge_train_val=merge_train_val, + merge_train_val=config.merge_train_val, tokenize_text="lm_weight" not in config.network_config, remove_no_label_data=config.remove_no_label_data, ) @@ -231,7 +232,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, retrain): with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp: yaml.dump(dict(best_config), fp) - data = load_static_data(best_config, merge_train_val=best_config.merge_train_val) + data = load_static_data(best_config) if retrain: logging.info(f"Re-training with best config: \n{best_config}") @@ -303,7 +304,7 @@ def main(): config = init_search_params_spaces(config, parameter_columns, prefix="") parser.set_defaults(**config) config = AttributeDict(vars(parser.parse_args())) - # no need to include validation during parameter search + # Validation sets are mandatoray during parameter search config.merge_train_val = False config.mode = "min" if config.val_metric == "Loss" else "max" @@ -344,20 +345,34 @@ def main(): Path(config.config).stem if config.config else config.model_name, datetime.now().strftime("%Y%m%d%H%M%S"), ) - analysis = tune.run( - tune.with_parameters(train_libmultilabel_tune, **data), - search_alg=init_search_algorithm(config.search_alg, metric=f"val_{config.val_metric}", mode=config.mode), - scheduler=scheduler, - storage_path=config.result_dir, - num_samples=config.num_samples, - resources_per_trial={"cpu": config.cpu_count, "gpu": config.gpu_count}, - progress_reporter=reporter, - config=config, - name=exp_name, + + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(train_libmultilabel_tune, **data), + resources={"cpu": config.cpu_count, "gpu": config.gpu_count}, + ), + param_space=config, + tune_config=tune.TuneConfig( + metric=f"val_{config.val_metric}", + mode=config.mode, + scheduler=scheduler, + num_samples=config.num_samples, + search_alg=init_search_algorithm( + search_alg=config.search_alg, + metric=f"val_{config.val_metric}", + mode=config.mode, + ), + ), + run_config=ray_train.RunConfig( + name=exp_name, + storage_path=config.result_dir, + progress_reporter=reporter, + ), ) + results = tuner.fit() # Save best model after parameter search. - best_trial = analysis.get_best_trial(metric=f"val_{config.val_metric}", mode=config.mode, scope="all") - retrain_best_model(exp_name, best_trial.config, best_trial.local_path, retrain=not config.no_retrain) + best_result = results.get_best_result(metric=f"val_{config.val_metric}", mode=config.mode, scope="all") + retrain_best_model(exp_name, best_result.config, best_result.path, retrain=not config.no_retrain) if __name__ == "__main__":