Skip to content

Commit

Permalink
Add meta function for channel_shuffle operation (#123033)
Browse files Browse the repository at this point in the history
This commit introduces a meta function for the `channel_shuffle` operation, enabling PyTorch to perform shape inference and optimizations related to this operation without actual computation. The meta function assumes input shape (*, C, H, W) and validates that the number of channels (C) is divisible by the specified number of groups.

Fixes #122771

Pull Request resolved: #123033
Approved by: https://github.com/ezyang, https://github.com/mikaylagawarecki
  • Loading branch information
Episkey0109 authored and pytorchmergebot committed Apr 11, 2024
1 parent 84580f7 commit 02b29e7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
16 changes: 16 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6161,6 +6161,22 @@ def meta_polygamma(n: int, self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype)


@register_meta(aten.channel_shuffle.default)
def meta_channel_shuffle(input, groups):
# Assume the input shape is (*, C, H, W), where * represents any number of leading dimensions
*leading_dims, C, H, W = input.size()
# The output shape is the same as the input
return torch.empty(
*leading_dims,
C,
H,
W,
dtype=input.dtype,
layout=input.layout,
device=input.device,
)


def _create_unary_float_meta_func(func):
@register_meta(func)
@out_wrapper()
Expand Down
33 changes: 33 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8828,6 +8828,20 @@ def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwarg
]
)

def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

shapes_groups = [
((1, 4, 10, 10), 2),
((2, 6, 8, 8), 3),
((2, 8, 5, 5), 4),
]

yield from (
SampleInput(make_arg(shape), args=(groups,))
for shape, groups in shapes_groups
)

def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
# Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps
Expand Down Expand Up @@ -19610,6 +19624,25 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
),
),
OpInfo(
"nn.functional.channel_shuffle",
sample_inputs_func=sample_inputs_channel_shuffle,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
backward_dtypes=integral_types_and(torch.bool),
supports_out=False,
supports_autograd=False,
allow_cow_input_materialize_forward=[0],
skips=(
# Skip due to NotImplementedError for MPS device.
DecorateInfo(unittest.expectedFailure, 'TestConsistency'),
# vmap: calling random operator not supported
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
),
),
OpInfo(
"nn.functional.kl_div",
sample_inputs_func=sample_inputs_kl_div,
Expand Down

0 comments on commit 02b29e7

Please sign in to comment.