From 5e77d62653263f4dfef4369b3b781add39371abc Mon Sep 17 00:00:00 2001 From: aakbik Date: Sat, 20 Apr 2019 18:14:39 +0200 Subject: [PATCH] GH-474: Model interface for sequence labeling, classification and regression --- flair/hyperparameter/param_selection.py | 2 + flair/models/sequence_tagger_model.py | 2 + flair/trainers/trainer.py | 132 +++++++++++------------- flair/training_utils.py | 2 + tests/test_text_regressor.py | 6 +- 5 files changed, 71 insertions(+), 73 deletions(-) diff --git a/flair/hyperparameter/param_selection.py b/flair/hyperparameter/param_selection.py index 6f234269bf..e85164b47c 100644 --- a/flair/hyperparameter/param_selection.py +++ b/flair/hyperparameter/param_selection.py @@ -120,6 +120,8 @@ def _objective(self, params: dict): curr_scores = list( map(lambda s: 1 - s, result["dev_score_history"][-3:]) ) + print(result) + print(curr_scores) score = sum(curr_scores) / float(len(curr_scores)) var = np.var(curr_scores) scores.append(score) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 21d5699182..cb68533171 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -282,6 +282,8 @@ def evaluate( with open(out_path, "w", encoding="utf-8") as outfile: outfile.write("".join(lines)) + print(metric) + detailed_result = ( f"\nMICRO_AVG: acc {metric.micro_avg_accuracy()} - f1-score {metric.micro_avg_f_score()}" f"\nMACRO_AVG: acc {metric.macro_avg_accuracy()} - f1-score {metric.macro_avg_f_score()}" diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 2cb0e9af4f..f685e13e68 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -52,7 +52,6 @@ def train( max_epochs: int = 100, anneal_factor: float = 0.5, patience: int = 3, - anneal_against_train_loss: bool = True, train_with_dev: bool = False, monitor_train: bool = False, embeddings_in_memory: bool = True, @@ -81,31 +80,29 @@ def train( log_test = True if (not param_selection_mode and self.corpus.test) else False log_dev = True if not train_with_dev else False - if not param_selection_mode: - loss_txt = init_output_file(base_path, "loss.tsv") - with open(loss_txt, "a") as f: - f.write(f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS") + loss_txt = init_output_file(base_path, "loss.tsv") + with open(loss_txt, "a") as f: + f.write(f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS") - dummy_result, _ = self.model.evaluate( - [Sentence("d", labels=["0.1"])], - eval_mini_batch_size, - embeddings_in_memory, + dummy_result, _ = self.model.evaluate( + [Sentence("d", labels=["0.1"])], + eval_mini_batch_size, + embeddings_in_memory, + ) + if log_train: + f.write( + "\tTRAIN_" + "\tTRAIN_".join(dummy_result.log_header.split("\t")) + ) + if log_dev: + f.write( + "\tDEV_LOSS\tDEV_" + + "\tDEV_".join(dummy_result.log_header.split("\t")) + ) + if log_test: + f.write( + "\tTEST_LOSS\tTEST_" + + "\tTEST_".join(dummy_result.log_header.split("\t")) ) - if log_train: - f.write( - "\tTRAIN_" - + "\tTRAIN_".join(dummy_result.log_header.split("\t")) - ) - if log_dev: - f.write( - "\tDEV_LOSS\tDEV_" - + "\tDEV_".join(dummy_result.log_header.split("\t")) - ) - if log_test: - f.write( - "\tTEST_LOSS\tTEST_" - + "\tTEST_".join(dummy_result.log_header.split("\t")) - ) weight_extractor = WeightExtractor(base_path) @@ -113,8 +110,9 @@ def train( if self.optimizer_state is not None: optimizer.load_state_dict(self.optimizer_state) - # annealing scheduler - anneal_mode = "min" if anneal_against_train_loss else "max" + # minimize training loss if training with dev data, else maximize dev score + anneal_mode = "min" if train_with_dev else "max" + if isinstance(optimizer, (AdamW, SGDW)): scheduler = ReduceLRWDOnPlateau( optimizer, @@ -224,51 +222,47 @@ def train( f"EPOCH {epoch + 1} done: loss {train_loss:.4f} - lr {learning_rate:.4f} - bad epochs {bad_epochs}" ) - dev_loss = "_" - if not param_selection_mode: - with open(loss_txt, "a") as f: + # anneal against train loss if training with dev, otherwise anneal against dev score + current_score = train_loss - f.write( - f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}" + with open(loss_txt, "a") as f: + + f.write( + f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}" + ) + + if log_train: + train_eval_result, train_loss = self.model.evaluate( + self.corpus.train, + eval_mini_batch_size, + embeddings_in_memory, ) + f.write(f"\t{train_eval_result.log_line}") - if log_train: - train_eval_result, train_loss = self.model.evaluate( - self.corpus.train, - eval_mini_batch_size, - embeddings_in_memory, - ) - f.write(f"\t{train_eval_result.log_line}") + if log_dev: + dev_eval_result, dev_loss = self.model.evaluate( + self.corpus.dev, eval_mini_batch_size, embeddings_in_memory + ) + f.write(f"\t{dev_loss}\t{dev_eval_result.log_line}") - if log_dev: - dev_eval_result, dev_loss = self.model.evaluate( - self.corpus.dev, - eval_mini_batch_size, - embeddings_in_memory, - ) - f.write(f"\t{dev_loss}\t{dev_eval_result.log_line}") - - if log_test: - test_eval_result, test_loss = self.model.evaluate( - self.corpus.test, - eval_mini_batch_size, - embeddings_in_memory, - base_path / "test.tsv", - ) - f.write(f"\t{test_loss}\t{test_eval_result.log_line}") - log.info( - f"TEST : loss {test_loss} - score {test_eval_result.main_score}" - ) + # calculate scores using dev data if available + # append dev score to score history + dev_score_history.append(dev_eval_result.main_score) + dev_loss_history.append(dev_loss) - # calculate scores using dev data if available - dev_score = 0.0 - if not train_with_dev: - # append dev score to score history - dev_score_history.append(dev_score) - dev_loss_history.append(dev_loss.item()) + current_score = dev_eval_result.main_score - # anneal against train loss if training with dev, otherwise anneal against dev score - current_score = train_loss if anneal_against_train_loss else dev_score + if log_test: + test_eval_result, test_loss = self.model.evaluate( + self.corpus.test, + eval_mini_batch_size, + embeddings_in_memory, + base_path / "test.tsv", + ) + f.write(f"\t{test_loss}\t{test_eval_result.log_line}") + log.info( + f"TEST : loss {test_loss} - score {test_eval_result.main_score}" + ) scheduler.step(current_score) @@ -351,13 +345,11 @@ def final_test( if type(self.corpus) is MultiCorpus: for subcorpus in self.corpus.corpora: log_line(log) - self._calculate_evaluation_results_for( - subcorpus.name, + self.model.evaluate( subcorpus.test, - evaluation_metric, - embeddings_in_memory, eval_mini_batch_size, - base_path / "test.tsv", + embeddings_in_memory, + base_path / f"{subcorpus.name}-test.tsv", ) # get and return the final test score of best model diff --git a/flair/training_utils.py b/flair/training_utils.py index 6176d4be9f..a4de2a8c72 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -112,6 +112,8 @@ def micro_avg_f_score(self): def macro_avg_f_score(self): class_f_scores = [self.f_score(class_name) for class_name in self.get_classes()] + if len(class_f_scores) == 0: + return 0.0 macro_f_score = sum(class_f_scores) / len(class_f_scores) return macro_f_score diff --git a/tests/test_text_regressor.py b/tests/test_text_regressor.py index 602a663f0b..c5abe82925 100644 --- a/tests/test_text_regressor.py +++ b/tests/test_text_regressor.py @@ -10,7 +10,7 @@ from flair.trainers import ModelTrainer -def init(tasks_base_path) -> Tuple[TaggedCorpus, TextRegressor]: +def init(tasks_base_path) -> Tuple[TaggedCorpus, TextRegressor, ModelTrainer]: corpus = NLPTaskDataFetcher.load_corpus(NLPTask.REGRESSION, tasks_base_path) glove_embedding: WordEmbeddings = WordEmbeddings("glove") @@ -18,7 +18,7 @@ def init(tasks_base_path) -> Tuple[TaggedCorpus, TextRegressor]: [glove_embedding], 128, 1, False, 64, False, False ) - model = TextRegressor(document_embeddings, Dictionary(), False) + model = TextRegressor(document_embeddings) trainer = ModelTrainer(model, corpus) @@ -40,7 +40,7 @@ def test_labels_to_indices(tasks_base_path): def test_trainer_evaluation(tasks_base_path): corpus, model, trainer = init(tasks_base_path) - expected = trainer._evaluate_text_regressor(model, corpus.dev) + expected = model.evaluate(corpus.dev) assert expected is not None