Skip to content

Commit

Permalink
Single prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolausWest committed Apr 20, 2023
1 parent 590a505 commit e1eb004
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
35 changes: 18 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ def log_images_segmentation(args, model: GroundingDINO, predictor: Sam):
rr.set_time_sequence("image", n)
image, image_tensor = load_image(image_uri)
predictor.set_image(image)
for prompt in args.prompts:
# run grounding dino model
logging.info(f"Running GroundedDINO with DETECTION PROMPT {prompt}.")
boxes_filt, pred_phrases = get_grounding_output(
model, image_tensor, prompt, 0.3, 0.25, device=args.device
)
# denormalize boxes (from [0, 1] to image size)
H, W, _ = image.shape
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]

run_segmentation(predictor, image, boxes_filt, prompt)
prompt = args.prompt
# run grounding dino model
logging.info(f"Running GroundedDINO with DETECTION PROMPT {prompt}.")
boxes_filt, box_phrases = get_grounding_output(
model, image_tensor, prompt, 0.3, 0.25, device=args.device
)
logging.info(f"Grounded output with prediction phrases: {box_phrases}")
# denormalize boxes (from [0, 1] to image size)
H, W, _ = image.shape
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]

run_segmentation(predictor, image, boxes_filt, box_phrases)


def log_video_segmentation(args, model: GroundingDINO, predictor: Sam):
Expand All @@ -68,6 +69,7 @@ def log_video_segmentation(args, model: GroundingDINO, predictor: Sam):
boxes_filt, pred_phrases = get_grounding_output(
model, image_tensor, prompt, 0.3, 0.25, device=args.device
)
logging.info(f"Grounded output with prediction phrases: {pred_phrases}")
# denormalize boxes (from [0, 1] to image size)
H, W, _ = rgb.shape
for i in range(boxes_filt.size(0)):
Expand Down Expand Up @@ -103,9 +105,8 @@ def main() -> None:
)

parser.add_argument(
"--prompts",
default=["tires", "windows"],
nargs="+",
"--prompt",
default="tires and windows",
type=str,
help="List of prompts to use for bounding box detection.",
)
Expand Down
32 changes: 18 additions & 14 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Final, Tuple
from typing import Final, List, Tuple
from urllib.parse import urlparse

import cv2
Expand Down Expand Up @@ -85,7 +85,7 @@ def create_sam(model: str, device: str) -> Sam:


def run_segmentation(
predictor: SamPredictor, image: Mat, boxes_filt, prompt: str
predictor: SamPredictor, image: Mat, boxes_filt, box_phrases: List[str]
) -> None:
"""Run segmentation on a single image."""
rr.log_image("image", image)
Expand All @@ -103,34 +103,38 @@ def run_segmentation(
multimask_output=False,
)



logging.info("Found {} masks".format(len(masks)))
# mask_tensor = masks.squeeze().numpy().astype("uint8") * 128
# rr.log_tensor(f"query_{idx}/mask_tensor", mask_tensor)

# TODO(jleibs): we could instead draw each mask as a separate image layer, but the current layer-stacking
# does not produce great results.
masks_with_ids = list(enumerate(masks.cpu(), start=1))
id_from_phrase = {phrase: i for i, phrase in enumerate(set(box_phrases))}
mask_ids = [id_from_phrase[phrase] for phrase in box_phrases] # One mask per box

# Work-around for https://github.com/rerun-io/rerun/issues/1782
# Make sure we have an AnnotationInfo present for every class-id used in this image
# TODO(jleibs): Remove when fix is released
rr.log_annotation_context(
"image",
[rr.AnnotationInfo(id) for id, _ in masks_with_ids],
[rr.AnnotationInfo(id=id, label=phrase)
for phrase, id in id_from_phrase.items()],
timeless=False,
)

# Layer all of the masks together, using the id as class-id in the segmentation
segmentation_img = np.zeros((image.shape[0], image.shape[1]))
for id, m in masks_with_ids:
segmentation_img[m.squeeze()] = id
# Layer all of the masks that belong to a single phrase together
for phrase, id in id_from_phrase.items():
segmentation_img = np.zeros((image.shape[0], image.shape[1]))
for mask_id, mask in zip(mask_ids, masks):
if mask_id == id:
segmentation_img[mask.squeeze()] = id

rr.log_segmentation_image(f"image/{prompt}/masks", segmentation_img)
rr.log_segmentation_image(f"image/{phrase}/mask", segmentation_img)

rr.log_rects(
f"image/{prompt}/boxes",
"image/boxes",
rects=boxes_filt.numpy(),
class_ids=[id for id, _ in masks_with_ids],
class_ids=mask_ids,
rect_format=RectFormat.XYXY,
)

Expand Down Expand Up @@ -213,7 +217,7 @@ def get_grounding_output(
caption: str,
box_threshold: float,
text_threshold: float,
with_logits: bool = True,
with_logits: bool = False,
device: str = "cpu",
):
caption = caption.lower()
Expand Down

0 comments on commit e1eb004

Please sign in to comment.