Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make v2 transforms authoring public #8787

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader
Developer tools
^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Transform

.. autosummary::
:toctree: generated/
:template: function.rst

v2.functional.register_kernel
v2.query_size
v2.query_chw
v2.get_bounding_boxes


V1 API Reference
Expand Down
117 changes: 98 additions & 19 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"""

# %%
from typing import Any, Dict, List

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
Expand Down Expand Up @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's
# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`:

structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# Basics: override the `transform()` method
# -----------------------------------------
#
# In order to support arbitrary inputs in your custom transform, you will need
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
# `.transform()` method (not the `forward()` method!). Below is a basic example:


class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation


my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all TVTensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
# An important thing to note is that when we call ``my_custom_transform`` on
# ``structured_input``, the input is flattened and then each individual part is
# passed to ``transform()``. That is, ``transform()``` receives the input image,
# then the bounding boxes, etc. Within ``transform()``, you can decide how to
# transform each input, based on their type.
#
# If you're curious why the other tensor (``torch.arange()``) didn't get passed
# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more
# details.
#
# Advanced: The ``make_params()`` method
# --------------------------------------
#
# The ``make_params()`` method is called internally before calling
# ``transform()`` on each input. This is typically useful to generate random
# parameter values. In the example below, we use it to randomly apply the
# transformation with a probability of 0.5


class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()

def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params

def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)


my_random_transform = MyRandomTransform()

torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform

# %%
#
# .. note::
#
# It's important for such random parameter generation to happen within
# ``make_params()`` and not within ``transform()``, so that for a given
# transform call, the same RNG applies to all the inputs in the same way. If
# we were to perform the RNG within ``transform()``, we would risk e.g.
# transforming the image while *not* transforming the bounding boxes.
#
# The ``make_params()`` method takes the list of all the inputs as parameter
# (each of the elements in this list will later be pased to ``transform()``).
# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
# using :func:`~torchvision.transforms.v2.query_chw` or
# :func:`~torchvision.transforms.v2.query_size`.
#
# ``make_params()`` should return a dict (or actually, anything you want) that
# will then be passed to ``transform()``.
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def __init__(self, size, fill=0):
self.size = size
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
def make_params(self, sample):
_, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
def transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

Expand Down
8 changes: 4 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test__copy_paste(self, label_type):


class TestFixedSizeCrop:
def test__get_params(self, mocker):
def test_make_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
canvas_size = (11, 5)
Expand All @@ -170,7 +170,7 @@ def test__get_params(self, mocker):
make_image(size=canvas_size, color_space="RGB"),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
]
params = transform._get_params(flat_inputs)
params = transform.make_params(flat_inputs)

assert params["needs_crop"]
assert params["height"] <= crop_size[0]
Expand All @@ -191,7 +191,7 @@ def test__transform_culling(self, mocker):

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
return_value=dict(
needs_crop=True,
top=0,
Expand Down Expand Up @@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker):
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
return_value=dict(
needs_crop=True,
top=0,
Expand Down
Loading
Loading