Skip to content

Commit

Permalink
GH-474: Model interface for sequence labeling, classification and reg…
Browse files Browse the repository at this point in the history
…ression
  • Loading branch information
aakbik committed Apr 20, 2019
1 parent 56d0c82 commit 9d32043
Show file tree
Hide file tree
Showing 10 changed files with 872 additions and 880 deletions.
529 changes: 256 additions & 273 deletions flair/models/sequence_tagger_model.py

Large diffs are not rendered by default.

236 changes: 136 additions & 100 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
import flair.embeddings
from flair.data import Dictionary, Sentence, Label
from flair.file_utils import cached_path
from flair.training_utils import convert_labels_to_one_hot, clear_embeddings

from flair.training_utils import (
convert_labels_to_one_hot,
clear_embeddings,
Metric,
Result,
)

log = logging.getLogger("flair")

Expand Down Expand Up @@ -66,98 +70,26 @@ def forward(self, sentences) -> List[List[float]]:

return label_scores

def save(self, model_file: Union[str, Path]):
"""
Saves the current model to the provided file.
:param model_file: the model file
"""
model_state = {
"state_dict": self.state_dict(),
"document_embeddings": self.document_embeddings,
"label_dictionary": self.label_dictionary,
"multi_label": self.multi_label,
}
torch.save(model_state, str(model_file), pickle_protocol=4)

def save_checkpoint(
self,
model_file: Union[str, Path],
optimizer_state: dict,
scheduler_state: dict,
epoch: int,
loss: float,
):
"""
Saves the current model to the provided file.
:param model_file: the model file
"""
def _get_state_dict(self):
model_state = {
"state_dict": self.state_dict(),
"document_embeddings": self.document_embeddings,
"label_dictionary": self.label_dictionary,
"multi_label": self.multi_label,
"optimizer_state_dict": optimizer_state,
"scheduler_state_dict": scheduler_state,
"epoch": epoch,
"loss": loss,
}
torch.save(model_state, str(model_file), pickle_protocol=4)
return model_state

@classmethod
def load_from_file(cls, model_file: Union[str, Path]):
"""
Loads the model from the given file.
:param model_file: the model file
:return: the loaded text classifier model
"""
state = TextClassifier._load_state(model_file)
def _init_model_with_state_dict(state):

model = TextClassifier(
document_embeddings=state["document_embeddings"],
label_dictionary=state["label_dictionary"],
multi_label=state["multi_label"],
)
model.load_state_dict(state["state_dict"])
model.eval()
model.to(flair.device)

model.load_state_dict(state["state_dict"])
return model

@classmethod
def load_checkpoint(cls, model_file: Union[str, Path]):
state = TextClassifier._load_state(model_file)
model = TextClassifier.load_from_file(model_file)

epoch = state["epoch"] if "epoch" in state else None
loss = state["loss"] if "loss" in state else None
optimizer_state_dict = (
state["optimizer_state_dict"] if "optimizer_state_dict" in state else None
)
scheduler_state_dict = (
state["scheduler_state_dict"] if "scheduler_state_dict" in state else None
)

return {
"model": model,
"epoch": epoch,
"loss": loss,
"optimizer_state_dict": optimizer_state_dict,
"scheduler_state_dict": scheduler_state_dict,
}

