Skip to content

Commit

Permalink
GH-3488: Add a function to write a ColumnCorpus instance to files
Browse files Browse the repository at this point in the history
  • Loading branch information
chelseagzr committed Jul 10, 2024
1 parent 355e2c0 commit 1c8f4ad
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 1 deletion.
158 changes: 157 additions & 1 deletion flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
cast,
)

from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data import ConcatDataset, Dataset, Subset

import flair
from flair.data import (
Expand All @@ -28,6 +28,7 @@
MultiCorpus,
Relation,
Sentence,
Span,
Token,
get_spans_from_bio,
)
Expand Down Expand Up @@ -443,6 +444,161 @@ def __init__(
**corpusargs,
)

@staticmethod
def get_token_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]:
"""Generates a label for each token in the sentence. This function requires that the labels corresponding to the label_type are token-level tokens.
Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "pos"
"""
list_of_labels = ["O" for _ in range(len(sentence.tokens))]
for label in sentence.get_labels(label_type):
label_token_index = label.data_point._internal_index
list_of_labels[label_token_index - 1] = label.value
return list_of_labels

@staticmethod
def get_span_level_label_of_each_token(sentence: Sentence, label_type: str) -> List[str]:
"""Generates a label for each token in the sentence in BIO format. This function requires that the labels corresponding to the label_type are span-level tokens.
Args:
sentence: a flair sentence to generate labels for
label_type: a string representing the type of the labels, e.g., "ner"
"""
list_of_labels = ["O" for _ in range(len(sentence.tokens))]
for label in sentence.get_labels(label_type):
tokens = label.data_point.tokens
start_token_index = tokens[0]._internal_index
list_of_labels[start_token_index - 1] = f"B-{label.value}"
for token in tokens[1:]:
token_index = token._internal_index
list_of_labels[token_index - 1] = f"I-{label.value}"
return list_of_labels

@staticmethod
def write_dataset_to_file(
dataset: Dataset, file_path: Path, label_type_tuples: List[tuple], column_delimiter: str = "\t"
) -> None:
"""Writes a dataset to a file.
Following these two rules:
(1) the text and the label of every token is represented in one line separated by column_delimiter
(2) every sentence is separated from the previous one by an empty line
"""
with open(file_path, mode="w") as output_file:
for sentence in dataset:
texts = [token.text for token in sentence.tokens]
texts_and_labels = [texts]
for label_type, level in label_type_tuples:
if level is Token:
texts_and_labels.append(ColumnCorpus.get_token_level_label_of_each_token(sentence, label_type))
elif level is Span:
texts_and_labels.append(ColumnCorpus.get_span_level_label_of_each_token(sentence, label_type))
else:
raise NotImplementedError(f"The level of {label_type} is neither token nor span.")

for text_and_labels_of_a_token in zip(*texts_and_labels):
output_file.write(column_delimiter.join(text_and_labels_of_a_token) + "\n")
output_file.write("\n")

@classmethod
def load_corpus_with_meta_data(cls, directory: Path) -> "ColumnCorpus":
"""Creates a ColumnCorpus instance from the directory generated by 'write_corpus_meta_data'."""
with open(directory / "meta_data.json") as file:
meta_data = json.load(file)

meta_data["column_format"] = {int(key): value for key, value in meta_data["column_format"].items()}

return cls(
data_folder=directory,
autofind_splits=True,
skip_first_line=False,
**meta_data,
)

def get_level_of_label(self, label_type: str):
"""Gets level of label type by checking the first label in this corpus."""
for dataset in [self.train, self.dev, self.test]:
if dataset:
for sentence in dataset:
for label in sentence.get_labels(label_type):
if isinstance(label.data_point, Token):
return Token
elif isinstance(label.data_point, Span):
return Span
else:
raise NotImplementedError(
f"The level of {label_type} is neither token nor span. Only token level labels and span level labels can be handled now."
)
raise RuntimeError(f"There is no label of type {label_type} in this corpus.")

def write_corpus_meta_data(self, label_types: List[str], file_path: Path, column_delimiter: str) -> None:
"""Writes meta data of this corpus to a json file.
Note:
Currently, the whitespace_after attribute of each token will not be preserved. Only default_whitespace_after attribute of each dataset will be written to the file.
"""
meta_data = {
"name": self.name,
"sample_missing_splits": False,
"column_delimiter": column_delimiter,
}

column_format = {0: "text"}
for label_type_index, label_type in enumerate(label_types):
column_format[label_type_index + 1] = label_type
meta_data["column_format"] = column_format

nonempty_dataset = self.train or self.dev or self.test
MAX_DEPTH = 5
for _ in range(MAX_DEPTH):
if type(nonempty_dataset) is ColumnDataset:
break
elif type(nonempty_dataset) is ConcatDataset:
nonempty_dataset = nonempty_dataset.datasets[0]
elif type(nonempty_dataset) is Subset:
nonempty_dataset = nonempty_dataset.dataset
else:
raise NotImplementedError("Unsupported type")

if type(nonempty_dataset) is not ColumnDataset:
raise NotImplementedError("Unsupported type")

meta_data["encoding"] = nonempty_dataset.encoding
meta_data["in_memory"] = nonempty_dataset.in_memory
meta_data["banned_sentences"] = nonempty_dataset.banned_sentences
meta_data["default_whitespace_after"] = nonempty_dataset.default_whitespace_after

with open(file_path, mode="w") as output_file:
json.dump(meta_data, output_file)

def write_to_directory(self, label_types: List[str], output_directory: Path, column_delimiter: str = "\t") -> None:
"""Writes train, dev, test dataset (if exist) and the meta data of the corpus to a directory.
Note:
Only labels corresponding to label_types will be written.
Only token level or span level sequence tagging labels are supported.
Currently, the whitespace_after attribute of each token will not be preserved in the written file.
"""
label_type_tuples = [(label_type, self.get_level_of_label(label_type)) for label_type in label_types]

os.makedirs(output_directory, exist_ok=True)
if self.train:
ColumnCorpus.write_dataset_to_file(
self.train, output_directory / "train.conll", label_type_tuples, column_delimiter
)
if self.dev:
ColumnCorpus.write_dataset_to_file(
self.dev, output_directory / "dev.conll", label_type_tuples, column_delimiter
)
if self.test:
ColumnCorpus.write_dataset_to_file(
self.test, output_directory / "test.conll", label_type_tuples, column_delimiter
)

self.write_corpus_meta_data(label_types, output_directory / "meta_data.json", column_delimiter)


class ColumnDataset(FlairDataset):
# special key for space after
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path):
_assert_universal_dependencies_conllu_dataset(corpus.train)


def test_write_to_and_load_from_directory(tasks_base_path):
from pathlib import Path

corpus = ColumnCorpus(
tasks_base_path / "column_with_whitespaces",
train_file="eng.train",
column_format={0: "text", 1: "ner"},
column_delimiter=" ",
skip_first_line=False,
sample_missing_splits=False,
)
directory = Path("resources/taggers/")
corpus.write_to_directory(["ner"], directory, column_delimiter="\t")
loaded_corpus = ColumnCorpus.load_corpus_with_meta_data(directory)
assert len(loaded_corpus.train) == len(corpus.train)
assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string()


def test_hipe_2022_corpus(tasks_base_path):
# This test covers the complete HIPE 2022 dataset.
# https://github.com/hipe-eval/HIPE-2022-data
Expand Down

0 comments on commit 1c8f4ad

Please sign in to comment.