Skip to content

Commit

Permalink
Make v2 transforms authoring public (#8787)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 9, 2024
1 parent 48f01de commit a9a726a
Show file tree
Hide file tree
Showing 18 changed files with 260 additions and 148 deletions.
9 changes: 9 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader
Developer tools
^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Transform

.. autosummary::
:toctree: generated/
:template: function.rst

v2.functional.register_kernel
v2.query_size
v2.query_chw
v2.get_bounding_boxes


V1 API Reference
Expand Down
117 changes: 98 additions & 19 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"""

# %%
from typing import Any, Dict, List

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
Expand Down Expand Up @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's
# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`:

structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# Basics: override the `transform()` method
# -----------------------------------------
#
# In order to support arbitrary inputs in your custom transform, you will need
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
# `.transform()` method (not the `forward()` method!). Below is a basic example:


class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation


my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all TVTensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
# An important thing to note is that when we call ``my_custom_transform`` on
# ``structured_input``, the input is flattened and then each individual part is
# passed to ``transform()``. That is, ``transform()``` receives the input image,
# then the bounding boxes, etc. Within ``transform()``, you can decide how to
# transform each input, based on their type.
#
# If you're curious why the other tensor (``torch.arange()``) didn't get passed
# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more
# details.
#
# Advanced: The ``make_params()`` method
# --------------------------------------
#
# The ``make_params()`` method is called internally before calling
# ``transform()`` on each input. This is typically useful to generate random
# parameter values. In the example below, we use it to randomly apply the
# transformation with a probability of 0.5


class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()

def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params

def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)


my_random_transform = MyRandomTransform()

torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform

# %%
#
# .. note::
#
# It's important for such random parameter generation to happen within
# ``make_params()`` and not within ``transform()``, so that for a given
# transform call, the same RNG applies to all the inputs in the same way. If
# we were to perform the RNG within ``transform()``, we would risk e.g.
# transforming the image while *not* transforming the bounding boxes.
#
# The ``make_params()`` method takes the list of all the inputs as parameter
# (each of the elements in this list will later be pased to ``transform()``).
# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
# using :func:`~torchvision.transforms.v2.query_chw` or
# :func:`~torchvision.transforms.v2.query_size`.
#
# ``make_params()`` should return a dict (or actually, anything you want) that
# will then be passed to ``transform()``.
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def __init__(self, size, fill=0):
self.size = size
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
def make_params(self, sample):
_, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
def transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

Expand Down
8 changes: 4 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test__copy_paste(self, label_type):


class TestFixedSizeCrop:
def test__get_params(self, mocker):
def test_make_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
canvas_size = (11, 5)
Expand All @@ -170,7 +170,7 @@ def test__get_params(self, mocker):
make_image(size=canvas_size, color_space="RGB"),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
]
params = transform._get_params(flat_inputs)
params = transform.make_params(flat_inputs)

assert params["needs_crop"]
assert params["height"] <= crop_size[0]
Expand All @@ -191,7 +191,7 @@ def test__transform_culling(self, mocker):

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
return_value=dict(
needs_crop=True,
top=0,
Expand Down Expand Up @@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker):
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
return_value=dict(
needs_crop=True,
top=0,
Expand Down
Loading

0 comments on commit a9a726a

Please sign in to comment.