Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training guide and align text detection model training with recognition model #8

Merged
merged 11 commits into from
Jan 30, 2024
Merged
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
- name: Install pipenv
run: pip install pipenv
- name: Install dependencies
run: pipenv install --dev
run: |
pipenv install --dev
pipenv run pip install torch torchvision
- name: Check formatting and types
run: pipenv run qa
3 changes: 1 addition & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ verify_ssl = true
name = "pypi"

[packages]
torch = "*"
numpy = "*"
torchvision = "*"
pillow = "*"
tqdm = "*"
opencv-python = "*"
shapely = "*"
wandb = "*"
pylev = "*"
onnx = "*"

[dev-packages]
black = "*"
Expand Down
1,147 changes: 539 additions & 608 deletions Pipfile.lock

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
This project contains tools for training PyTorch models for use with the
[**Ocrs**](https://github.com/robertknight/ocrs/) OCR engine.

## About the models

The ocrs engine splits text detection and recognition into three phases, each
of which corresponds to a different model in this repository:

Expand All @@ -22,3 +24,14 @@ All models can be exported to ONNX for downstream use.

The models are trained exclusively on datasets which are a) open and b) have non-restrictive licenses. This currently includes:
- [HierText](https://github.com/google-research-datasets/hiertext) (CC-BY-SA 4.0)

## Pre-trained models

Pre-trained models are available from [Hugging
Face](https://huggingface.co/robertknight/ocrs) as PyTorch checkpoints,
[ONNX](https://onnx.ai) and [RTen](https://github.com/robertknight/rten) models.

## Training custom models

See the [Training guide](docs/training.md) for a walk-through of the process to
train models from scratch or fine-tune existing models.
190 changes: 190 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Training Ocrs models

This document describes how to train models for use with
[ocrs](https://github.com/robertknight/ocrs).

## Prerequisites

To train the models you will need:

- Python 3.10 or later
- A GPU. The initial training was done on NVidia A10G GPUs with 24 GB RAM, via
[AWS EC2 G5 instances](https://aws.amazon.com/ec2/instance-types/g5/) (the
smallest `g5.xlarge` size will work).
- Optional: A Weights and Biases account (https://wandb.ai/) to track training progress

## About Ocrs models

Ocrs splits the OCR process into three stages:

1. Text detection
2. Layout analysis
3. Text recognition

Each of these stages corresponds to a separate PyTorch model. The layout
analysis model is incomplete and is not currently used in Ocrs.

You can mix and match default/pre-trained and custom models for the different
stages. For example you may wish to use a pre-trained detection model but a
custom recognition model.

## Download the dataset

Following the instructions in
https://github.com/google-research-datasets/hiertext#getting-started, clone the
HierText repository and download the training data.

Note that you do **not** need to follow the step about decompressing the
`.jsonl.gz` files. The training tools will do this for you.

The compressed dataset is ~3.6 GB in total size.

```
# Clone the HierText repository. This contains the ground truth data.
mkdir -p datasets/
cd datasets
git clone https://github.com/google-research-datasets/hiertext.git
cd hiertext

# Download the training, validation and test images.
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/train.tgz .
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/validation.tgz .
aws s3 --no-sign-request cp s3://open-images-dataset/ocr/test.tgz .

# Decompress the datasets.
tar -xf train.tgz
tar -xf validation.tgz
tar -xf test.tgz
```

## Set up the training environment

1. Install [Pipenv](https://pipenv.pypa.io/en/latest/)
2. Install dependencies, except for PyTorch:

```
pipenv install --dev
```

3. Install the appropriate version of PyTorch for your system, in the virtualenv
created by pipenv:

```
pipenv run pip install torch torchvision
```

See https://pytorch.org/get-started/locally/ for an appropriate pip command
depending on your platform and GPU.

4. Start a dummy training run of text detection training to verify everything is working:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ --max-images 100
```

Wait for one successful epoch of training and validation to complete and then
exit the process with Ctrl+C.

## Set up Weights and Biases integration (optional)

The ocrs-models training scripts support tracking training progress using
[Weights and Biases](https://wandb.ai). To enable this you will need to create
an account and then set the `WANDB_API_KEY` environment variable before running
training scripts:

```
export WANDB_API_KEY=<your_api_key>
```

## Train the text detection model

To launch a training run for the text detection model, run:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ \
--max-epochs 50 \
--batch-size 28
```

The `--batch-size` flag will need to be varied according to the amount of GPU
memory you have available. One way to do this is to start with a small value,
and then increase it until the training process is using most of the available
GPU memory. The above value was used with a GPU that has 24 GB of memory. When
training with an NVidia GPU, you can use the `nvidia-smi` tool to get memory
usage statistics.

To fine-tune an existing model, pass the `--checkpoint` flag to specify the
pre-trained model to start with.

### Export the text detection model

As training progresses, the latest checkpoint will be saved to
`text-detection-checkpoint.pt`. Once training completes, you can export the
model to ONNX via:

```
pipenv run python -m ocrs_models.train_detection hiertext datasets/hiertext/ \
--checkpoint text-detection-checkpoint.pt \
--export text-detection.onnx
```

### Convert the text detection model

To use the exported ONNX model with Ocrs, you will need to convert it to
the `.rten` format used by [RTen][rten].

See the [RTen README](https://github.com/robertknight/rten#getting-started)
for current instructions on how to do this.

To use the converted model with the `ocrs` CLI tool, you can either pass the
model path via CLI arguments, or replace the default models in the cache
directory (`~/.cache/ocrs`). Example using CLI arguments:

```sh
ocrs --detect-model custom-detection-model.rten image.jpg
```

[rten]: https://github.com/robertknight/rten

## Train the text recognition model

To launch a training run for the text recognition model, run:

```
pipenv run python -m ocrs_models.train_rec hiertext datasets/hiertext/ \
--max-epochs 50 \
--batch-size 250
```

The `--batch-size` flag will need to be varied according to the amount of GPU
memory you have available. One way to do this is to start with a small value,
and then increase it until the training process is using most of the available
GPU memory. The above value was used with a GPU that has 24 GB of memory.

To fine-tune an existing model, pass the `--checkpoint` flag to specify the
pre-trained model to start with.

### Export the text recognition model

As training progresses, the latest checkpoint will be saved to
`text-rec-checkpoint.pt`. Once training completes, you can export the model to
ONNX via:

```
pipenv run python -m ocrs_models.train_rec hiertext datasets/hiertext/ \
--checkpoint text-rec.pt \
--export text-recognition.onnx
```

### Convert the text recognition model

To use the exported ONNX models with Ocrs, convert it to `.rten` format using
the same process as for the detection model.

To use the converted model with the `ocrs` CLI tool, you can either pass the
model path via CLI arguments, or replace the default models in the cache
directory (`~/.cache/ocrs`). Example using CLI arguments:

```sh
ocrs --rec-model custom-recognition-model.rten image.jpg
```
6 changes: 0 additions & 6 deletions ocrs_models/eval_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def main():
parser.add_argument("model")
parser.add_argument("image")
parser.add_argument("out_basename")
parser.add_argument(
"--export", type=str, help="Export model as ONNX after evaluation"
)
args = parser.parse_args()

model = DetectionModel()
Expand All @@ -51,9 +48,6 @@ def main():
pred_masks = model(img)
end = time.time()

if args.export:
torch.onnx.export(model, img, args.export)

print(f"Predicted text in {end - start:.2f}s", file=sys.stderr)

pred_masks = pred_masks[0] # Remove dummy batch dimension
Expand Down
35 changes: 34 additions & 1 deletion ocrs_models/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ def expand_quads(quads: torch.Tensor, dist: float) -> torch.Tensor:
return torch.stack([expand_quad(quad, dist) for quad in quads])


def lines_intersect(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
"""
Return true if the lines (a_start, a_end) and (b_start, b_end) intersect.
"""
if a_start <= b_start:
return a_end > b_start
else:
return b_end > a_start


def bounds_intersect(
a: tuple[float, float, float, float], b: tuple[float, float, float, float]
) -> bool:
"""
Return true if the rects defined by two (min_x, min_y, max_x, max_y) tuples intersect.
"""
a_min_x, a_min_y, a_max_x, a_max_y = a
b_min_x, b_min_y, b_max_x, b_max_y = b
return lines_intersect(a_min_x, a_max_x, b_min_x, b_max_x) and lines_intersect(
a_min_y, a_max_y, b_min_y, b_max_y
)


def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, float]:
"""
Compute metrics for quality of matches between two sets of rotated rects.
Expand All @@ -99,12 +122,22 @@ def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, flo
# Areas of unions of predictions and targets
union = torch.zeros((len(pred), len(target)))

# Get bounding boxes of polys for a cheap intersection test.
pred_polys_bounds = [poly.bounds for poly in pred_polys]
target_polys_bounds = [poly.bounds for poly in target_polys]

pred_areas = torch.zeros((len(pred),))
for pred_index, pred_poly in enumerate(pred_polys):
pred_areas[pred_index] = pred_poly.area
pred_bounds = pred_polys_bounds[pred_index]

for target_index, target_poly in enumerate(target_polys):
if not pred_poly.intersects(target_poly):
# Do a cheap intersection test and skip computing the actual
# union/intersection if that fails.
target_bounds = target_polys_bounds[target_index]
if not bounds_intersect(pred_bounds, target_bounds):
continue

pt_intersection = pred_poly.intersection(target_poly)
intersection[pred_index, target_index] = pt_intersection.area

Expand Down
Loading
Loading