import os
import time
from typing import List

import torch
import uvicorn
from fastapi import FastAPI
from loguru import logger
from pydantic import BaseModel, Field
from transformers import LayoutLMv3ForTokenClassification

from v3.helpers import MAX_LEN, parse_logits, prepare_inputs, boxes2inputs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (
    LayoutLMv3ForTokenClassification.from_pretrained(
        os.getenv("LOAD_PATH", "hantian/layoutreader")
    )
    .bfloat16()
    .to(device)
    .eval()
)
app = FastAPI()


@app.get("/config")
def get_config():
    return {
        "max_len": MAX_LEN,
    }


class PredictRequest(BaseModel):
    boxes: List[List[float]] = Field(
        ...,
        examples=[[[2, 2, 3, 3], [1, 1, 2, 2]]],
        description="Boxes of [left, top, right, bottom]",
    )
    width: float = Field(..., examples=[5], description="Page width")
    height: float = Field(..., examples=[5], description="Page height")


class PredictResponse(BaseModel):
    orders: List[int] = Field(..., examples=[[1, 0]], description="The order of spans.")
    elapsed: float = Field(
        ..., examples=[0.123], description="Elapsed time in seconds."
    )


def do_predict(boxes: List[List[int]]) -> List[int]:
    inputs = boxes2inputs(boxes)
    inputs = prepare_inputs(inputs, model)
    logits = model(**inputs).logits.cpu().squeeze(0)
    return parse_logits(logits, len(boxes))


@app.post("/predict")
def predict(request: PredictRequest) -> PredictResponse:
    x_scale = 1000.0 / request.width
    y_scale = 1000.0 / request.height

    boxes = []
    logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(request.boxes)}")
    for left, top, right, bottom in request.boxes:
        left = round(left * x_scale)
        top = round(top * y_scale)
        right = round(right * x_scale)
        bottom = round(bottom * y_scale)
        assert (
            1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
        ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
        boxes.append([left, top, right, bottom])

    start = time.time()
    orders = do_predict(boxes)
    ret = PredictResponse(orders=orders, elapsed=time.time() - start)
    logger.info(f"Input Length: {len(boxes)}, Predicted in {ret.elapsed:.3f}s.")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return ret


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)