-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pytorch] Add support for aten::linalg_vector_norm
#16123
Conversation
b4e0e40
to
f2e6107
Compare
@@ -3844,6 +3844,30 @@ def inplace_copy(self, inputs, input_types): | |||
# Return | |||
return _op.scatter_nd(source, indices, values, mode) | |||
|
|||
def linalg_vector_norm(self, inputs, input_types): | |||
data = inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this method is based on torch.linalg.vector_norm. The latter assums that input data type is float, double or complex one; dtype also should be real or complex. Would you check it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review. That's true. It's based on torch.linalg.vector_norm and it supports float, double and complex dtypes as input. Seems like PyTorch doesn't support complex data type. convert_pt_to_tvm_type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mshr-h! My idea was to add something like assert data.dtype == float or data.dtype == double
. And may be add TODO for further support complex values, but I do not think it is needed just now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vvchernov ! I've added assertion and testcases for double-precision input data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work! LGTM, see my small comment
|
||
class VectorNorm5(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.linalg.vector_norm(x, ord=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you only need different ord
for each test case, you don't need to create a different class for each. Please clean them up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@masahi Revised as suggested. Please review. Thanks!
Fix #16096
Support torch.linalg.vector_norm.
cc @jikechao @vvchernov @Hzfengsy @junrushao