Skip to content

Commit

Permalink
Add check for target_sizes is None in `post_process_image_guided_de…
Browse files Browse the repository at this point in the history
…tection` for owlv2 (#31934)

* Add check for target_sizes is None in post_process_image_guided_detection

* Make sure Owlvit and Owlv2 in sync

* Fix incorrect indentation; add check for correct size of target_sizes
  • Loading branch information
catalys1 authored Jul 26, 2024
1 parent f9756d9 commit 5f841c7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/transformers/models/owlv2/image_processing_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,9 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes

if len(logits) != len(target_sizes):
if target_sizes is not None and len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

probs = torch.max(logits, dim=-1)
Expand All @@ -588,9 +588,14 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh
scores[idx][ious > nms_threshold] = 0.0

# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.tensor([i[0] for i in target_sizes])
img_w = torch.tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]

# Compute box display alphas based on prediction scores
results = []
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/owlvit/image_processing_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,9 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes

if len(logits) != len(target_sizes):
if target_sizes is not None and len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

probs = torch.max(logits, dim=-1)
Expand All @@ -579,9 +579,14 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh
scores[idx][ious > nms_threshold] = 0.0

# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.tensor([i[0] for i in target_sizes])
img_w = torch.tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
target_boxes = target_boxes * scale_fct[:, None, :]

# Compute box display alphas based on prediction scores
results = []
Expand Down

0 comments on commit 5f841c7

Please sign in to comment.