Skip to content

Commit

Permalink
Merge pull request #368 from MichalMalyska/abbreviation_detector_seri…
Browse files Browse the repository at this point in the history
…alization

Add abbreviation serializations
  • Loading branch information
dakinggg authored Jul 15, 2021
2 parents 45e0ca1 + 757b350 commit 3d153dd
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 14 deletions.
37 changes: 36 additions & 1 deletion scispacy/abbreviation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ class AbbreviationDetector:
nlp: `Language`, a required argument for spacy to use this as a factory
name: `str`, a required argument for spacy to use this as a factory
make_serializable: `bool`, a required argument for whether we want to use the serializable
or non serializable version.
"""

def __init__(self, nlp: Language, name: str = "abbreviation_detector") -> None:
def __init__(
self,
nlp: Language,
name: str = "abbreviation_detector",
make_serializable: bool = False,
) -> None:
Doc.set_extension("abbreviations", default=[], force=True)
Span.set_extension("long_form", default=None, force=True)

self.matcher = Matcher(nlp.vocab)
self.matcher.add("parenthesis", [[{"ORTH": "("}, {"OP": "+"}, {"ORTH": ")"}]])
self.make_serializable = make_serializable
self.global_matcher = Matcher(nlp.vocab)

def find(self, span: Span, doc: Doc) -> Tuple[Span, Set[Span]]:
Expand Down Expand Up @@ -186,6 +194,12 @@ def __call__(self, doc: Doc) -> Doc:
for short in short_forms:
short._.long_form = long_form
doc._.abbreviations.append(short)
if self.make_serializable:
abbreviations = doc._.abbreviations
doc._.abbreviations = [
self.make_short_form_serializable(abbreviation)
for abbreviation in abbreviations
]
return doc

def find_matches_for(
Expand Down Expand Up @@ -223,3 +237,24 @@ def find_matches_for(
self.global_matcher.remove(key)

return list((k, v) for k, v in all_occurences.items())

def make_short_form_serializable(self, abbreviation: Span):
"""
Converts the abbreviations into a short form that is serializable to enable multiprocessing
Parameters
----------
abbreviation: Span
The abbreviation span identified by the detector
"""
long_form = abbreviation._.long_form
abbreviation._.long_form = long_form.text
serializable_abbr = {
"short_text": abbreviation.text,
"short_start": abbreviation.start,
"short_end": abbreviation.end,
"long_text": long_form.text,
"long_start": long_form.start,
"long_end": long_form.end,
}
return serializable_abbr
21 changes: 12 additions & 9 deletions scispacy/linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,23 @@ def __init__(
self.umls = self.kb

def __call__(self, doc: Doc) -> Doc:
mentions = []
mention_strings = []
if self.resolve_abbreviations and Doc.has_extension("abbreviations"):

# TODO: This is possibly sub-optimal - we might
# prefer to look up both the long and short forms.
for ent in doc.ents:
# TODO: This is possibly sub-optimal - we might
# prefer to look up both the long and short forms.
if ent._.long_form is not None:
mentions.append(ent._.long_form)
if isinstance(ent._.long_form, Span):
# Long form
mention_strings.append(ent._.long_form.text)
elif isinstance(ent._.long_form, str):
# Long form
mention_strings.append(ent._.long_form)
else:
mentions.append(ent)
# no abbreviations case
mention_strings.append(ent.text)
else:
mentions = doc.ents
mention_strings = [x.text for x in doc.ents]

mention_strings = [x.text for x in mentions]
batch_candidates = self.candidate_generator(mention_strings, self.k)

for mention, candidates in zip(doc.ents, batch_candidates):
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional
import os

import pytest
Expand All @@ -8,6 +8,7 @@

from scispacy.custom_sentence_segmenter import pysbd_sentencizer
from scispacy.custom_tokenizer import combined_rule_tokenizer, combined_rule_prefixes, remove_new_lines
from scispacy.abbreviation import AbbreviationDetector

LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {}

Expand All @@ -19,14 +20,15 @@ def get_spacy_model(
ner: bool,
with_custom_tokenizer: bool = False,
with_sentence_segmenter: bool = False,
with_serializable_abbreviation_detector: Optional[bool] = None,
) -> SpacyModelType:
"""
In order to avoid loading spacy models repeatedly,
we'll save references to them, keyed by the options
we used to create the spacy model, so any particular
configuration only gets loaded once.
"""
options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter)
options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter, with_serializable_abbreviation_detector)
if options not in LOADED_SPACY_MODELS:
disable = ["vectors", "textcat"]
if not pos_tags:
Expand All @@ -46,6 +48,8 @@ def get_spacy_model(
spacy_model.tokenizer = combined_rule_tokenizer(spacy_model)
if with_sentence_segmenter:
spacy_model.add_pipe("pysbd_sentencizer", first=True)
if with_serializable_abbreviation_detector is not None:
spacy_model.add_pipe("abbreviation_detector", config={"make_serializable": with_serializable_abbreviation_detector})

LOADED_SPACY_MODELS[options] = spacy_model
return LOADED_SPACY_MODELS[options]
Expand Down Expand Up @@ -97,9 +101,13 @@ def test_model_dir():

@pytest.fixture()
def combined_all_model_fixture():
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False)
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=True)
return nlp

@pytest.fixture()
def combined_all_model_fixture_non_serializable_abbrev():
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=False)
return nlp

@pytest.fixture()
def combined_rule_prefixes_fixture():
Expand Down
28 changes: 27 additions & 1 deletion tests/custom_tests/test_all_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import spacy
from spacy.vocab import Vocab
import shutil
import pytest


def test_custom_segmentation(combined_all_model_fixture):
Expand Down Expand Up @@ -36,6 +37,31 @@ def test_custom_segmentation(combined_all_model_fixture):
]
actual_tokens = [t.text for t in doc]
assert expected_tokens == actual_tokens
assert doc.is_parsed
assert doc.has_annotation("DEP")
assert doc[0].dep_ == "ROOT"
assert doc[0].tag_ == "NN"

def test_full_pipe_serializable(combined_all_model_fixture):
text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
doc = [doc for doc in combined_all_model_fixture.pipe([text, text], n_process = 2)][0]
# If we got here this means that both model is serializable and there is an abbreviation that would break if it wasn't
assert len(doc._.abbreviations) > 0
abbrev = doc._.abbreviations[0]
assert abbrev["short_text"] == "CEIL"
assert abbrev["long_text"] == "cytokine expression in leukocytes"
assert doc[abbrev["short_start"] : abbrev["short_end"]].text == abbrev["short_text"]
assert doc[abbrev["long_start"] : abbrev["long_end"]].text == abbrev["long_text"]

def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev):
text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
# This line requires the pipeline to be serializable, so the test should fail here
doc = combined_all_model_fixture_non_serializable_abbrev(text)
with pytest.raises(TypeError):
doc.to_bytes()

# Below is the test version to be used once we move to spacy v3.1.0 or higher
# def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev):
# text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
# # This line requires the pipeline to be serializable (because it uses 2 processes), so the test should fail here
# with pytest.raises(TypeError):
# list(combined_all_model_fixture_non_serializable_abbrev.pipe([text, text], n_process = 2))

0 comments on commit 3d153dd

Please sign in to comment.