Skip to content

Commit

Permalink
Merge branch 'main' into 313
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 5, 2024
2 parents 615f5fc + 6279faa commit bfdb356
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/prototype-tests-linux-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
gpu-arch-type: cuda
gpu-arch-version: "11.8"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/vision
runner: ${{ matrix.runner }}
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
gpu-arch-type: cuda
gpu-arch-version: "11.8"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/vision
runner: ${{ matrix.runner }}
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:
./.github/scripts/unittest.sh
onnx:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/vision
test-infra-ref: main
Expand Down Expand Up @@ -138,7 +138,7 @@ jobs:
echo '::endgroup::'
unittests-extended:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
if: contains(github.event.pull_request.labels.*.name, 'run-extended')
with:
repository: pytorch/vision
Expand Down
4 changes: 3 additions & 1 deletion packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ else
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly
fi

yum install -y libjpeg-turbo-devel libwebp-devel freetype gnutls
conda install libwebp -yq
conda install libjpeg-turbo -c pytorch
yum install -y freetype gnutls
pip install auditwheel
fi

Expand Down
2 changes: 1 addition & 1 deletion references/depth/stereo/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
y = torch.arange(kernel_size, dtype=torch.float32)
x = x - (kernel_size - 1) / 2
y = y - (kernel_size - 1) / 2
x, y = torch.meshgrid(x, y)
x, y = torch.meshgrid(x, y, indexing="ij")
grid = (x**2 + y**2) / (2 * sigma**2)
kernel = torch.exp(-grid)
kernel = kernel / kernel.sum()
Expand Down
4 changes: 2 additions & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def main(args):

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path, weights_only=True)
dataset, _ = torch.load(cache_path, weights_only=False)
dataset.transform = transform_train
else:
if args.distributed:
Expand Down Expand Up @@ -201,7 +201,7 @@ def main(args):

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path, weights_only=True)
dataset_test, _ = torch.load(cache_path, weights_only=False)
dataset_test.transform = transform_test
else:
if args.distributed:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from copy import deepcopy
from itertools import chain
from typing import Mapping, Sequence

Expand Down Expand Up @@ -322,3 +323,14 @@ def forward(self, x):
out = model(self.inp)
# And backward
out["leaf_module"].float().mean().backward()

def test_deepcopy(self):
# Non-regression test for https://github.com/pytorch/vision/issues/8634
model = models.efficientnet_b3(weights=None)
extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})

extractor.eval()
extractor.train()
extractor = deepcopy(extractor)
extractor.eval()
extractor.train()
2 changes: 2 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
IS_MACOS = sys.platform == "darwin"
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
# See https://github.com/pytorch/vision/pull/8724#issuecomment-2503964558
ROCM_WEBP_MESSAGE = "ROCM not built with webp support."

# Hacky way of figuring out whether we compiled with libavif/libheif (those are
# currenlty disabled by default)
Expand Down
15 changes: 15 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def test_draw_boxes():
assert_equal(img, img_cp)


@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
def test_draw_boxes_with_coloured_labels():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors)

path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png"
)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)


@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
Expand Down
5 changes: 4 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#if WEBP_FOUND
#include "webp/decode.h"
#include "webp/types.h"
#endif // WEBP_FOUND

