Skip to content

Commit

Permalink
fix styling
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon committed Nov 5, 2023
1 parent 6a98030 commit a23587d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
28 changes: 18 additions & 10 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import logging
from typing import Any, Dict, List, Optional

import cv2
import numpy as np
import torch
import cv2

logger = logging.getLogger(__name__)

from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_requirements
from sahi.utils.cv import get_bbox_from_bool_mask
from sahi.utils.import_utils import check_requirements


class Yolov8DetectionModel(DetectionModel):
Expand Down Expand Up @@ -57,7 +57,7 @@ def perform_inference(self, image: np.ndarray):
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""

from ultralytics.yolo.engine.results import Masks

# Confirm model is loaded
Expand All @@ -67,11 +67,17 @@ def perform_inference(self, image: np.ndarray):
if self.has_mask:

if not prediction_result[0].masks:
prediction_result[0].masks = Masks(torch.tensor([], device=self.model.device), prediction_result[0].boxes.orig_shape)
prediction_result[0].masks = Masks(
torch.tensor([], device=self.model.device), prediction_result[0].boxes.orig_shape
)

prediction_result_ = [(result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold],
result.masks.data[result.boxes.data[:, 4] >= self.confidence_threshold])
for result in prediction_result]
prediction_result_ = [
(
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold],
result.masks.data[result.boxes.data[:, 4] >= self.confidence_threshold],
)
for result in prediction_result
]

else:
prediction_result_ = []
Expand Down Expand Up @@ -101,7 +107,7 @@ def has_mask(self):
Returns if model output contains segmentation mask
"""
# return True
return self.model.overrides['task'] == 'segment'
return self.model.overrides["task"] == "segment"

def _create_object_prediction_list_from_original_predictions(
self,
Expand Down Expand Up @@ -130,14 +136,16 @@ def _create_object_prediction_list_from_original_predictions(
for image_ind, image_predictions in enumerate(original_predictions):

image_predictions_in_xyxy_format = image_predictions[0]
image_predictions_masks = image_predictions[1]
image_predictions_masks = image_predictions[1]

shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []

# process predictions
for prediction, bool_mask in zip(image_predictions_in_xyxy_format.cpu().detach().numpy(), image_predictions_masks.cpu().detach().numpy()):
for prediction, bool_mask in zip(
image_predictions_in_xyxy_format.cpu().detach().numpy(), image_predictions_masks.cpu().detach().numpy()
):
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
Expand Down
5 changes: 5 additions & 0 deletions sahi/utils/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def download_yolov8x_model(destination_path: Optional[str] = None):
destination_path,
)


def download_yolov8n_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand All @@ -113,6 +114,7 @@ def download_yolov8n_seg_model(destination_path: Optional[str] = None):
destination_path,
)


def download_yolov8s_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand All @@ -126,6 +128,7 @@ def download_yolov8s_seg_model(destination_path: Optional[str] = None):
destination_path,
)


def download_yolov8m_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand All @@ -139,6 +142,7 @@ def download_yolov8m_seg_model(destination_path: Optional[str] = None):
destination_path,
)


def download_yolov8l_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand All @@ -152,6 +156,7 @@ def download_yolov8l_seg_model(destination_path: Optional[str] = None):
destination_path,
)


def download_yolov8x_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand Down

0 comments on commit a23587d

Please sign in to comment.