-
Notifications
You must be signed in to change notification settings - Fork 906
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: export document pages as multimodal output (#54)
* feat: export document pages as multimodal output Signed-off-by: Michele Dolfi <[email protected]> * create a single parquet output Signed-off-by: Michele Dolfi <[email protected]> * add loading into HF datasets library Signed-off-by: Michele Dolfi <[email protected]> * renaming Signed-off-by: Michele Dolfi <[email protected]> * cleanup Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Michele Dolfi <[email protected]>
- Loading branch information
1 parent
69e5d95
commit 1de2e4f
Showing
5 changed files
with
1,025 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import logging | ||
from typing import Any, Dict, Iterable, List, Tuple | ||
|
||
from docling_core.types.doc.base import BaseCell, Ref, Table, TableCell | ||
|
||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell | ||
from docling.datamodel.document import ConvertedDocument, Page | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
def _export_table_to_html(table: Table): | ||
|
||
# TODO: this is flagged as internal, because we will move it | ||
# to the docling-core package. | ||
|
||
def _get_tablecell_span(cell: TableCell, ix): | ||
span = set([s[ix] for s in cell.spans]) | ||
if len(span) == 0: | ||
return 1, None, None | ||
return len(span), min(span), max(span) | ||
|
||
body = "" | ||
nrows = table.num_rows | ||
ncols = table.num_cols | ||
|
||
for i in range(nrows): | ||
body += "<tr>" | ||
for j in range(ncols): | ||
cell: TableCell = table.data[i][j] | ||
|
||
rowspan, rowstart, rowend = _get_tablecell_span(cell, 0) | ||
colspan, colstart, colend = _get_tablecell_span(cell, 1) | ||
|
||
if rowstart is not None and rowstart != i: | ||
continue | ||
if colstart is not None and colstart != j: | ||
continue | ||
|
||
if rowstart is None: | ||
rowstart = i | ||
if colstart is None: | ||
colstart = j | ||
|
||
content = cell.text.strip() | ||
label = cell.obj_type | ||
label_class = "body" | ||
celltag = "td" | ||
if label in ["row_header", "row_multi_header", "row_title"]: | ||
label_class = "header" | ||
elif label in ["col_header", "col_multi_header"]: | ||
label_class = "header" | ||
celltag = "th" | ||
|
||
opening_tag = f"{celltag}" | ||
if rowspan > 1: | ||
opening_tag += f' rowspan="{rowspan}"' | ||
if colspan > 1: | ||
opening_tag += f' colspan="{colspan}"' | ||
|
||
body += f"<{opening_tag}>{content}</{celltag}>" | ||
body += "</tr>" | ||
body = f"<table>{body}</table>" | ||
|
||
return body | ||
|
||
|
||
def generate_multimodal_pages( | ||
doc_result: ConvertedDocument, | ||
) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]: | ||
|
||
label_to_doclaynet = { | ||
"title": "title", | ||
"table-of-contents": "document_index", | ||
"subtitle-level-1": "section_header", | ||
"checkbox-selected": "checkbox_selected", | ||
"checkbox-unselected": "checkbox_unselected", | ||
"caption": "caption", | ||
"page-header": "page_header", | ||
"page-footer": "page_footer", | ||
"footnote": "footnote", | ||
"table": "table", | ||
"formula": "formula", | ||
"list-item": "list_item", | ||
"code": "code", | ||
"figure": "picture", | ||
"picture": "picture", | ||
"reference": "text", | ||
"paragraph": "text", | ||
"text": "text", | ||
} | ||
|
||
content_text = "" | ||
page_no = 0 | ||
start_ix = 0 | ||
end_ix = 0 | ||
doc_items = [] | ||
|
||
doc = doc_result.output | ||
|
||
def _process_page_segments(doc_items: list[Tuple[int, BaseCell]], page: Page): | ||
segments = [] | ||
|
||
for ix, item in doc_items: | ||
item_type = item.obj_type | ||
label = label_to_doclaynet.get(item_type, None) | ||
|
||
if label is None: | ||
continue | ||
|
||
bbox = BoundingBox.from_tuple( | ||
item.prov[0].bbox, origin=CoordOrigin.BOTTOMLEFT | ||
) | ||
new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized( | ||
page_size=page.size | ||
) | ||
|
||
new_segment = { | ||
"index_in_doc": ix, | ||
"label": label, | ||
"text": item.text if item.text is not None else "", | ||
"bbox": new_bbox.as_tuple(), | ||
"data": [], | ||
} | ||
|
||
if isinstance(item, Table): | ||
table_html = _export_table_to_html(item) | ||
new_segment["data"].append( | ||
{ | ||
"html_seq": table_html, | ||
"otsl_seq": "", | ||
} | ||
) | ||
|
||
segments.append(new_segment) | ||
|
||
return segments | ||
|
||
def _process_page_cells(page: Page): | ||
cells = [] | ||
for cell in page.cells: | ||
new_bbox = cell.bbox.to_top_left_origin( | ||
page_height=page.size.height | ||
).normalized(page_size=page.size) | ||
is_ocr = isinstance(cell, OcrCell) | ||
ocr_confidence = cell.confidence if is_ocr else 1.0 | ||
cells.append( | ||
{ | ||
"text": cell.text, | ||
"bbox": new_bbox.as_tuple(), | ||
"ocr": is_ocr, | ||
"ocr_confidence": ocr_confidence, | ||
} | ||
) | ||
return cells | ||
|
||
def _process_page(): | ||
page_ix = page_no - 1 | ||
page = doc_result.pages[page_ix] | ||
|
||
page_cells = _process_page_cells(page=page) | ||
page_segments = _process_page_segments(doc_items=doc_items, page=page) | ||
content_md = doc.export_to_markdown( | ||
main_text_start=start_ix, main_text_stop=end_ix | ||
) | ||
|
||
return content_text, content_md, page_cells, page_segments, page | ||
|
||
for ix, orig_item in enumerate(doc.main_text): | ||
|
||
item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item | ||
if item is None or item.prov is None or len(item.prov) == 0: | ||
_log.debug(f"Skipping item {orig_item}") | ||
continue | ||
|
||
item_page = item.prov[0].page | ||
|
||
# Page is complete | ||
if page_no > 0 and item_page > page_no: | ||
yield _process_page() | ||
|
||
start_ix = ix | ||
doc_items = [] | ||
content_text = "" | ||
|
||
page_no = item_page | ||
end_ix = ix | ||
doc_items.append((ix, item)) | ||
if item.text is not None and item.text != "": | ||
content_text += item.text + " " | ||
|
||
if len(doc_items) > 0: | ||
yield _process_page() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import datetime | ||
import logging | ||
import time | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
|
||
from docling.datamodel.base_models import AssembleOptions, ConversionStatus | ||
from docling.datamodel.document import DocumentConversionInput | ||
from docling.document_converter import DocumentConverter | ||
from docling.utils.export import generate_multimodal_pages | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
IMAGE_RESOLUTION_SCALE = 2.0 | ||
|
||
|
||
def main(): | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
input_doc_paths = [ | ||
Path("./tests/data/2206.01062.pdf"), | ||
] | ||
output_dir = Path("./scratch") | ||
|
||
input_files = DocumentConversionInput.from_paths(input_doc_paths) | ||
|
||
# Important: For operating with page images, we must keep them, otherwise the DocumentConverter | ||
# will destroy them for cleaning up memory. | ||
# This is done by setting AssembleOptions.images_scale, which also defines the scale of images. | ||
# scale=1 correspond of a standard 72 DPI image | ||
assemble_options = AssembleOptions() | ||
assemble_options.images_scale = IMAGE_RESOLUTION_SCALE | ||
|
||
doc_converter = DocumentConverter(assemble_options=assemble_options) | ||
|
||
start_time = time.time() | ||
|
||
converted_docs = doc_converter.convert(input_files) | ||
|
||
success_count = 0 | ||
failure_count = 0 | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
for doc in converted_docs: | ||
if doc.status != ConversionStatus.SUCCESS: | ||
_log.info(f"Document {doc.input.file} failed to convert.") | ||
failure_count += 1 | ||
continue | ||
|
||
rows = [] | ||
for ( | ||
content_text, | ||
content_md, | ||
page_cells, | ||
page_segments, | ||
page, | ||
) in generate_multimodal_pages(doc): | ||
|
||
dpi = page._default_image_scale * 72 | ||
|
||
rows.append( | ||
{ | ||
"document": doc.input.file.name, | ||
"hash": doc.input.document_hash, | ||
"page_hash": page.page_hash, | ||
"image": { | ||
"width": page.image.width, | ||
"height": page.image.height, | ||
"bytes": page.image.tobytes(), | ||
}, | ||
"cells": page_cells, | ||
"contents": content_text, | ||
"contents_md": content_md, | ||
"segments": page_segments, | ||
"extra": { | ||
"page_num": page.page_no + 1, | ||
"width_in_points": page.size.width, | ||
"height_in_points": page.size.height, | ||
"dpi": dpi, | ||
}, | ||
} | ||
) | ||
success_count += 1 | ||
|
||
# Generate one parquet from all documents | ||
df = pd.json_normalize(rows) | ||
now = datetime.datetime.now() | ||
output_filename = output_dir / f"multimodal_{now:%Y-%m-%d_%H%M%S}.parquet" | ||
df.to_parquet(output_filename) | ||
|
||
end_time = time.time() - start_time | ||
|
||
_log.info(f"All documents were converted in {end_time:.2f} seconds.") | ||
|
||
if failure_count > 0: | ||
raise RuntimeError( | ||
f"The example failed converting {failure_count} on {len(input_doc_paths)}." | ||
) | ||
|
||
# This block demonstrates how the file can be opened with the HF datasets library | ||
# from datasets import Dataset | ||
# from PIL import Image | ||
# multimodal_df = pd.read_parquet(output_filename) | ||
|
||
# # Convert pandas DataFrame to Hugging Face Dataset and load bytes into image | ||
# dataset = Dataset.from_pandas(multimodal_df) | ||
# def transforms(examples): | ||
# examples["image"] = Image.frombytes('RGB', (examples["image.width"], examples["image.height"]), examples["image.bytes"], 'raw') | ||
# return examples | ||
# dataset = dataset.map(transforms) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.