Skip to content

Commit

Permalink
Customize YOLOv8 image_size & device + Allow Saving Slices (#929)
Browse files Browse the repository at this point in the history
Co-authored-by: fatih <[email protected]>
Co-authored-by: fcakyon <[email protected]>
  • Loading branch information
3 people authored Nov 6, 2023
1 parent 739b3dd commit 4c5f6b1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,7 @@ python -m scripts.run_code_style format

<a align="left" href="https://github.com/pranavdurai10" target="_blank">Pranav Durai</a>

<a align="left" href="https://github.com/lakshaymehra" target="_blank">Lakshay Mehra</a>

</div>

9 changes: 8 additions & 1 deletion sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ def perform_inference(self, image: np.ndarray):
# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
prediction_result = self.model(image[:, :, ::-1], verbose=False) # YOLOv8 expects numpy arrays to have BGR
if self.image_size is not None: # ADDED IMAGE SIZE OPTION FOR YOLOV8 MODELS:
prediction_result = self.model(
image[:, :, ::-1], imgsz=self.image_size, verbose=False, device=self.device
) # YOLOv8 expects numpy arrays to have BGR
else:
prediction_result = self.model(
image[:, :, ::-1], verbose=False, device=self.device
) # YOLOv8 expects numpy arrays to have BGR
prediction_result = [
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result
]
Expand Down
4 changes: 4 additions & 0 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def get_prediction(
def get_sliced_prediction(
image,
detection_model=None,
output_file_name=None, # ADDED OUTPUT FILE NAME TO (OPTIONALLY) SAVE SLICES
interim_dir="slices/", # ADDED INTERIM DIRECTORY TO (OPTIONALLY) SAVE SLICES
slice_height: int = None,
slice_width: int = None,
overlap_height_ratio: float = 0.2,
Expand Down Expand Up @@ -199,6 +201,8 @@ def get_sliced_prediction(
time_start = time.time()
slice_image_result = slice_image(
image=image,
output_file_name=output_file_name, # ADDED OUTPUT FILE NAME TO (OPTIONALLY) SAVE SLICES
output_dir=interim_dir, # ADDED INTERIM DIRECTORY TO (OPTIONALLY) SAVE SLICES
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
Expand Down

0 comments on commit 4c5f6b1

Please sign in to comment.