diff --git a/scispacy/abbreviation.py b/scispacy/abbreviation.py index 732b068..cf5e4aa 100644 --- a/scispacy/abbreviation.py +++ b/scispacy/abbreviation.py @@ -151,14 +151,22 @@ class AbbreviationDetector: nlp: `Language`, a required argument for spacy to use this as a factory name: `str`, a required argument for spacy to use this as a factory + make_serializable: `bool`, a required argument for whether we want to use the serializable + or non serializable version. """ - def __init__(self, nlp: Language, name: str = "abbreviation_detector") -> None: + def __init__( + self, + nlp: Language, + name: str = "abbreviation_detector", + make_serializable: bool = False, + ) -> None: Doc.set_extension("abbreviations", default=[], force=True) Span.set_extension("long_form", default=None, force=True) self.matcher = Matcher(nlp.vocab) self.matcher.add("parenthesis", [[{"ORTH": "("}, {"OP": "+"}, {"ORTH": ")"}]]) + self.make_serializable = make_serializable self.global_matcher = Matcher(nlp.vocab) def find(self, span: Span, doc: Doc) -> Tuple[Span, Set[Span]]: @@ -186,6 +194,12 @@ def __call__(self, doc: Doc) -> Doc: for short in short_forms: short._.long_form = long_form doc._.abbreviations.append(short) + if self.make_serializable: + abbreviations = doc._.abbreviations + doc._.abbreviations = [ + self.make_short_form_serializable(abbreviation) + for abbreviation in abbreviations + ] return doc def find_matches_for( @@ -223,3 +237,24 @@ def find_matches_for( self.global_matcher.remove(key) return list((k, v) for k, v in all_occurences.items()) + + def make_short_form_serializable(self, abbreviation: Span): + """ + Converts the abbreviations into a short form that is serializable to enable multiprocessing + + Parameters + ---------- + abbreviation: Span + The abbreviation span identified by the detector + """ + long_form = abbreviation._.long_form + abbreviation._.long_form = long_form.text + serializable_abbr = { + "short_text": abbreviation.text, + "short_start": abbreviation.start, + "short_end": abbreviation.end, + "long_text": long_form.text, + "long_start": long_form.start, + "long_end": long_form.end, + } + return serializable_abbr diff --git a/scispacy/linking.py b/scispacy/linking.py index 974bbea..1e247bd 100644 --- a/scispacy/linking.py +++ b/scispacy/linking.py @@ -96,20 +96,23 @@ def __init__( self.umls = self.kb def __call__(self, doc: Doc) -> Doc: - mentions = [] + mention_strings = [] if self.resolve_abbreviations and Doc.has_extension("abbreviations"): - + # TODO: This is possibly sub-optimal - we might + # prefer to look up both the long and short forms. for ent in doc.ents: - # TODO: This is possibly sub-optimal - we might - # prefer to look up both the long and short forms. - if ent._.long_form is not None: - mentions.append(ent._.long_form) + if isinstance(ent._.long_form, Span): + # Long form + mention_strings.append(ent._.long_form.text) + elif isinstance(ent._.long_form, str): + # Long form + mention_strings.append(ent._.long_form) else: - mentions.append(ent) + # no abbreviations case + mention_strings.append(ent.text) else: - mentions = doc.ents + mention_strings = [x.text for x in doc.ents] - mention_strings = [x.text for x in mentions] batch_candidates = self.candidate_generator(mention_strings, self.k) for mention, candidates in zip(doc.ents, batch_candidates): diff --git a/tests/conftest.py b/tests/conftest.py index 68622fb..705dfe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import os import pytest @@ -8,6 +8,7 @@ from scispacy.custom_sentence_segmenter import pysbd_sentencizer from scispacy.custom_tokenizer import combined_rule_tokenizer, combined_rule_prefixes, remove_new_lines +from scispacy.abbreviation import AbbreviationDetector LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {} @@ -19,6 +20,7 @@ def get_spacy_model( ner: bool, with_custom_tokenizer: bool = False, with_sentence_segmenter: bool = False, + with_serializable_abbreviation_detector: Optional[bool] = None, ) -> SpacyModelType: """ In order to avoid loading spacy models repeatedly, @@ -26,7 +28,7 @@ def get_spacy_model( we used to create the spacy model, so any particular configuration only gets loaded once. """ - options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter) + options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter, with_serializable_abbreviation_detector) if options not in LOADED_SPACY_MODELS: disable = ["vectors", "textcat"] if not pos_tags: @@ -46,6 +48,8 @@ def get_spacy_model( spacy_model.tokenizer = combined_rule_tokenizer(spacy_model) if with_sentence_segmenter: spacy_model.add_pipe("pysbd_sentencizer", first=True) + if with_serializable_abbreviation_detector is not None: + spacy_model.add_pipe("abbreviation_detector", config={"make_serializable": with_serializable_abbreviation_detector}) LOADED_SPACY_MODELS[options] = spacy_model return LOADED_SPACY_MODELS[options] @@ -97,9 +101,13 @@ def test_model_dir(): @pytest.fixture() def combined_all_model_fixture(): - nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False) + nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=True) return nlp +@pytest.fixture() +def combined_all_model_fixture_non_serializable_abbrev(): + nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=False) + return nlp @pytest.fixture() def combined_rule_prefixes_fixture(): diff --git a/tests/custom_tests/test_all_model.py b/tests/custom_tests/test_all_model.py index 56991a3..1da5d0e 100644 --- a/tests/custom_tests/test_all_model.py +++ b/tests/custom_tests/test_all_model.py @@ -3,6 +3,7 @@ import spacy from spacy.vocab import Vocab import shutil +import pytest def test_custom_segmentation(combined_all_model_fixture): @@ -36,6 +37,31 @@ def test_custom_segmentation(combined_all_model_fixture): ] actual_tokens = [t.text for t in doc] assert expected_tokens == actual_tokens - assert doc.is_parsed + assert doc.has_annotation("DEP") assert doc[0].dep_ == "ROOT" assert doc[0].tag_ == "NN" + +def test_full_pipe_serializable(combined_all_model_fixture): + text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." + doc = [doc for doc in combined_all_model_fixture.pipe([text, text], n_process = 2)][0] + # If we got here this means that both model is serializable and there is an abbreviation that would break if it wasn't + assert len(doc._.abbreviations) > 0 + abbrev = doc._.abbreviations[0] + assert abbrev["short_text"] == "CEIL" + assert abbrev["long_text"] == "cytokine expression in leukocytes" + assert doc[abbrev["short_start"] : abbrev["short_end"]].text == abbrev["short_text"] + assert doc[abbrev["long_start"] : abbrev["long_end"]].text == abbrev["long_text"] + +def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev): + text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." + # This line requires the pipeline to be serializable, so the test should fail here + doc = combined_all_model_fixture_non_serializable_abbrev(text) + with pytest.raises(TypeError): + doc.to_bytes() + +# Below is the test version to be used once we move to spacy v3.1.0 or higher +# def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev): +# text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." +# # This line requires the pipeline to be serializable (because it uses 2 processes), so the test should fail here +# with pytest.raises(TypeError): +# list(combined_all_model_fixture_non_serializable_abbrev.pipe([text, text], n_process = 2)) \ No newline at end of file