Skip to content

Commit

Permalink
Add ONNXRT model layoutlmv2 (#1089)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 25, 2023
1 parent a050af7 commit 5f0b17e
Show file tree
Hide file tree
Showing 20 changed files with 2,645 additions and 0 deletions.
14 changes: 14 additions & 0 deletions examples/.config/model_params_onnxrt.json
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,20 @@
"main_script": "main.py",
"batch_size": 1
},
"hf_layoutlmv2_dynamic": {
"model_src_dir": "nlp/huggingface_model/token_classification/layoutlmv2/quantization/ptq_dynamic",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/hf_layoutlmv2/layoutlmv2-model.onnx",
"main_script": "main.py",
"batch_size": 1
},
"hf_layoutlmv2": {
"model_src_dir": "nlp/huggingface_model/token_classification/layoutlmv2/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/hf_layoutlmv2/layoutlmv2-model.onnx",
"main_script": "main.py",
"batch_size": 1
},
"hf_layoutlm_dynamic": {
"model_src_dir": "nlp/huggingface_model/token_classification/layoutlm/quantization/ptq_dynamic",
"dataset_location": "",
Expand Down
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,14 @@ Intel® Neural Compressor validated examples with multiple compression technique
<a href="./onnxrt/nlp/huggingface_model/token_classification/layoutlmv3/quantization/ptq_dynamic">integerops</a> / <a href="./onnxrt/nlp/huggingface_model/token_classification/layoutlmv3/quantization/ptq_static">qlinearops</a>
</td>
</tr>
<tr>
<td>LayoutLMv2 FUNSD (HuggingFace)</td>
<td>Natural Language Processing</td>
<td>Post-Training Dynamic / Static Quantization</td>
<td>
<a href="./onnxrt/nlp/huggingface_model/token_classification/layoutlmv2/quantization/ptq_dynamic">integerops</a> / <a href="./onnxrt/nlp/huggingface_model/token_classification/layoutlmv2/quantization/ptq_static">qlinearops</a>
</td>
</tr>
<tr>
<td>LayoutLM FUNSD (HuggingFace)</td>
<td>Natural Language Processing</td>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Step-by-Step
============

This example quantizes the [LayoutLMv2](https://huggingface.co/microsoft/layoutlmv2-base-uncased) model that is fine-tuned on the [FUNSD](https://huggingface.co/datasets/nielsr/funsd) dataset.

# Prerequisite

## 1. Environment
```shell
pip install neural-compressor
pip install -r requirements.txt
```
> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).
## 2. Prepare ONNX Model
Export the [nielsr/layoutlmv2-finetuned-funsd](https://huggingface.co/nielsr/layoutlmv2-finetuned-funsd) model to ONNX.

```bash
# fine-tuned model https://huggingface.co/nielsr/layoutlmv2-finetuned-funsd
python export.py --torch_model_name_or_path=/fine-tuned/torch/model/name/or/path
```
> Note: To export LayoutLMv2, please install [detectron2](https://github.com/facebookresearch/detectron2) first.
# Run

## 1. Quantization

Static quantization with QOperator format:

```bash
bash run_tuning.sh --input_model=./layoutlmv2-finetuned-funsd-exported.onnx \ # onnx model path as *.onnx
--output_model=/path/to/model_tune \
--quant_format="QOperator"
```


## 2. Benchmark

```bash
bash run_benchmark.sh --input_model=/path/to/model \ # model path as *.onnx
--batch_size=batch_size \
--mode=performance # or accuracy
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
# coding=utf-8
# export pytorch model into onnx model


def get_dummy_input(model_name_or_path):
import torch
from transformers import LayoutLMv2Processor
from PIL import Image

processor = LayoutLMv2Processor.from_pretrained(model_name_or_path)

width = 762
height = 800
# Create a new RGB image with the specified size
image = Image.new("RGB", (width, height))
# Generate random pixel values and set them in the image
pixels = []
import numpy as np

red = np.random.randint(255, size=(width * height))
green = np.random.randint(255, size=(width * height))
blue = np.random.randint(255, size=(width * height))

pixels = [(r, g, b) for r, g, b in zip(red, green, blue)]

image.putdata(pixels)

encoding = processor(image, return_tensors="pt", max_length=512, padding="max_length")
print(encoding.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'])
for key, val in encoding.items():
print(f"key: {key}; val: {val.shape}")
dummy_input = {}
dummy_input["input_ids"] = encoding["input_ids"].to(torch.int64)
dummy_input["attention_mask"] = encoding["attention_mask"].to(torch.int64)
dummy_input["bbox"] = encoding["bbox"].to(torch.int64)
dummy_input["image"] = encoding["image"].to(torch.int64)
# image torch.Size([4, 3, 224, 224])
# input_ids torch.Size([4, 512])
# attention_mask torch.Size([4, 512])
# token_type_ids torch.Size([4, 512])
# bbox torch.Size([4, 512, 4])
# labels torch.Size([4, 512])
return dummy_input


def export_model_to_onnx(model_name_or_path, export_model_path):
from torch.onnx import export as onnx_export
from collections import OrderedDict
from itertools import chain
from transformers import LayoutLMv2ForTokenClassification
# labels = datasets['train'].features['ner_tags'].feature.names
# id2label = {v: k for v, k in enumerate(labels)}
# label2id = {k: v for v, k in enumerate(labels)}
model = LayoutLMv2ForTokenClassification.from_pretrained(model_name_or_path, num_labels=7)
dummy_input = get_dummy_input(model_name_or_path)
inputs = OrderedDict(
{
"input_ids": {0: "batch_size", 1: "sequence_length"},
"bbox": {0: "batch_size", 1: "sequence_length"},
"image": {0: "batch_size", 1: "num_channels"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
)
assert len(inputs.keys()) == len(dummy_input.keys())
outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})

onnx_export(
model=model,
args=(dummy_input,),
f=export_model_path,
input_names=list(inputs.keys()),
output_names=list(outputs.keys()),
dynamic_axes=dict(chain(inputs.items(), outputs.items())),
do_constant_folding=True,
)
print(f"The model was successfully exported and saved as {export_model_path}.")


import argparse

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export huggingface onnx model", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--torch_model_name_or_path", type=str, help="fine-tuned pytorch model name or path")
parser.add_argument("--max_len", type=int, default=512, help="Maximum length of the sentence pairs")
args = parser.parse_args()
args.output_model = args.torch_model_name_or_path.split("/")[-1] + "-exported.onnx"
export_model_to_onnx(model_name_or_path=args.torch_model_name_or_path, export_model_path=args.output_model)
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# coding=utf-8
# Adapted from https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py

import json
import os

import datasets

from PIL import Image
import numpy as np

logger = datasets.logging.get_logger(__name__)


_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""


def load_image(image_path):
image = Image.open(image_path).convert("RGB")
w, h = image.size
return image, (w, h)


def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]


class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""

def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)


class Funsd(datasets.GeneratorBasedBuilder):
"""FUNSD dataset."""

BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]

def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)

def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]

def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
words_example, label = item["words"], item["label"]
words_example = [w for w in words_example if w["text"].strip() != ""]
if len(words_example) == 0:
continue
if label == "other":
for w in words_example:
words.append(w["text"])
ner_tags.append("O")
bboxes.append(normalize_bbox(w["box"], size))
else:
words.append(words_example[0]["text"])
ner_tags.append("B-" + label.upper())
bboxes.append(normalize_bbox(words_example[0]["box"], size))
for w in words_example[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
bboxes.append(normalize_bbox(w["box"], size))
yield guid, {
"id": str(guid),
"words": words,
"bboxes": bboxes,
"ner_tags": ner_tags,
"image_path": image_path,
}
Loading

0 comments on commit 5f0b17e

Please sign in to comment.