@classmethod
def _load_state(cls, model_file: Union[str, Path]):
# ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
# serialization of torch objects
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
# see https://github.com/zalandoresearch/flair/issues/351
f = flair.file_utils.load_big_file(str(model_file))
state = torch.load(f, map_location=flair.device)
return state

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.tensor:
scores = self.forward(sentences)
return self._calculate_loss(scores, sentences)
Expand Down Expand Up @@ -201,6 +133,112 @@ def predict(

return sentences

def evaluate(
self,
sentences: List[Sentence],
eval_mini_batch_size: int = 32,
embeddings_in_memory: bool = False,
out_path: Path = None,
) -> (Result, float):

with torch.no_grad():
eval_loss = 0

batches = [
sentences[x : x + eval_mini_batch_size]
for x in range(0, len(sentences), eval_mini_batch_size)
]

metric = Metric("Evaluation")

lines: List[str] = []
for batch in batches:

labels, loss = self.forward_labels_and_loss(batch)

clear_embeddings(
batch, also_clear_word_embeddings=not embeddings_in_memory
)

eval_loss += loss

sentences_for_batch = [sent.to_plain_string() for sent in batch]
confidences_for_batch = [
[label.score for label in sent_labels] for sent_labels in labels
]
predictions_for_batch = [
[label.value for label in sent_labels] for sent_labels in labels
]
true_values_for_batch = [
sentence.get_label_names() for sentence in batch
]
available_labels = self.label_dictionary.get_items()

for sentence, confidence, prediction, true_value in zip(
sentences_for_batch,
confidences_for_batch,
predictions_for_batch,
true_values_for_batch,
):
eval_line = "{}\t{}\t{}\t{}\n".format(
sentence, true_value, prediction, confidence
)
lines.append(eval_line)

for predictions_for_sentence, true_values_for_sentence in zip(
predictions_for_batch, true_values_for_batch
):

for label in available_labels:
if (
label in predictions_for_sentence
and label in true_values_for_sentence
):
metric.add_tp(label)
elif (
label in predictions_for_sentence
and label not in true_values_for_sentence
):
metric.add_fp(label)
elif (
label not in predictions_for_sentence
and label in true_values_for_sentence
):
metric.add_fn(label)
elif (
label not in predictions_for_sentence
and label not in true_values_for_sentence
):
metric.add_tn(label)

eval_loss /= len(sentences)

detailed_result = (
f"\nMICRO_AVG: acc {metric.micro_avg_accuracy()} - f1-score {metric.micro_avg_f_score()}"
f"\nMACRO_AVG: acc {metric.macro_avg_accuracy()} - f1-score {metric.macro_avg_f_score()}"
)
for class_name in metric.get_classes():
detailed_result += (
f"\n{class_name:<10} tp: {metric.get_tp(class_name)} - fp: {metric.get_fp(class_name)} - "
f"fn: {metric.get_fn(class_name)} - tn: {metric.get_tn(class_name)} - precision: "
f"{metric.precision(class_name):.4f} - recall: {metric.recall(class_name):.4f} - "
f"accuracy: {metric.accuracy(class_name):.4f} - f1-score: "
f"{metric.f_score(class_name):.4f}"
)

result = Result(
main_score=metric.micro_avg_f_score(),
log_line=f"{metric.precision()}\t{metric.recall()}\t{metric.micro_avg_f_score()}",
log_header="PRECISION\tRECALL\tF1",
detailed_results=detailed_result,
)

if out_path is not None:
with open(out_path, "w", encoding="utf-8") as outfile:
outfile.write("".join(lines))

return result, eval_loss

@staticmethod
def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
Expand All @@ -213,8 +251,8 @@ def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
return filtered_sentences

def _calculate_loss(
self, scores: List[List[float]], sentences: List[Sentence]
) -> float:
self, scores: torch.tensor, sentences: List[Sentence]
) -> torch.tensor:
"""
Calculates the loss.
:param scores: the prediction scores from the model
Expand Down Expand Up @@ -293,29 +331,27 @@ def _labels_to_indices(self, sentences: List[Sentence]):

return vec

@staticmethod
def load(model: str):
model_file = None
def _fetch_model(model_name) -> str:

model_map = {}
aws_resource_path = (
"https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.4"
)
cache_dir = Path("models")

if model.lower() == "de-offensive-language":
base_path = "/".join(
[
aws_resource_path,
"TEXT-CLASSIFICATION_germ-eval-2018_task-1",
"germ-eval-2018-task-1.pt",
]
)
model_file = cached_path(base_path, cache_dir=cache_dir)
model_map["de-offensive-language"] = "/".join(
[
aws_resource_path,
"TEXT-CLASSIFICATION_germ-eval-2018_task-1",
"germ-eval-2018-task-1.pt",
]
)

elif model.lower() == "en-sentiment":
base_path = "/".join(
[aws_resource_path, "TEXT-CLASSIFICATION_imdb", "imdb.pt"]
)
model_file = cached_path(base_path, cache_dir=cache_dir)
model_map["en-sentiment"] = "/".join(
[aws_resource_path, "TEXT-CLASSIFICATION_imdb", "imdb.pt"]
)

cache_dir = Path("models")
if model_name in model_map:
model_name = cached_path(model_map[model_name], cache_dir=cache_dir)

if model_file is not None:
return TextClassifier.load_from_file(model_file)
return model_name
Loading

0 comments on commit 9d32043

Please sign in to comment.