Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pytorch] Add support for aten::linalg_vector_norm #16123

Merged
merged 6 commits into from
Nov 26, 2023

Conversation

mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Nov 14, 2023

@github-actions github-actions bot requested a review from Hzfengsy November 14, 2023 02:43
@mshr-h mshr-h force-pushed the pytorch-linalg_vector_norm branch from b4e0e40 to f2e6107 Compare November 14, 2023 04:19
@@ -3844,6 +3844,30 @@ def inplace_copy(self, inputs, input_types):
# Return
return _op.scatter_nd(source, indices, values, mode)

def linalg_vector_norm(self, inputs, input_types):
data = inputs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is based on torch.linalg.vector_norm. The latter assums that input data type is float, double or complex one; dtype also should be real or complex. Would you check it?

Copy link
Contributor Author

@mshr-h mshr-h Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. That's true. It's based on torch.linalg.vector_norm and it supports float, double and complex dtypes as input. Seems like PyTorch doesn't support complex data type. convert_pt_to_tvm_type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! My idea was to add something like assert data.dtype == float or data.dtype == double. And may be add TODO for further support complex values, but I do not think it is needed just now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vvchernov ! I've added assertion and testcases for double-precision input data.

Copy link
Contributor

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work! LGTM, see my small comment

@mshr-h mshr-h marked this pull request as ready for review November 15, 2023 12:50
@github-actions github-actions bot requested a review from junrushao November 20, 2023 01:40

class VectorNorm5(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you only need different ord for each test case, you don't need to create a different class for each. Please clean them up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi Revised as suggested. Please review. Thanks!

@masahi masahi merged commit 3fd3a63 into apache:main Nov 26, 2023
@mshr-h mshr-h deleted the pytorch-linalg_vector_norm branch November 27, 2023 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants