Skip to content

thomaspwang/check-classification

Repository files navigation

Check Classification and Fraud Detection

Overview

This repo provides a set of modules and functions to extract information from check images as well as classify check types. It supports different OCR strategies including LLaVA, DocTR and Amazon Textract.

Folder Structure

extraction:
    analyze_checks:
        Script for analyzing a folder of checks and extracting data based on specified strategies to a CSV file. 
        It supports various extraction strategies like LLaVA and Textract.

    extract_bboxes:
        Module for extracting and visualizing bounding boxes from check images using AWS Textract.

    classify_treasury:
        Module for determining whether a check is a United States Treasury check using the LLaVA model.

    parse_bbox:
        Module for extracting specific text from images using various OCR models, such as LLaVA and DocTR.
        It handles the extraction process by cropping the image to the bounding box and passing it to the appropriate model.

    extract:
        Module for extracting all text data from a check image using either LLaVA or DocTR models.

    extract_micr:
        Module for extracting MICR (Magnetic Ink Character Recognition) data from check images using AWS Textract.

    extraction_utils:
        Utility functions for check data extraction. Provides functions for cropping images to bounding boxes, 
        merging bounding boxes, and stretching bounding boxes.

    llava_wrapper:
        Wraps the LLaVA library into a callable class, which allows one to feed in a prompt and an image
        and recieve a text response.

scripts:
    compare_predictions_to_labels:
        Script for comparing predictions generated by analyze_checks.py and true labels. This script generates various aggregate statistics.

Environment Setup


Step 0
We used AWS SageMaker notebooks throughout the development of this project. Ensure that your notebook is using at least a ml.g5.2xlarge instance. However, to run models in their non-quantized form, you must use at least a ml.g5.12xlarge instance. Lastly, we recommend using a volume size of atleast 500GB EBS since the LLaVA model is large.


Step 1
Make sure you have an updated Python version with

python --version  # should be 3.10 or higher


Step 2
Navigate to the root of this repository (sofi-check-classification) and create a new virtual environment with either

python -m venv venv
# or
python3 -m venv venv


Step 3
Activate the venv with

source venv/bin/activate


Step 4
Install requirements with

pip install -r requirements.txt


Step 5
Even after installing requirements.txt, you may need need to manually install these:

pip install python-doctr
pip install torch torchvision torchaudio

Note: We've found that we needed to do this on EC2 instances (ml.g5.12xlarge). Note: pip install torch torchvision torchaudio assumes that your system has Cuda 12.1.


Step 6
To install llava, run

pip install -e git+https://github.com/haotian-liu/LLaVA.git@c121f04#egg=llava


Step 7
Once within the notebook terminal, run

export HOME="/home/ec2-user/SageMaker"

This is because the 500GB drive is mounted on SageMaker, while a smaller 50GB drive is mounted at /home/ec2-user. An out-of-disk-space error will occur as libraries and model weights will download to the default HOME=/home/ec2-user.


Step 8
For local testing and development, we recommend creating a local folder such as sofi-check-classification/data for PII images and labeled data.

Usage Example

import boto3
from classify_treasury import is_treasury_check
from extract_micr import extract_micr_data, MICRData, MICRExtractionError
from extract import extract_data
from extract_bboxes import BoundingBox
from parse_bbox import parse_bbox, ExtractMode, generate_LLaVA_model

AWS_PROFILE_NAME = ...
AWS_AWS_REGION_NAME = ...

INPUT_IMAGE_PATH = Path("...")

# Generating Models
llava_model = generate_LLaVA_model()

session = boto3.Session(profile_name=AWS_PROFILE_NAME)
textract_client = session.client('textract', region_name=AWS_AWS_REGION_NAME)

# Classifying Treasury Checks
is_treasury_check: bool = is_treasury_check(INPUT_IMAGE_PATH, llava_model)

# Extracting MICR data
try:
    micr_data: MICRData = extract_micr_data(INPUT_IMAGE_PATH, textract_client)
except MICRExtractionError as e:
    raise

# Scraping specific check data using LLaVA
PROMPT = "Scan this check and output the check amount as a string"
llava_check_amount_output: str = parse_bbox(INPUT_IMAGE_PATH, box=None, ExtractMode.LLAVA, llava_model, PROMPT)

# Scraping all check data using LLaVA and doctr
check_data_doctr: list[str] = extract_data(INPUT_IMAGE_PATH, textract_client, ExractMode.DOC_TR)
check_data_llava: list[str] = extract_data(INPUT_IMAGE_PATH, textract_client, ExractMode.LLAVA)

Demos


Extracting Bounding Boxes
Writes a full-sized check image with the bounding boxes draw on it to a specified output file.
python extract_bboxes.py ../data/mcd-test-3-front-images/mcd-test-3-front-93.jpg output_image.jpg


Extracting MICR from an image
Prints out a MICRData dataclass object generated from a full-sized check image to the console.
python extract_micr.py ../data/mcd-test-3-front-images/mcd-test-3-front-93.jpg


Treasury Check Classification
Prints out whether or not a given full-sized input check is a treasury check or not.
python classify_treasury.py ../data/mcd-test-4-front-images/mcd-test-4-front-70.jpg


Extracting all data from a check image
Prints out all text data extracted from a full-sized check image as a list of strings.
python extract.py ../data/mcd-test-3-front-images/mcd-test-3-front-93.jpg --model llava

Benchmarking Algorithms


1. Add or a Select a Strategy
Go to extraction/analyze_checks.py, and follow the module docstring to add a strategy or select an existing strategy of {LLAVA_AMOUNT_AND_NAME, TEXTRACT_MICR, LLAVA_TREASURY}.


2. Obtain a Dataset and Labels
Obtain a dataset of check images and put it in a folder. The image file names must be in the format mcd-test-N-front-##.jpg, where N is one digit and ## can be any amount of digits.

The label file should be a csv with each row corresponding directly to the check number (e.g. row 0 are the headers, and row 1 corresponds to mcd-test-N-front-1.jpg. The label file must be contiguous, but the dataset does not need to include all files specified by the label set. The datasets provided by @jtongseely (N=1-5) should all work. Examples of label files are in the datasets provided by @jtongseely; convert the mcd-test-N-image-details.numbers to csvs and that will be a valid label file.


3. Run Inference on the Dataset
cd sofi-check-classification
python extraction/analyze_checks.py <path_to_dataset_folder> <path_to_output_csv> <strategy>

example:
python extraction/analyze_checks.py ../mcd-test-3-front-images/ ../LLAVA_TREASURY_PREDICTIONS.csv LLAVA_TREASURY This will compute some statistics like seconds / inference and $ / inference.


4. Compute Aggregate Statistics
To compute statistics like hit rate, accuracy and edit distance, run python scripts/compare_predictions_to_labels.py <path_to_dataset_folder> <path_to_prediction_csv> <path_to_labels> --verbose

example:
python scripts/compare_predictions_to_labels.py ../mcd-test-3-front-images/ ../LLAVA_TREASURY_PREDICTIONS.csv ../mcd-test-3-image-details.csv --verbose

Possible TO-DOs

  • Configure environment variables automatically through dotenv or something instead of having redundant top-level variables such as AWS_REGION at the top of every file.
  • Configure logging

Debugging Tips

  • Using draw_bounding_boxes_on_image in extraction/extract_bboxes.py can be useful for visualizing bounding boxes. Note that bounding box coordinates are specific to a particular image, so boxes can only be drawn on the images they were generated on.

Initial Contributors

[email protected] [email protected]

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages