Skip to content

Commit

Permalink
Revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Dec 18, 2023
1 parent adf249d commit 3dce586
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def draw_bounding_boxes(
font_size (int): The requested font size in points.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with bounding boxes plotted.
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
Expand Down Expand Up @@ -346,7 +346,7 @@ def draw_keypoints(
width (int): Integer denoting width of line connecting keypoints.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 or float32 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
Expand Down Expand Up @@ -389,7 +389,7 @@ def draw_keypoints(
width=width,
)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=image.dtype)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
Expand Down

0 comments on commit 3dce586

Please sign in to comment.