namespace vision {
Expand Down Expand Up @@ -44,10 +45,12 @@ torch::Tensor decode_webp(

auto decoded_data =
decoding_func(encoded_data_p, encoded_data_size, &width, &height);

TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");

auto deleter = [decoded_data](void*) { WebPFree(decoded_data); };
auto out = torch::from_blob(
decoded_data, {height, width, num_channels}, torch::kUInt8);
decoded_data, {height, width, num_channels}, deleter, torch::kUInt8);

return out.permute({2, 0, 1});
}
Expand Down
22 changes: 16 additions & 6 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
install PyAV on your system.
"""
)
try:
FFmpegError = av.FFmpegError # from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rst
except AttributeError:
FFmpegError = av.AVError
except ImportError:
av = ImportError(
"""\
Expand Down Expand Up @@ -155,7 +159,13 @@ def write_video(

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
try:
frame.pict_type = "NONE"
except TypeError:
from av.video.frame import PictureType # noqa

frame.pict_type = PictureType.NONE

for packet in stream.encode(frame):
container.mux(packet)

Expand Down Expand Up @@ -215,7 +225,7 @@ def _read_from_stream(
try:
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError:
except FFmpegError:
# TODO add some warnings in this case
# print("Corrupted file?", container.name)
return []
Expand All @@ -228,7 +238,7 @@ def _read_from_stream(
buffer_count += 1
continue
break
except av.AVError:
except FFmpegError:
# TODO add a warning
pass
# ensure that the results are sorted wrt the pts
Expand Down Expand Up @@ -350,7 +360,7 @@ def read_video(
)
info["audio_fps"] = container.streams.audio[0].rate

except av.AVError:
except FFmpegError:
# TODO raise a warning?
pass

Expand Down Expand Up @@ -441,10 +451,10 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in
video_time_base = video_stream.time_base
try:
pts = _decode_video_timestamps(container)
except av.AVError:
except FFmpegError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
except av.AVError as e:
except FFmpegError as e:
msg = f"Failed to open container for {filename}; Caught error: {e}"
warnings.warn(msg, RuntimeWarning)

Expand Down
37 changes: 36 additions & 1 deletion torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import math
import re
Expand All @@ -10,7 +11,7 @@
import torch
import torchvision
from torch import fx, nn
from torch.fx.graph_module import _copy_attr
from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY


__all__ = ["create_feature_extractor", "get_graph_node_names"]
Expand Down Expand Up @@ -330,6 +331,40 @@ def train(self, mode=True):
self.graph = self.eval_graph
return super().train(mode=mode)

def _deepcopy_init(self):
# See __deepcopy__ below
return DualGraphModule.__init__

def __deepcopy__(self, memo):
# Same as the base class' __deepcopy__ from pytorch, with minor
# modification to account for train_graph and eval_graph
# https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875
#
# This is using a bunch of private stuff from torch, so if that breaks,
# we'll likely have to remove this, along with the associated
# non-regression test.
res = type(self).__new__(type(self))
memo[id(self)] = res
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"])

extra_preserved_attrs = [
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
"_replace_hook",
"_create_node_hooks",
"_erase_node_hooks",
]
for attr in extra_preserved_attrs:
if attr in self.__dict__:
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
setattr(res, attr_name, attr)
return res


def create_feature_extractor(
model: nn.Module,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/maxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List


def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)], indexing="ij"))
coords_flat = torch.flatten(coords, 1)
relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
Expand Down
12 changes: 10 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def draw_bounding_boxes(
width: int = 1,
font: Optional[str] = None,
font_size: Optional[int] = None,
label_colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
) -> torch.Tensor:

"""
Expand All @@ -184,9 +185,12 @@ def draw_bounding_boxes(
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
`colors` argument for details. Defaults to the same colors used for the boxes.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""
import torchvision.transforms.v2.functional as F # noqa

Expand Down Expand Up @@ -219,6 +223,10 @@ def draw_bounding_boxes(
)

colors = _parse_colors(colors, num_objects=num_boxes)
if label_colors:
label_colors = _parse_colors(label_colors, num_objects=num_boxes) # type: ignore[assignment]
else:
label_colors = colors.copy() # type: ignore[assignment]

if font is None:
if font_size is not None:
Expand All @@ -243,7 +251,7 @@ def draw_bounding_boxes(
else:
draw = ImageDraw.Draw(img_to_draw)

for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
if fill:
fill_color = color + (100,)
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
Expand All @@ -252,7 +260,7 @@ def draw_bounding_boxes(

if label is not None:
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]

out = F.pil_to_tensor(img_to_draw)
if original_dtype.is_floating_point:
Expand Down

0 comments on commit bfdb356

Please sign in to comment.