-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflair_conll2003_en.py
72 lines (57 loc) · 2.16 KB
/
flair_conll2003_en.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import (
WordEmbeddings,
StackedEmbeddings,
PooledFlairEmbeddings,
)
"""
based on: "flair/resources/docs/EXPERIMENTS.md"
"""
def build_conll03en_corpus(base_path: str):
document_as_sequence = False
corpus = ColumnCorpus(
base_path,
column_format={0: "text", 1: "pos", 2: "np", 3: "ner"},
train_file="train.txt",
dev_file="dev.txt",
test_file="test.txt",
tag_to_bioes="ner",
document_separator_token=None if not document_as_sequence else "-DOCSTART-",
)
tag_type = "ner"
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
return corpus, tag_type, tag_dictionary
def build_and_train_conll03en_flair_sequence_tagger(corpus,tag_type,tag_dictionary):
'''
do not change!
same configuration as described in
file: "flair/resources/docs/EXPERIMENTS.md"
section: "CoNLL-03 Named Entity Recognition (English)"
'''
embeddings: StackedEmbeddings = StackedEmbeddings(
embeddings=[
WordEmbeddings("glove"),
PooledFlairEmbeddings("news-forward", pooling="min"),
PooledFlairEmbeddings("news-backward", pooling="min"),
]
)
from flair.models import SequenceTagger
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
)
from flair.trainers import ModelTrainer
corpus = Corpus(train=corpus.train, dev=corpus.dev,test=[])
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
# trainer.train("resources/taggers/example-ner", train_with_dev=True, max_epochs=150) # original
trainer.train("flair_checkpoints", train_with_dev=False, max_epochs=40,save_final_model=False) # original
return tagger
if __name__ == "__main__":
HOME = os.environ["HOME"] #+ "/hpc"
base_path = HOME + "/FARM/data/conll03-en"
corpus, tag_type, tag_dictionary = build_conll03en_corpus(base_path)
tagger = build_and_train_conll03en_flair_sequence_tagger(corpus,tag_type,tag_dictionary)