Skip to content

Commit

Permalink
Fix TorchScript exporting for custom checkpoints (#267)
Browse files Browse the repository at this point in the history
* Try to reproduce the bug in GH

* Fix torch.jit.script for loading custom checkpoint

* Use real image to test_load_from_yolov5_torchscript

* Minor fix in YOLOTRTModule

* Use YOLO.load_from_yolov5 in test_load_from_yolov5_torchscript

* Update images to test
  • Loading branch information
zhiqwang authored Jan 9, 2022
1 parent 9ebcea9 commit c9b62db
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 23 deletions.
49 changes: 46 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import Tensor
from yolort import models
from yolort.models import YOLOv5
from yolort.models import YOLO, YOLOv5
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion
Expand Down Expand Up @@ -303,7 +303,6 @@ def test_criterion(self, use_p6=False):
head_outputs = self._get_head_outputs(N, H, W)
strides = self._get_strides(use_p6)
anchor_grids = self._get_anchor_grids(use_p6)
num_anchors = len(anchor_grids)
num_classes = self.num_classes

targets = torch.tensor(
Expand All @@ -314,7 +313,7 @@ def test_criterion(self, use_p6=False):
[3.0000, 3.0000, 0.6305, 0.3290, 0.3274, 0.2270],
]
)
criterion = SetCriterion(num_anchors, strides, anchor_grids, num_classes)
criterion = SetCriterion(strides, anchor_grids, num_classes)
losses = criterion(targets, head_outputs)
assert isinstance(losses, dict)
assert isinstance(losses["cls_logits"], Tensor)
Expand Down Expand Up @@ -386,3 +385,47 @@ def test_load_from_yolov5(
torch.testing.assert_close(out_from_yolov5[0]["scores"], out[0]["scores"], rtol=0, atol=0)
torch.testing.assert_close(out_from_yolov5[0]["labels"], out[0]["labels"], rtol=0, atol=0)
torch.testing.assert_close(out_from_yolov5[0]["boxes"], out[0]["boxes"], rtol=0, atol=0)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_load_from_yolov5_torchscript(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
import cv2
from yolort.utils import read_image_to_tensor
from yolort.v5 import letterbox

# Loading and pre-processing the image
img_path = "test/assets/zidane.jpg"
img_raw = cv2.imread(img_path)
img = letterbox(img_raw, new_shape=(640, 640))[0]
img = read_image_to_tensor(img)

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

score_thresh = 0.25

model = YOLO.load_from_yolov5(checkpoint_path, score_thresh=score_thresh, version=version)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.eval()

out = model(img[None])
out_script = scripted_model(img[None])

torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0)
2 changes: 1 addition & 1 deletion yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def decode_single(
rel_codes: Tensor,
grid: Tensor,
shift: Tensor,
stride: int,
stride: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
From a set of original boxes and encoded relative box offsets,
Expand Down
4 changes: 2 additions & 2 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _generate_shifts(
device: torch.device = torch.device("cpu"),
) -> List[Tensor]:

anchors = torch.tensor(self.anchor_grids, dtype=dtype, device=device)
strides = torch.tensor(self.strides, dtype=dtype, device=device)
anchors = torch.as_tensor(self.anchor_grids, dtype=torch.float32, device=device).to(dtype=dtype)
strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)
anchors = anchors.view(self.num_layers, -1, 2) / strides.view(-1, 1, 1)

shifts = []
Expand Down
15 changes: 9 additions & 6 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, x: List[Tensor]) -> List[Tensor]:
return all_pred_logits


class SetCriterion:
class SetCriterion(nn.Module):
"""
This class computes the loss for YOLOv5.
Expand All @@ -98,7 +98,6 @@ class SetCriterion:

def __init__(
self,
num_anchors: int,
strides: List[int],
anchor_grids: List[List[float]],
num_classes: int,
Expand All @@ -112,12 +111,13 @@ def __init__(
label_smoothing: float = 0.0,
auto_balance: bool = False,
) -> None:
super().__init__()
assert len(strides) == len(anchor_grids)

self.num_anchors = num_anchors
self.num_classes = num_classes
self.strides = strides
self.anchor_grids = anchor_grids
self.num_anchors = len(anchor_grids[0]) // 2

self.balance = [4.0, 1.0, 0.4]
self.ssi = 0 # stride 16 index
Expand All @@ -142,7 +142,7 @@ def __init__(
self.obj_gain = obj_gain
self.anchor_thresh = anchor_thresh

def __call__(
def forward(
self,
targets: Tensor,
head_outputs: List[Tensor],
Expand Down Expand Up @@ -322,7 +322,7 @@ def _concat_pred_logits(
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
strides: List[int],
strides: Tensor,
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape
Expand Down Expand Up @@ -397,8 +397,11 @@ def forward(
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])
device = head_outputs[0].device
dtype = head_outputs[0].dtype
strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, strides)
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
Expand Down
11 changes: 1 addition & 10 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ class YOLO(nn.Module):
- scores (``Tensor[N]``): the scores or each prediction
"""

__annotations__ = {
"compute_loss": SetCriterion,
}

def __init__(
self,
backbone: nn.Module,
Expand Down Expand Up @@ -107,12 +103,7 @@ def __init__(
self.anchor_generator = anchor_generator

if criterion is None:
criterion = SetCriterion(
anchor_generator.num_anchors,
anchor_generator.strides,
anchor_generator.anchor_grids,
num_classes,
)
criterion = SetCriterion(strides, anchor_grids, num_classes)
self.compute_loss = criterion

if head is None:
Expand Down
5 changes: 4 additions & 1 deletion yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def forward(
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])
device = head_outputs[0].device
dtype = head_outputs[0].dtype
strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, strides)

bbox_regression = []
pred_scores = []
Expand Down

0 comments on commit c9b62db

Please sign in to comment.