Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finally, the working version of YoloV8 Instance Segmentation #918

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import logging
from typing import Any, Dict, List, Optional

import cv2
import numpy as np
import torch

logger = logging.getLogger(__name__)

from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_bbox_from_bool_mask
from sahi.utils.import_utils import check_requirements


Expand Down Expand Up @@ -55,9 +58,12 @@ def perform_inference(self, image: np.ndarray):
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""

from ultralytics.engine.results import Masks

# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")

if self.image_size is not None: # ADDED IMAGE SIZE OPTION FOR YOLOV8 MODELS:
prediction_result = self.model(
image[:, :, ::-1], imgsz=self.image_size, verbose=False, device=self.device
Expand All @@ -66,11 +72,32 @@ def perform_inference(self, image: np.ndarray):
prediction_result = self.model(
image[:, :, ::-1], verbose=False, device=self.device
) # YOLOv8 expects numpy arrays to have BGR
prediction_result = [
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result
]

self._original_predictions = prediction_result
if self.has_mask:

if not prediction_result[0].masks:
prediction_result[0].masks = Masks(
torch.tensor([], device=self.model.device), prediction_result[0].boxes.orig_shape
)

prediction_result_ = [
(
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold],
result.masks.data[result.boxes.data[:, 4] >= self.confidence_threshold],
)
for result in prediction_result
]

else:
prediction_result_ = []
for result in prediction_result:
result_boxes = result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold]
result_masks = torch.tensor([[] for _ in range(result_boxes.size()[0])])
# result_masks = [torch.tensor([], device=self.model.device) for _ in result_boxes]
prediction_result_.append((result_boxes, result_masks))

self._original_predictions = prediction_result_
self._original_shape = image.shape

@property
def category_names(self):
Expand All @@ -88,7 +115,8 @@ def has_mask(self):
"""
Returns if model output contains segmentation mask
"""
return False # fix when yolov5 supports segmentation models
# return True
return self.model.overrides["task"] == "segment"

def _create_object_prediction_list_from_original_predictions(
self,
Expand All @@ -114,13 +142,19 @@ def _create_object_prediction_list_from_original_predictions(

# handle all predictions
object_prediction_list_per_image = []
for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions):
for image_ind, image_predictions in enumerate(original_predictions):

image_predictions_in_xyxy_format = image_predictions[0]
image_predictions_masks = image_predictions[1]

shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []

# process predictions
for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy():
for prediction, bool_mask in zip(
image_predictions_in_xyxy_format.cpu().detach().numpy(), image_predictions_masks.cpu().detach().numpy()
):
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
Expand All @@ -130,6 +164,20 @@ def _create_object_prediction_list_from_original_predictions(
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# parse prediction mask
if not self.has_mask:
# bool_mask = bool_mask
# check if mask is valid
# https://github.com/obss/sahi/discussions/696
# if get_bbox_from_bool_mask(bool_mask) is None:
# continue
# else:
bool_mask = None
else:
bool_mask = cv2.resize(bool_mask, (self._original_shape[1], self._original_shape[0]))
bool_mask[bool_mask >= 0.5] = 1
bool_mask[bool_mask < 0.5] = 0

# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
Expand All @@ -152,7 +200,7 @@ def _create_object_prediction_list_from_original_predictions(
bbox=bbox,
category_id=category_id,
score=score,
bool_mask=None,
bool_mask=bool_mask,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
Expand Down
85 changes: 85 additions & 0 deletions sahi/utils/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ class Yolov8TestConstants:
YOLOV8X_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt"
YOLOV8X_MODEL_PATH = "tests/data/models/yolov8/yolov8x.pt"

YOLOV8N_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-seg.pt"
YOLOV8N_SEG_PATH = "tests/data/models/yolov8/yolov8n-seg.pt"

YOLOV8S_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-seg.pt"
YOLOV8S_SEG_PATH = "tests/data/models/yolov8/yolov8s-seg.pt"

YOLOV8M_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-seg.pt"
YOLOV8M_SEG_PATH = "tests/data/models/yolov8/yolov8m-seg.pt"

YOLOV8L_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-seg.pt"
YOLOV8L_SEG_PATH = "tests/data/models/yolov8/yolov8l-seg.pt"

YOLOV8X_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt"
YOLOV8X_SEG_PATH = "tests/data/models/yolov8/yolov8x-seg.pt"


def download_yolov8n_model(destination_path: Optional[str] = None):
if destination_path is None:
Expand Down Expand Up @@ -84,3 +99,73 @@ def download_yolov8x_model(destination_path: Optional[str] = None):
Yolov8TestConstants.YOLOV8X_MODEL_URL,
destination_path,
)


def download_yolov8n_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8N_SEG_MODEL_URL,
destination_path,
)


def download_yolov8s_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8S_SEG_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8S_SEG_MODEL_URL,
destination_path,
)


def download_yolov8m_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8M_SEG_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8M_SEG_MODEL_URL,
destination_path,
)


def download_yolov8l_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8L_SEG_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8L_SEG_MODEL_URL,
destination_path,
)


def download_yolov8x_seg_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8X_SEG_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8X_SEG_MODEL_URL,
destination_path,
)
Loading