Skip to content

Commit

Permalink
update dataset create, test eval, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
ppaanngggg committed Mar 4, 2024
1 parent 2c09cbb commit a1abc18
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 128 deletions.
71 changes: 42 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,86 @@
🤗 <a href="https://huggingface.co/hantian/layoutreader">Hugging Face</a>
</p>

TODO: a result example

## Why this repo?

The original [LayoutReader](https://github.com/microsoft/unilm/tree/master/layoutreader) is published by Microsoft Research. It is based on `LayoutLM`, and use a `seq2seq` architecture to predict the reading order of the words in a document. There are several problems with the original repo:
1. Because it doesn't use `transformers`, there are lots of experiments in the code, and the code is not well-organized. It's hard to train and deploy.
2. `seq2seq` is too slow in production, I want to get the all predictions in one pass.
3. The [pre-trained model]()'s input is English word-level, but it's not the real case. The real inputs should be the spans extracted by PDF parser or OCR.
3. The [pre-trained model](https://huggingface.co/nielsr/layoutreader-readingbank)'s input is English word-level, but it's not the real case. The real inputs should be the spans extracted by PDF parser or OCR.
4. I want a multilingual model. I notice only use the bbox is only a little bit worse than bbox+text, so I want to train a model only use bbox, ignore the text.

## What I did?

1. Use `LayoutLMv3ForTokenClassification` of `transformers` to train and eval.
1. Refactor the codes, use `LayoutLMv3ForTokenClassification` of `transformers` to train and eval.
2. Offer a script turn the original word-level dataset into span-level dataset.
3. Use a better post-processor to avoid duplicate predictions.
4. Offer a docker image with API service.
3. Implement a better post-processor to avoid duplicate predictions.
4. Release a [pre-trained model](https://huggingface.co/hantian/layoutreader) fine-tuned from [layoutlmv3-large](https://huggingface.co/microsoft/layoutlmv3-large)

## How to use?

```python
from transformers import LayoutLMv3ForTokenClassification

model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# TODO
```

## Dataset

### Download Original Dataset

The original dataset can download from [ReadingBank](https://layoutlm.blob.core.windows.net/readingbank/dataset/ReadingBank.zip?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D). More details can be found in the original [repo](https://aka.ms/readingbank).

### Build Dataset
### Build Span-Level Dataset

```bash
python tools.py cache-dataset-spans --help
unzip ReadingBank.zip
python tools.py ./train/ train.jsonl.gz
python tools.py ./dev/ dev.jsonl.gz
python tools.py ./test/ test.jsonl.gz --src-shuffle-rate=0
python tools.py ./test/ test_shuf.jsonl.gz --src-shuffle-rate=1
```

### Train
## Train & Eval

```bash
bash train.sh
```

### Eval
The core codes are in `./v3` folder. The `train.sh` and `eval.py` are the entry points.

```bash
python eval.py --help
bash train.sh
python eval.py ../test.jsonl.gz hantian/layoutreader
python eval.py ../test_shuf.jsonl.gz hantian/layoutreader
```

## Span-Level Results

1. `shuf` means whether the input order is shuffled.
2. `BlEU Idx` is the BLEU score of predicted tokens' orders.
3. `BLEU Token` is the BLEU score of final merged text.

I only train the `layout only` model. And test on the span-level dataset. So the `Heuristic Method` result is quite different from the original word-level result. I mainly focus on the `BLEU Token`, it's only 0.4 lower than the original word-level result. But the speed is much faster.
3. `BLEU Text` is the BLEU score of final merged text.

> only use the first part of test file
I only train the `layout only` model. And test on the span-level dataset. So the `Heuristic Method` result is quite different from the original word-level result. I mainly focus on the `BLEU Text`, it's only a bit lower than the original word-level result. But the speed is much faster.

| Method | shuf | BLEU Idx | BLEU Token |
|----------------------------|------|----------|------------|
| Heuristic Method | no | 44.4 | 70.7 |
| LayoutReader (layout only) | no | 95.3 | 97.8 |
| LayoutReader (layout only) | yes | 95.0 | 97.6 |
| Method | shuf | BLEU Idx | BLEU Text |
|----------------------------|------|----------|-----------|
| Heuristic Method | no | 44.4 | 70.7 |
| LayoutReader (layout only) | no | 94.9 | 97.5 |
| LayoutReader (layout only) | yes | 95.0 | 97.6 |

## Word-Level Results

### My eval script

The `layout only` model is trained by myself using the original codes, and the `public model` is the pre-trained model. The `layout only` is nearly as good as the `public model`, and the `shuf` only has a little effect on the results.

> only use the first part of test file.
> Only test the first part of test dataset. Because it's too slow...
| Method | shuf | BLEU Idx | BLEU Token |
|-----------------------------|------|----------|------------|
| Heuristic Method | no | 78.3 | 79.4 |
| LayoutReader (layout only) | no | 98.0 | 98.2 |
| LayoutReader (layout only) | yes | 97.8 | 98.0 |
| LayoutReader (public model) | no | 98.0 | 98.3 |
| Method | shuf | BLEU Idx | BLEU Text |
|-----------------------------|------|----------|-----------|
| Heuristic Method | no | 78.3 | 79.4 |
| LayoutReader (layout only) | no | 98.0 | 98.2 |
| LayoutReader (layout only) | yes | 97.8 | 98.0 |
| LayoutReader (public model) | no | 98.0 | 98.3 |

### Old eval script (copy from original paper)

Expand Down
13 changes: 4 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (
LayoutLMv3ForTokenClassification.from_pretrained(os.getenv("LOAD_PATH", "./model"))
LayoutLMv3ForTokenClassification.from_pretrained(
os.getenv("LOAD_PATH", "hantian/layoutreader")
)
.bfloat16()
.to(device)
.eval()
)
data_collator = DataCollator(skip_input_ids=False, hidden_size=model.config.hidden_size)
data_collator = DataCollator()
app = FastAPI()


Expand All @@ -30,9 +32,6 @@ def get_config():


class PredictRequest(BaseModel):
spans: List[str] = Field(
..., examples=[["Hello", "World"]], description="Spans of text"
)
boxes: List[List[float]] = Field(
...,
examples=[[[2, 2, 3, 3], [1, 1, 2, 2]]],
Expand Down Expand Up @@ -65,10 +64,6 @@ def do_predict(boxes: List[List[int]]) -> List[int]:

@app.post("/predict")
def predict(request: PredictRequest) -> PredictResponse:
assert len(request.spans) == len(
request.boxes
), "The length of spans and boxes must be equal."

x_scale = 1000.0 / request.width
y_scale = 1000.0 / request.height

Expand Down
104 changes: 29 additions & 75 deletions tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,18 @@
import gzip
import json
import os
import pickle
import random
import re

import numpy as np
import tqdm
import typer
from loguru import logger

app = typer.Typer()


def draw_one(ids: list[list], texts: list[str]):
import cv2

img = np.empty((1000, 1000, 3), dtype="uint8")
img.fill(255)
for i, (_, left, top, right, bottom) in enumerate(ids):
cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0), 1)
cv2.putText(
img,
str(i),
(left, top - 1),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 0, 255),
1,
)
cv2.putText(
img,
texts[i],
(left, bottom - 3),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
(255, 0, 0),
1,
)
cv2.imshow("img", img)
return cv2.waitKey(0)


@app.command()
def draw(input_file: str):
with open(input_file, "rb") as f:
datas = pickle.load(f)

for data in datas:
if draw_one(data["target_ids"], data["target_texts"]) == ord("q"):
break


def read_raws(path: str) -> list:
def read_raws(path: str):
logger.info("Creating features from dataset at {}", path)
examples = []
if os.path.isdir(path):
text_files = glob.glob(f"{path}/*text*.json")
layout_files = [re.sub("text|txt", "layout", x, 1) for x in text_files]
Expand All @@ -68,24 +26,26 @@ def read_raws(path: str) -> list:
for i, (text_line, layout_line) in enumerate(
zip(text_reader, layout_reader)
):
if (i + 1) % 10000 == 0:
logger.info(f"{i + 1} lines ...")
examples.append((json.loads(text_line), json.loads(layout_line)))
return examples
yield json.loads(text_line), json.loads(layout_line)


@app.command()
def cache_dataset_spans(
path: str,
output_file: str,
shuffle: bool = True,
src_shuffle_rate: float = 0.5,
def create_dataset_spans(
path: str = typer.Argument(
...,
help="Path to the dataset, like `./train/`",
),
output_file: str = typer.Argument(
..., help="Path to the output file, like `./train.jsonl.gz`"
),
src_shuffle_rate: float = typer.Option(
0.5, help="The rate to shuffle input's order"
),
):
random.seed(42)
examples = read_raws(path)

features = []
for text, layout in tqdm.tqdm(examples):
logger.info("Saving features into file {}", output_file)
f_out = gzip.open(output_file, "wt")
for text, layout in tqdm.tqdm(read_raws(path)):
target_boxes = []
target_texts = []
last_box = [0, 0, 0, 0]
Expand Down Expand Up @@ -136,26 +96,20 @@ def cache_dataset_spans(
target_index[i] = j
j += 1

# if draw_one(source_boxes, source_texts) == ord("q"):
# return
features.append(
{
"source_boxes": source_boxes,
"source_texts": source_texts,
"target_boxes": target_boxes,
"target_texts": target_texts,
"target_index": target_index,
"bleu": text["bleu"],
}
f_out.write(
json.dumps(
{
"source_boxes": source_boxes,
"source_texts": source_texts,
"target_boxes": target_boxes,
"target_texts": target_texts,
"target_index": target_index,
"bleu": text["bleu"],
}
)
+ "\n"
)

if shuffle:
random.shuffle(features)

logger.info("Saving features into cached file {}", output_file)
with gzip.open(output_file, "wt") as f:
for feature in tqdm.tqdm(features):
f.write(json.dumps(feature) + "\n")
f_out.close()


if __name__ == "__main__":
Expand Down
11 changes: 1 addition & 10 deletions v3/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@


class DataCollator:
def __init__(self, skip_input_ids: bool, hidden_size: int):
self.skip_input_ids = skip_input_ids
self.hidden_size = hidden_size

def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
Expand Down Expand Up @@ -58,13 +54,8 @@ def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
if self.skip_input_ids:
ret["inputs_embeds"] = torch.zeros(
(len(features), max_len, self.hidden_size)
)
else:
ret["input_ids"] = torch.tensor(input_ids)
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
Expand Down
6 changes: 1 addition & 5 deletions v3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ class Arguments(TrainingArguments):
default=None,
metadata={"help": "Path to dataset"},
)
skip_input_ids: bool = field(
default=False,
metadata={"help": "Whether to skip input ids"},
)


def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset):
Expand Down Expand Up @@ -56,7 +52,7 @@ def main():
model = LayoutLMv3ForTokenClassification.from_pretrained(
args.model_dir, num_labels=MAX_LEN, visual_embed=False
)
data_collator = DataCollator(args.skip_input_ids, model.config.hidden_size)
data_collator = DataCollator()
trainer = Trainer(
model=model,
args=args,
Expand Down
1 change: 1 addition & 0 deletions v3/train.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DATA_DIR="${DIR}/ReadingBank/"
mkdir -p "${OUTPUT_DIR}"

deepspeed train.py \
--model_dir 'microsoft/layoutlmv3-large' \
--dataset_dir "${DATA_DIR}" \
--dataloader_num_workers 1 \
--deepspeed ds_config.json \
Expand Down

0 comments on commit a1abc18

Please sign in to comment.