Skip to content

Commit

Permalink
GH-474: adapt Plotter to different types of tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Apr 23, 2019
1 parent dc559f3 commit 6e36bfd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 86 deletions.
2 changes: 0 additions & 2 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def evaluate(
with open(out_path, "w", encoding="utf-8") as outfile:
outfile.write("".join(lines))

print(metric)

detailed_result = (
f"\nMICRO_AVG: acc {metric.micro_avg_accuracy()} - f1-score {metric.micro_avg_f_score()}"
f"\nMACRO_AVG: acc {metric.macro_avg_accuracy()} - f1-score {metric.macro_avg_f_score()}"
Expand Down
13 changes: 10 additions & 3 deletions flair/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@


class Model(torch.nn.Module):
"""Abstract base class for all models. Every new type of model must implement these methods."""
"""Abstract base class for all downstream task models in Flair, such as SequenceTagger and TextClassifier.
Every new type of model must implement these methods."""

@abstractmethod
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor:
"""Performs a forward pass and returns the loss."""
"""Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training."""
pass

@abstractmethod
def predict(
self, sentences: Union[List[Sentence], Sentence], mini_batch_size=32
) -> List[Sentence]:
"""Predicts the labels/tags for the given list of sentences. The labels/tags are added directly to the
sentences."""
sentences. Implement this to enable prediction."""
pass

@abstractmethod
Expand All @@ -36,14 +37,20 @@ def evaluate(
embeddings_in_memory: bool = False,
out_path: Path = None,
) -> (Result, float):
"""Evaluates the model on a list of gold-labeled Sentences. Returns a Result object containing evaluation
results and a loss value. Implement this to enable evaluation."""
pass

@abstractmethod
def _get_state_dict(self):
"""Returns the state dictionary for this model. Implementing this enables the save() and save_checkpoint()
functionality."""
pass

@abstractmethod
def _init_model_with_state_dict(state):
"""Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint()
functionality."""
pass

@abstractmethod
Expand Down
144 changes: 63 additions & 81 deletions flair/visual/training_curves.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Union
from typing import Union, List

import numpy as np
import csv
Expand All @@ -17,6 +18,8 @@
WEIGHT_NUMBER = 2
WEIGHT_VALUE = 3

log = logging.getLogger("flair")


class Plotter(object):
"""
Expand All @@ -26,50 +29,50 @@ class Plotter(object):
"""

@staticmethod
def _extract_evaluation_data(file_name: Path) -> dict:
def _extract_evaluation_data(file_name: Path, score: str = "PEARSON") -> dict:
training_curves = {
"train": {"loss": [], "f_score": [], "acc": []},
"test": {"loss": [], "f_score": [], "acc": []},
"dev": {"loss": [], "f_score": [], "acc": []},
"train": {"loss": [], "score": []},
"test": {"loss": [], "score": []},
"dev": {"loss": [], "score": []},
}

with open(file_name, "r") as tsvin:
tsvin = csv.reader(tsvin, delimiter="\t")

# determine the column index of loss, f-score and accuracy for train, dev and test split
row = next(tsvin, None)
TRAIN_LOSS = row.index("TRAIN_LOSS")
TRAIN_F_SCORE = row.index("TRAIN_F-SCORE")
TRAIN_ACCURACY = row.index("TRAIN_ACCURACY")
DEV_LOSS = row.index("DEV_LOSS")
DEV_F_SCORE = row.index("DEV_F-SCORE")
DEV_ACCURACY = row.index("DEV_ACCURACY")
TEST_LOSS = row.index("TEST_LOSS")
TEST_F_SCORE = row.index("TEST_F-SCORE")
TEST_ACCURACY = row.index("TEST_ACCURACY")

score = score.upper()

if f"TEST_{score}" not in row:
log.warning("-" * 100)
log.warning(f"WARNING: No {score} found for test split in this data.")
log.warning(
f"Are you sure you want to plot {score} and not another value such as PEARSON?"
)
log.warning("-" * 100)

TRAIN_SCORE = (
row.index(f"TRAIN_{score}") if f"TRAIN_{score}" in row else None
)
DEV_SCORE = row.index(f"DEV_{score}") if f"DEV_{score}" in row else None
TEST_SCORE = row.index(f"TEST_{score}")

# then get all relevant values from the tsv
for row in tsvin:
if row[TRAIN_LOSS] != "_":
training_curves["train"]["loss"].append(float(row[TRAIN_LOSS]))
if row[TRAIN_F_SCORE] != "_":
training_curves["train"]["f_score"].append(
float(row[TRAIN_F_SCORE])
)
if row[TRAIN_ACCURACY] != "_":
training_curves["train"]["acc"].append(float(row[TRAIN_ACCURACY]))
if row[DEV_LOSS] != "_":
training_curves["dev"]["loss"].append(float(row[DEV_LOSS]))
if row[DEV_F_SCORE] != "_":
training_curves["dev"]["f_score"].append(float(row[DEV_F_SCORE]))
if row[DEV_ACCURACY] != "_":
training_curves["dev"]["acc"].append(float(row[DEV_ACCURACY]))
if row[TEST_LOSS] != "_":
training_curves["test"]["loss"].append(float(row[TEST_LOSS]))
if row[TEST_F_SCORE] != "_":
training_curves["test"]["f_score"].append(float(row[TEST_F_SCORE]))
if row[TEST_ACCURACY] != "_":
training_curves["test"]["acc"].append(float(row[TEST_ACCURACY]))

if TRAIN_SCORE is not None:
if row[TRAIN_SCORE] != "_":
training_curves["train"]["score"].append(
float(row[TRAIN_SCORE])
)

if DEV_SCORE is not None:
if row[DEV_SCORE] != "_":
training_curves["dev"]["score"].append(float(row[DEV_SCORE]))

if row[TEST_SCORE] != "_":
training_curves["test"]["score"].append(float(row[TEST_SCORE]))

return training_curves

Expand Down Expand Up @@ -156,58 +159,37 @@ def plot_weights(self, file_name: Union[str, Path]):

plt.close(fig)

def plot_training_curves(self, file_name: Union[str, Path]):
def plot_training_curves(
self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"]
):
if type(file_name) is str:
file_name = Path(file_name)

fig = plt.figure(figsize=(15, 10))

training_curves = self._extract_evaluation_data(file_name)

# plot 1
plt.subplot(3, 1, 1)
if training_curves["train"]["loss"]:
x = np.arange(0, len(training_curves["train"]["loss"]))
plt.plot(x, training_curves["train"]["loss"], label="training loss")
if training_curves["dev"]["loss"]:
x = np.arange(0, len(training_curves["dev"]["loss"]))
plt.plot(x, training_curves["dev"]["loss"], label="validation loss")
if training_curves["test"]["loss"]:
x = np.arange(0, len(training_curves["test"]["loss"]))
plt.plot(x, training_curves["test"]["loss"], label="test loss")
plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0)
plt.ylabel("loss")
plt.xlabel("epochs")

# plot 2
plt.subplot(3, 1, 2)
if training_curves["train"]["acc"]:
x = np.arange(0, len(training_curves["train"]["acc"]))
plt.plot(x, training_curves["train"]["acc"], label="training accuracy")
if training_curves["dev"]["acc"]:
x = np.arange(0, len(training_curves["dev"]["acc"]))
plt.plot(x, training_curves["dev"]["acc"], label="validation accuracy")
if training_curves["test"]["acc"]:
x = np.arange(0, len(training_curves["test"]["acc"]))
plt.plot(x, training_curves["test"]["acc"], label="test accuracy")
plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0)
plt.ylabel("accuracy")
plt.xlabel("epochs")

# plot 3
plt.subplot(3, 1, 3)
if training_curves["train"]["f_score"]:
x = np.arange(0, len(training_curves["train"]["f_score"]))
plt.plot(x, training_curves["train"]["f_score"], label="training f1-score")
if training_curves["dev"]["f_score"]:
x = np.arange(0, len(training_curves["dev"]["f_score"]))
plt.plot(x, training_curves["dev"]["f_score"], label="validation f1-score")
if training_curves["test"]["f_score"]:
x = np.arange(0, len(training_curves["test"]["f_score"]))
plt.plot(x, training_curves["test"]["f_score"], label="test f1-score")
plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0)
plt.ylabel("f1-score")
plt.xlabel("epochs")
for plot_no, plot_value in enumerate(plot_values):

training_curves = self._extract_evaluation_data(file_name, plot_value)

plt.subplot(len(plot_values), 1, plot_no + 1)
if training_curves["train"]["score"]:
x = np.arange(0, len(training_curves["train"]["score"]))
plt.plot(
x, training_curves["train"]["score"], label=f"training {plot_value}"
)
if training_curves["dev"]["score"]:
x = np.arange(0, len(training_curves["dev"]["score"]))
plt.plot(
x, training_curves["dev"]["score"], label=f"validation {plot_value}"
)
if training_curves["test"]["score"]:
x = np.arange(0, len(training_curves["test"]["score"]))
plt.plot(
x, training_curves["test"]["score"], label=f"test {plot_value}"
)
plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0)
plt.ylabel(plot_value)
plt.xlabel("epochs")

# save plots
plt.tight_layout(pad=1.0)
Expand Down

0 comments on commit 6e36bfd

Please sign in to comment.