diff --git a/sahi/models/yolov8.py b/sahi/models/yolov8.py index 3ae40bda7..f4165d530 100644 --- a/sahi/models/yolov8.py +++ b/sahi/models/yolov8.py @@ -4,13 +4,16 @@ import logging from typing import Any, Dict, List, Optional +import cv2 import numpy as np +import torch logger = logging.getLogger(__name__) from sahi.models.base import DetectionModel from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list +from sahi.utils.cv import get_bbox_from_bool_mask from sahi.utils.import_utils import check_requirements @@ -55,9 +58,12 @@ def perform_inference(self, image: np.ndarray): A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. """ + from ultralytics.engine.results import Masks + # Confirm model is loaded if self.model is None: raise ValueError("Model is not loaded, load it by calling .load_model()") + if self.image_size is not None: # ADDED IMAGE SIZE OPTION FOR YOLOV8 MODELS: prediction_result = self.model( image[:, :, ::-1], imgsz=self.image_size, verbose=False, device=self.device @@ -66,11 +72,32 @@ def perform_inference(self, image: np.ndarray): prediction_result = self.model( image[:, :, ::-1], verbose=False, device=self.device ) # YOLOv8 expects numpy arrays to have BGR - prediction_result = [ - result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result - ] - self._original_predictions = prediction_result + if self.has_mask: + + if not prediction_result[0].masks: + prediction_result[0].masks = Masks( + torch.tensor([], device=self.model.device), prediction_result[0].boxes.orig_shape + ) + + prediction_result_ = [ + ( + result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold], + result.masks.data[result.boxes.data[:, 4] >= self.confidence_threshold], + ) + for result in prediction_result + ] + + else: + prediction_result_ = [] + for result in prediction_result: + result_boxes = result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] + result_masks = torch.tensor([[] for _ in range(result_boxes.size()[0])]) + # result_masks = [torch.tensor([], device=self.model.device) for _ in result_boxes] + prediction_result_.append((result_boxes, result_masks)) + + self._original_predictions = prediction_result_ + self._original_shape = image.shape @property def category_names(self): @@ -88,7 +115,8 @@ def has_mask(self): """ Returns if model output contains segmentation mask """ - return False # fix when yolov5 supports segmentation models + # return True + return self.model.overrides["task"] == "segment" def _create_object_prediction_list_from_original_predictions( self, @@ -114,13 +142,19 @@ def _create_object_prediction_list_from_original_predictions( # handle all predictions object_prediction_list_per_image = [] - for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions): + for image_ind, image_predictions in enumerate(original_predictions): + + image_predictions_in_xyxy_format = image_predictions[0] + image_predictions_masks = image_predictions[1] + shift_amount = shift_amount_list[image_ind] full_shape = None if full_shape_list is None else full_shape_list[image_ind] object_prediction_list = [] # process predictions - for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy(): + for prediction, bool_mask in zip( + image_predictions_in_xyxy_format.cpu().detach().numpy(), image_predictions_masks.cpu().detach().numpy() + ): x1 = prediction[0] y1 = prediction[1] x2 = prediction[2] @@ -130,6 +164,20 @@ def _create_object_prediction_list_from_original_predictions( category_id = int(prediction[5]) category_name = self.category_mapping[str(category_id)] + # parse prediction mask + if not self.has_mask: + # bool_mask = bool_mask + # check if mask is valid + # https://github.com/obss/sahi/discussions/696 + # if get_bbox_from_bool_mask(bool_mask) is None: + # continue + # else: + bool_mask = None + else: + bool_mask = cv2.resize(bool_mask, (self._original_shape[1], self._original_shape[0])) + bool_mask[bool_mask >= 0.5] = 1 + bool_mask[bool_mask < 0.5] = 0 + # fix negative box coords bbox[0] = max(0, bbox[0]) bbox[1] = max(0, bbox[1]) @@ -152,7 +200,7 @@ def _create_object_prediction_list_from_original_predictions( bbox=bbox, category_id=category_id, score=score, - bool_mask=None, + bool_mask=bool_mask, category_name=category_name, shift_amount=shift_amount, full_shape=full_shape, diff --git a/sahi/utils/yolov8.py b/sahi/utils/yolov8.py index aca2780b5..ea4246c63 100644 --- a/sahi/utils/yolov8.py +++ b/sahi/utils/yolov8.py @@ -20,6 +20,21 @@ class Yolov8TestConstants: YOLOV8X_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt" YOLOV8X_MODEL_PATH = "tests/data/models/yolov8/yolov8x.pt" + YOLOV8N_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-seg.pt" + YOLOV8N_SEG_PATH = "tests/data/models/yolov8/yolov8n-seg.pt" + + YOLOV8S_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-seg.pt" + YOLOV8S_SEG_PATH = "tests/data/models/yolov8/yolov8s-seg.pt" + + YOLOV8M_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-seg.pt" + YOLOV8M_SEG_PATH = "tests/data/models/yolov8/yolov8m-seg.pt" + + YOLOV8L_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-seg.pt" + YOLOV8L_SEG_PATH = "tests/data/models/yolov8/yolov8l-seg.pt" + + YOLOV8X_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt" + YOLOV8X_SEG_PATH = "tests/data/models/yolov8/yolov8x-seg.pt" + def download_yolov8n_model(destination_path: Optional[str] = None): if destination_path is None: @@ -84,3 +99,73 @@ def download_yolov8x_model(destination_path: Optional[str] = None): Yolov8TestConstants.YOLOV8X_MODEL_URL, destination_path, ) + + +def download_yolov8n_seg_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov8TestConstants.YOLOV8N_SEG_MODEL_URL, + destination_path, + ) + + +def download_yolov8s_seg_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov8TestConstants.YOLOV8S_SEG_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov8TestConstants.YOLOV8S_SEG_MODEL_URL, + destination_path, + ) + + +def download_yolov8m_seg_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov8TestConstants.YOLOV8M_SEG_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov8TestConstants.YOLOV8M_SEG_MODEL_URL, + destination_path, + ) + + +def download_yolov8l_seg_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov8TestConstants.YOLOV8L_SEG_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov8TestConstants.YOLOV8L_SEG_MODEL_URL, + destination_path, + ) + + +def download_yolov8x_seg_model(destination_path: Optional[str] = None): + + if destination_path is None: + destination_path = Yolov8TestConstants.YOLOV8X_SEG_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + Yolov8TestConstants.YOLOV8X_SEG_MODEL_URL, + destination_path, + )