Skip to content

Commit

Permalink
Merge pull request #4429 from voxel51/bugfix/iss-4428
Browse files Browse the repository at this point in the history
Fix YOLO-NAS inference
  • Loading branch information
brimoor authored May 28, 2024
2 parents 0b0cc99 + 12c7014 commit 0cbb722
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions fiftyone/utils/super_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@ def convert_super_gradients_model(model):


def _convert_yolo_nas_detection_model(model):
config_model = {
"model": model,
"labels_path": "{{eta-resources}}/ms-coco-labels.txt",
"output_processor_cls": "fiftyone.utils.torch.ClassifierOutputProcessor",
"raw_inputs": True,
}

config_model = {"model": model, "raw_inputs": True}
config = TorchYoloNasModelConfig(config_model)
return TorchYoloNasModel(config)

Expand Down Expand Up @@ -95,18 +89,16 @@ class TorchYoloNasModel(fout.TorchImageModel):

def _load_model(self, config):
if config.model is not None:
self._model = config.model
model = config.model
else:
if self._using_gpu:
self._model = super_gradients.training.models.get(
config.yolo_nas_model, pretrained_weights=config.pretrained
).cuda()
else:
self._model = super_gradients.training.models.get(
config.yolo_nas_model, pretrained_weights=config.pretrained
)
model = super_gradients.training.models.get(
config.yolo_nas_model, pretrained_weights=config.pretrained
)

if self._using_gpu:
model = model.cuda()

return self._model
return model

def _convert_bboxes(self, bboxes, w, h):
tmp = np.copy(bboxes[:, 1])
Expand Down Expand Up @@ -135,6 +127,7 @@ def _generate_detections(self, p):

if 0 in bboxes.shape:
return fo.Detections(detections=[])

bboxes = self._convert_bboxes(bboxes, width, height)
labels = [class_names[l] for l in labels]

Expand All @@ -144,6 +137,9 @@ def _generate_detections(self, p):
]
return fo.Detections(detections=detections)

def _predict_all(self, imgs):
preds = self._model.predict(imgs, conf=self.config.confidence_thresh)
def predict(self, img):
preds = self._model.predict(img, conf=self.config.confidence_thresh)
return self._generate_detections(preds)

def predict_all(self, imgs):
return [self.predict(img) for img in imgs]

0 comments on commit 0cbb722

Please sign in to comment.