-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathpos_stats.py
47 lines (36 loc) · 1.84 KB
/
pos_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Calculation of statistics that require a pos-tagger in the pipeline"""
from spacy.tokens import Doc, Span
from spacy.language import Language
from typing import Counter, Union
@Language.factory("pos_stats", default_config={"use_pos": True})
def create_pos_stats_component(nlp: Language, name: str, use_pos: bool):
"""Allows PosStats to be added to a spaCy pipe using nlp.add_pipe("pos_stats")"""
tagger = set(["tagger"])
if not tagger.intersection(set(nlp.pipe_names)):
raise ValueError("The pipeline does not contain a tagger. Please load a spaCy model which includes a 'tagger' component.")
return POSStatistics(nlp, use_pos=use_pos)
class POSStatistics:
"""spaCy v.3.0 component that adds attributes for POS statistics to `Doc` and `Span` objects.
"""
def __init__(self, nlp: Language, use_pos: bool):
"""Initialise components"""
self.use_pos = use_pos
if not Doc.has_extension("pos_proportions"):
Doc.set_extension("pos_proportions", getter=self.pos_proportions)
if not Span.has_extension("pos_proportions"):
Span.set_extension("pos_proportions", getter=self.pos_proportions)
def __call__(self, doc):
"""Run the pipeline component"""
return doc
def pos_proportions(self, text: Union[Doc, Span]) -> dict:
"""
Returns:
Dict containing {pos_prop_POSTAG: proportion of all tokens tagged with POSTAG. Does not create a key if no tokens in the document fit the POSTAG.
"""
pos_counts = Counter()
if self.use_pos:
pos_counts.update([token.pos_ for token in text])
else:
pos_counts.update([token.tag_ for token in text])
pos_proportions = {"pos_prop_" + tag: count / len(text) for tag, count in pos_counts.items()}
return pos_proportions