From dd650393b8d2d8340f09857d52bd7ddbff09c28e Mon Sep 17 00:00:00 2001 From: Angele Zamarron Date: Thu, 9 Nov 2023 15:23:07 -0800 Subject: [PATCH] Refactor grobid sections (#281) * some kind of progress, still need to address overlap in sentences crossing paragraphs * ok cool this seems to be working! * make heading spans part of section * make sentences have unique ids, give paragraphs and sections ids * fix 'coords' error * pad_x for sentences * IT WORKS we get nice spans for sentences for this one specific sha now * remove spanless results (useless) * lil rename * mmda version bump * just return list * oops delete my thoughts * oops fix my error made when switching to just list being returned * new fix_overlaps param --- pyproject.toml | 2 +- ...grobid_augment_existing_document_parser.py | 140 +++++++++++------- src/mmda/utils/tools.py | 85 ++++++++--- 3 files changed, 156 insertions(+), 71 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac07d5ec..847bd5d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'mmda' -version = '0.9.15' +version = '0.9.16' description = 'MMDA - multimodal document analysis' authors = [ {name = 'Allen Institute for Artificial Intelligence', email = 'contact@allenai.org'}, diff --git a/src/mmda/parsers/grobid_augment_existing_document_parser.py b/src/mmda/parsers/grobid_augment_existing_document_parser.py index 24f3ca27..2f7406bf 100644 --- a/src/mmda/parsers/grobid_augment_existing_document_parser.py +++ b/src/mmda/parsers/grobid_augment_existing_document_parser.py @@ -4,7 +4,7 @@ """ from grobid_client.grobid_client import GrobidClient -from typing import List, Optional +from typing import List, Optional, Tuple, Dict import logging import os @@ -12,7 +12,7 @@ from mmda.parsers.parser import Parser from mmda.types import Metadata -from mmda.types.annotation import BoxGroup, Box, SpanGroup +from mmda.types.annotation import BoxGroup, Box, SpanGroup, Span from mmda.types.document import Document from mmda.types.names import PagesField, RowsField, TokensField from mmda.utils.tools import box_groups_to_span_groups @@ -104,32 +104,69 @@ def _parse_xml_onto_doc(self, xml: str, doc: Document) -> Document: # sentences within the body text, also tagged by paragraphs. # We use these to annotate the document in order to provide a hierarchical structure: # e.g. doc.sections.header, doc.sections[0].paragraphs[0].sentences[0] - section_box_groups, heading_box_groups, paragraph_box_groups, sentence_box_groups = \ - self._get_structured_body_text_box_groups(xml_root) - doc.annotate( - sections=box_groups_to_span_groups( - section_box_groups, doc, center=True - ) - ) - doc.annotate( - headings=box_groups_to_span_groups( - heading_box_groups, doc, center=True - ) - ) - doc.annotate( - paragraphs=box_groups_to_span_groups( - paragraph_box_groups, doc, center=True - ) - ) - doc.annotate( - sentences=box_groups_to_span_groups( - sentence_box_groups, doc, center=True - ) - ) + section_headings_and_sentence_box_groups_in_paragraphs = \ + self._get_structured_sentence_box_groups(xml_root) + + heading_span_groups = [] + paragraph_span_groups = [] + section_span_groups = [] + sentence_span_groups = [] + + unallocated_section_tokens_dict: Dict[int, SpanGroup] = dict() + + for heading_box_group, paragraphs in section_headings_and_sentence_box_groups_in_paragraphs: + section_spans = [] + if heading_box_group: + heading_span_group_in_list = ( + box_groups_to_span_groups( + [heading_box_group], + doc, + center=True, + unallocated_tokens_dict=unallocated_section_tokens_dict, + fix_overlaps=True, + ) + ) + heading_span_group = heading_span_group_in_list[0] + heading_span_groups.append(heading_span_group) + section_spans.extend(heading_span_group.spans) + this_section_paragraph_span_groups = [] + for sentence_box_groups in paragraphs: + this_paragraph_sentence_span_groups = box_groups_to_span_groups( + sentence_box_groups, + doc, + center=True, + pad_x=True, + unallocated_tokens_dict=unallocated_section_tokens_dict, + fix_overlaps=True, + ) + if all([sg.spans for sg in this_paragraph_sentence_span_groups]): + sentence_span_groups.extend(this_paragraph_sentence_span_groups) + paragraph_spans = [] + for sg in this_paragraph_sentence_span_groups: + paragraph_spans.extend(sg.spans) + # TODO add boxes to paragraph spangroups + this_section_paragraph_span_groups.append(SpanGroup(spans=paragraph_spans)) + paragraph_span_groups.extend(this_section_paragraph_span_groups) + for sg in this_section_paragraph_span_groups: + section_spans.extend(sg.spans) + # TODO add boxes to section spangroups + section_span_groups.append(SpanGroup(spans=section_spans)) + + # ensure unique IDs within annotations + all_section_span_groups = [heading_span_groups, sentence_span_groups, paragraph_span_groups, section_span_groups] + for span_groups in all_section_span_groups: + for i, span_group in enumerate(span_groups): + span_group.id = i + + doc.annotate(headings=heading_span_groups) + doc.annotate(sentences=sentence_span_groups) + doc.annotate(paragraphs=paragraph_span_groups) + doc.annotate(sections=section_span_groups) + return doc - def _xml_coords_to_boxes(self, coords_attribute: str): + def _xml_coords_to_boxes(self, coords_attribute: str) -> List[Box]: coords_list = coords_attribute.split(";") boxes = [] for coords in coords_list: @@ -176,7 +213,11 @@ def _get_box_groups( elements = item_list_root.findall(f".//tei:{item_tag}", NS) for e in elements: - coords_string = e.attrib["coords"] + try: + coords_string = e.attrib["coords"] + except KeyError: + logging.warning(f"Element with '{item_tag}' tag missing 'coords' attribute") + continue boxes = self._xml_coords_to_boxes(coords_string) grobid_id = e.attrib[ID_ATTR_KEY] if ID_ATTR_KEY in e.keys() else None @@ -208,7 +249,11 @@ def _get_heading_box_group( box_group = None heading_element = section_div.find(f".//tei:head", NS) if heading_element is not None: # elements evaluate as False if no children - coords_string = heading_element.attrib["coords"] + try: + coords_string = heading_element.attrib["coords"] + except KeyError: + logging.warning(f"Heading element missing 'coords' attribute") + return None boxes = self._xml_coords_to_boxes(coords_string) number = heading_element.attrib["n"] if "n" in heading_element.keys() else None section_title = heading_element.text @@ -218,34 +263,29 @@ def _get_heading_box_group( ) return box_group - def _get_structured_body_text_box_groups( + def _get_structured_sentence_box_groups( self, root: et.Element - ) -> (List[BoxGroup], List[BoxGroup], List[BoxGroup], List[BoxGroup]): + ) -> List[Tuple[Optional[BoxGroup], List[List[BoxGroup]]]]: section_list_root = root.find(f".//tei:body", NS) - - body_sections: List[BoxGroup] = [] - body_headings: List[BoxGroup] = [] - body_paragraphs: List[BoxGroup] = [] - body_sentences: List[BoxGroup] = [] - section_divs = section_list_root.findall(f"./tei:div", NS) + + section_structures = [] for div in section_divs: - section_boxes: List[Box] = [] heading_box_group = self._get_heading_box_group(div) - if heading_box_group: - body_headings.append(heading_box_group) - section_boxes.extend(heading_box_group.boxes) + paragraphs: List[List[BoxGroup]] = [] for p in div.findall(f"./tei:p", NS): - paragraph_boxes: List[Box] = [] - paragraph_sentences: List[BoxGroup] = [] + sentence_box_groups: List[BoxGroup] = [] for s in p.findall(f"./tei:s", NS): - sentence_boxes = self._xml_coords_to_boxes(s.attrib["coords"]) - paragraph_sentences.append(BoxGroup(boxes=sentence_boxes)) - paragraph_boxes.extend(sentence_boxes) - body_paragraphs.append(BoxGroup(boxes=paragraph_boxes)) - section_boxes.extend(paragraph_boxes) - body_sentences.extend(paragraph_sentences) - body_sections.append(BoxGroup(boxes=section_boxes)) - - return body_sections, body_headings, body_paragraphs, body_sentences + try: + coords_string = s.attrib["coords"] + except KeyError: + logging.warning(f"Sentence element missing 'coords' attribute") + continue + sentence_boxes = self._xml_coords_to_boxes(coords_string) + sentence_box_groups.append(BoxGroup(boxes=sentence_boxes)) + paragraphs.append(sentence_box_groups) + + section_structures.append([heading_box_group, paragraphs]) + + return section_structures diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index 1095c96c..d1482c1b 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -4,7 +4,7 @@ from collections import defaultdict from itertools import groupby import itertools -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Optional, Union import numpy as np @@ -41,20 +41,38 @@ def allocate_overlapping_tokens_for_box( def box_groups_to_span_groups( - box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False + box_groups: List[BoxGroup], + doc, + pad_x: bool = False, + center: bool = False, + unallocated_tokens_dict: Optional[Dict[int, SpanGroup]] = None, + fix_overlaps: bool = False, ) -> List[SpanGroup]: - """Generate SpanGroups from BoxGroups. + """Generate SpanGroups from BoxGroups given they can only generate spans of tokens not already allocated Args `box_groups` (List[BoxGroup]) `doc` (Document) base document annotated with pages, tokens, rows to - `center` (bool) if True, considers tokens to be overlapping with boxes only if their centers overlap + `center` (bool) if True, considers tokens to be overlapping with boxes only if their centers overlap + `unallocated_tokens` (Optional[Dict]) of token spangroups keyed by page. If provided, will use as starting + point for determining if token is already allocated. Assumes the tokens within are of the same type as the + `doc` (i.e., tokens from both doc and the dict both have their box data in either Span.box or + SpanGroup.boxgroup) + `fix_overlaps` (bool) if True, will attempt to fix overlapping spans within a SpanGroup by omitting + spans from already allocated tokens that end up contained in the derived_spans that come from MergeSpans. + This allows for the possibility of a BoxGroup that covers text to end up with a SpanGroup that is missing + spans or even has no spans since a previous BoxGroup already allocated all the underlying tokens. This + reduces the possibility of SpanGroup overlap errors, but may not return the desired SpanGroups. Returns - List[SpanGroup] with each SpanGroup.spans corresponding to spans (sans boxes) of allocated tokens per box_group, + Union (either) of: + -List[SpanGroup] with each SpanGroup.spans corresponding to spans (sans boxes) of allocated tokens per box_group, and each SpanGroup.box_group containing original box_groups + or Tuple of: + -List[SpanGroup] as described above, and + -Dictionary of unallocated tokens keyed by page """ assert all([isinstance(group, BoxGroup) for group in box_groups]) - all_page_tokens = dict() + unallocated_tokens = unallocated_tokens_dict if unallocated_tokens_dict is not None else dict() avg_token_widths = dict() derived_span_groups = [] token_box_in_box_group = None @@ -66,8 +84,8 @@ def box_groups_to_span_groups( for box in box_group.boxes: # Caching the page tokens to avoid duplicated search - if box.page not in all_page_tokens: - cur_page_tokens = all_page_tokens[box.page] = doc.pages[ + if box.page not in unallocated_tokens: + cur_page_tokens = unallocated_tokens[box.page] = doc.pages[ box.page ].tokens if token_box_in_box_group is None: @@ -89,7 +107,7 @@ def box_groups_to_span_groups( avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens]) else: - cur_page_tokens = all_page_tokens[box.page] + cur_page_tokens = unallocated_tokens[box.page] # Find all the tokens within the box tokens_in_box, remaining_tokens = allocate_overlapping_tokens_for_box( @@ -101,7 +119,7 @@ def box_groups_to_span_groups( y=0.0, center=center ) - all_page_tokens[box.page] = remaining_tokens + unallocated_tokens[box.page] = remaining_tokens all_tokens_overlapping_box_group.extend(tokens_in_box) merge_spans = ( @@ -123,15 +141,47 @@ def box_groups_to_span_groups( # tokens overlapping with derived spans: sg_tokens = doc.find_overlapping(SpanGroup(spans=derived_spans), "tokens") + def omit_span_from_derived_spans(t_span): + # if the sg_token is in the derived_span, cut it out by updating derived_spans. + # this can happen because merge_spans finds min number of spans and can merge spans that + # cover tokens that were already allocated. We update this to avoid spangroup overlap errors. + for i, d_span in enumerate(derived_spans): + if d_span.start == t_span.start and t_span.end < d_span.end: + # unusable token_span is at start of derived_span + d_span.start = t_span.end + elif d_span.end == t_span.end and d_span.start < t_span.start < d_span.end: + # unusable token_span is at end of derived_span + d_span.end = t_span.end + elif d_span.start < t_span.start < d_span.end and t_span.end < d_span.end: + # unusable token_span is encompassed by derived_span + d_span.end = t_span.start + derived_spans.insert(i+1, Span(t_span.end, d_span.end)) + elif d_span.start == t_span.start and d_span.end == t_span.end: + # unusable token_span is equal to derived_span + derived_spans.remove(d_span) + # remove any additional tokens added to the spangroup via MergeSpans from the list of available page tokens # (this can happen if the MergeSpans algorithm merges tokens that are not adjacent, e.g. if `center` is True and # a token is not found to be overlapping with the box, but MergeSpans decides it is close enough to be merged) for sg_token in sg_tokens: if sg_token not in all_tokens_overlapping_box_group: - if token_box_in_box_group and sg_token in all_page_tokens[sg_token.box_group.boxes[0].page]: - all_page_tokens[sg_token.box_group.boxes[0].page].remove(sg_token) - elif not token_box_in_box_group and sg_token in all_page_tokens[sg_token.spans[0].box.page]: - all_page_tokens[sg_token.spans[0].box.page].remove(sg_token) + # if token not removed from unallocated_tokens yet, do it now + if token_box_in_box_group: + if sg_token in unallocated_tokens[sg_token.box_group.boxes[0].page]: + unallocated_tokens[sg_token.box_group.boxes[0].page].remove(sg_token) + # otherwise, if it is in neither all_tokens_overlapping_box_group nor unallocated_tokens, + # the assumption is that the token has already been allocated by a different box_group, so, we need + # to remove it from our derived spans to avoid 'SpanGroup overlap' error. + else: + if fix_overlaps: + omit_span_from_derived_spans(sg_token.spans[0]) + else: + if sg_token in unallocated_tokens[sg_token.spans[0].box.page]: + unallocated_tokens[sg_token.spans[0].box.page].remove(sg_token) + # same scenario as above. + else: + if fix_overlaps: + omit_span_from_derived_spans(sg_token.spans[0]) derived_span_groups.append( SpanGroup( @@ -148,21 +198,16 @@ def box_groups_to_span_groups( "future Spans wont contain box). Ensure Document is annotated with tokens " "having box stored in SpanGroup box_group.boxes") - del all_page_tokens - derived_span_groups = sorted( derived_span_groups, key=lambda span_group: span_group.start ) # ensure they are ordered based on span indices - for box_id, span_group in enumerate(derived_span_groups): span_group.id = box_id - # return self._annotate_span_group( - # span_groups=derived_span_groups, field_name=field_name - # ) return derived_span_groups + class MergeSpans: """ Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans