-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/sg 1442 sliding window inference for yolonas (#1979)
* wip * wip * wip2 * working version, hard coded nms params * moved post prediction callback to utils * moved back to wrapper * added abstract class, small refactoring for pipeline * rolled back customizable detector, solved pretrained weights setting of proccessing for the wrapper * temp cleanup * support for fuse model in predict * example added for predict * added support for forward wrappers in trainer * added test for validation forward wrapper * added option for None as post prediction callback in DetectionMetrics * wip adding set_model before using wrapper * commit changes before removal of validation during training support * refined docs * removed old test for forward wrapper, fixed defaults * fixed test and added clarifications * forward wrapper test removed * updated wrong threshold extraction and test result * fixed docstring format
- Loading branch information
Showing
6 changed files
with
506 additions
and
13 deletions.
There are no files selected for viewing
20 changes: 20 additions & 0 deletions
20
src/super_gradients/examples/predict/sliding_sindow_detection_predict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
from super_gradients.common.object_names import Models | ||
from super_gradients.training import models | ||
|
||
|
||
# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported. | ||
from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper | ||
|
||
model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco") | ||
|
||
# We want to use cuda if available to speed up inference. | ||
model = model.to("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = SlidingWindowInferenceDetectionWrapper(model=model, tile_size=640, tile_step=160, tile_nms_conf=0.35) | ||
|
||
predictions = model.predict( | ||
"https://images.pexels.com/photos/7968254/pexels-photo-7968254.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2", skip_image_resizing=True | ||
) | ||
predictions.show() | ||
predictions.save(output_path="2.jpg") # Save in working directory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.