Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update ONNX provider search and warn if GPU will not be used #401

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions unstructured_inference/models/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py

import cv2
import logging
import numpy as np
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from PIL import Image as PILImage

from unstructured_inference.constants import ElementType, Source
Expand All @@ -20,6 +20,7 @@
download_if_needed_and_get_local_path,
)

logger = logging.getLogger(__name__)
YOLOX_LABEL_MAP = {
0: ElementType.CAPTION,
1: ElementType.FOOTNOTE,
Expand Down Expand Up @@ -72,13 +73,18 @@ def initialize(self, model_path: str, label_map: dict):
"""Start inference session for YoloX model."""
self.model_path = model_path

available_providers = C.get_available_providers()
available_providers = onnxruntime.get_available_providers()
ordered_providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
providers = [provider for provider in ordered_providers if provider in available_providers]
logger.info("Available ONNX runtime providers: %r", providers)
if "CUDAExecutionProvider" not in providers:
logger.info("If you expected to see CUDAExecutionProvider and it is not there, "
"you may need to install the appropriate version of onnxruntime-gpu "
"for your CUDA toolkit.")

self.model = onnxruntime.InferenceSession(
model_path,
Expand Down