Skip to content

Commit

Permalink
iblack & isort
Browse files Browse the repository at this point in the history
  • Loading branch information
edugzlez authored and fcakyon committed Nov 6, 2023
1 parent 2ddf2ef commit f04e446
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions sahi/models/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
logger = logging.getLogger(__name__)

from sahi.models.base import DetectionModel
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_requirements
from sahi.models.yolov8 import Yolov8DetectionModel


class RTDetrDetectionModel(Yolov8DetectionModel):
def check_dependencies(self) -> None:
Expand All @@ -29,4 +30,4 @@ def load_model(self):
model = RTDETR(self.model_path)
self.set_model(model)
except Exception as e:
raise TypeError("model_path is not a valid rtdet model path: ", e)
raise TypeError("model_path is not a valid rtdet model path: ", e)
4 changes: 3 additions & 1 deletion sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def perform_inference(self, image: np.ndarray):
# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
prediction_result = self.model(image[:, :, ::-1], verbose=False, device=self.device) # YOLOv8 expects numpy arrays to have BGR
prediction_result = self.model(
image[:, :, ::-1], verbose=False, device=self.device
) # YOLOv8 expects numpy arrays to have BGR
prediction_result = [
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result
]
Expand Down
4 changes: 2 additions & 2 deletions sahi/utils/rtdetr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import urllib.request
from os import path
from pathlib import Path
Expand All @@ -25,6 +24,7 @@ def download_rtdetrl_model(destination_path: Optional[str] = None):
destination_path,
)


def download_rtdetrx_model(destination_path: Optional[str] = None):
if destination_path is None:
destination_path = Yolov8TestConstants.RTDETRX_MODEL_PATH
Expand All @@ -35,4 +35,4 @@ def download_rtdetrx_model(destination_path: Optional[str] = None):
urllib.request.urlretrieve(
Yolov8TestConstants.RTDETRX_MODEL_URL,
destination_path,
)
)

0 comments on commit f04e446

Please sign in to comment.