Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support flatten and reshape via shuffle_layer #2354

Merged
merged 11 commits into from
Oct 6, 2023

Conversation

zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Sep 29, 2023

Description

Support flatten and reshape via shuffle_layer

Fixes #2214

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Sep 29, 2023
@github-actions github-actions bot requested a review from peri044 September 29, 2023 03:58
@zewenli98
Copy link
Collaborator Author

@bowang007

@zewenli98
Copy link
Collaborator Author

When I test flatten via torch.flatten(inputs, start_dim, end_dim), got error AssertionError: False is not true : expected ops {<OpOverload(op='aten.flatten', overload='using_ints')>}, actuall ops {<OpOverload(op='aten.view', overload='default')>, <OpOverloadPacket(op='aten.sym_size')>, <built-in function mul>}

@zewenli98 zewenli98 self-assigned this Sep 29, 2023
@zewenli98 zewenli98 requested review from gs-olive and removed request for peri044 September 29, 2023 17:15
@github-actions github-actions bot requested a review from peri044 September 29, 2023 18:53
@zewenli98 zewenli98 force-pushed the shuffle_dynamo_converter branch from b414928 to d212e51 Compare September 30, 2023 01:37
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason - performance or otherwise, why flatten should need a different implementation than reshape, when using static shapes? Specifically, we can comment out the flatten implementation for now, and for any converters needing flatten for static shapes, they can just use a reshape and flatten the dimensions themselves.

As an alternative, @zewenli98, you can add a utility flatten_dims, which will flatten the dimensions of an input tensor into a reshape-usable form, then you can have @bowang007's converter test that utility.

@zewenli98
Copy link
Collaborator Author

Is there a specific reason - performance or otherwise, why flatten should need a different implementation than reshape, when using static shapes?

Thanks for the advice! I did this because I noticed there's a flatten op in this schema. I thought our goal is to support these native_functions as much as possible. But anyways, I'll comment out the flatten converter and then wrap the reshape to do flatten op.

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 3, 2023

Generally, the focus is to cover as much of this operation set as possible: https://pytorch.org/docs/stable/ir.html#core-aten-ir, though if there are operators that show up which we can directly convert as opposed to lowering, that is certainly a good thing to have.

@zewenli98 zewenli98 requested a review from gs-olive October 3, 2023 22:42
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do still think flatten_dims can be a utility which gives the shape to pass to reshape. That way, it can get tested as a utility and not as a converter (see tests/py/dynamo/conversion/test_converter_utils.py). Added a suggestion on syntax.

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py Outdated Show resolved Hide resolved
@zewenli98
Copy link
Collaborator Author

zewenli98 commented Oct 4, 2023

I do still think flatten_dims can be a utility which gives the shape to pass to reshape.

I tried implementing flatten_dims in converter_utils.py. Since flatten_dims needs to call reshape, it caused circular import. That's why I moved to shuffle file. Anyways, I will rewrite in flatten_dims.

@zewenli98 zewenli98 requested a review from gs-olive October 4, 2023 04:34
@gs-olive
Copy link
Collaborator

gs-olive commented Oct 4, 2023

@zewenli98 I see - thanks for the details - to clarify, I was intending for flatten_dims to not change the network at all since changing the network requires the function to be in impl/. I was thinking it could be instead something like:

def flatten_dims(
    input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
    start_dim: int,
    end_dim: int,
) -> Tuple[int]:
    shape = input.shape
    dim_size = len(shape)
    start_dim = get_positive_dim(start_dim, dim_size)
    end_dim = get_positive_dim(end_dim, dim_size)

    if not isinstance(input, TRTTensor):
        input = get_trt_tensor(ctx, input, f"{name}_flatten")

    num_elements = 1
    for i in range(start_dim, end_dim + 1):
        num_elements *= shape[i]

    new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

    return new_shape

Then, the user can call flatten_dims to get back the flattened dimension shape, which they can then pass to reshape themselves.

@zewenli98
Copy link
Collaborator Author

@gs-olive Thanks a lot! This makes more sense. Modified!

@zewenli98 zewenli98 force-pushed the shuffle_dynamo_converter branch from 5847be9 to d48d611 Compare October 4, 2023 22:44
@zewenli98 zewenli98 force-pushed the shuffle_dynamo_converter branch from 4fc3a8b to 9846f23 Compare October 5, 2023 20:34
@@ -65,6 +68,7 @@ def forward(self, x):
self.run_test_with_dynamic_shape(
TestModule(target_shape),
input_specs,
expected_ops={torch.ops.aten.view.default},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed, in accordance with the new testing PR

@zewenli98 zewenli98 requested a review from gs-olive October 6, 2023 20:07
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, pending CI pass.

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 6, 2023

Relevant tests pass locally on Torch 2.1.0. Merging to main.

@gs-olive gs-olive merged commit d375d10 into pytorch:main Oct 6, 2023
12 of 14 checks passed
@zewenli98
Copy link
Collaborator Author

@bowang007 Please consult this PR for the shuffle op.

@bowang007
Copy link
Collaborator

@zewenli98 Thanks! Let me update PR accordingly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests priority: high
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose IShuffleLayer in dynamo.conversion.impl
4 participants