-
Notifications
You must be signed in to change notification settings - Fork 516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOSA] Fix aten.view and aten.slice.tensor #1768
Conversation
e8021a4
to
694fc38
Compare
|
|
|
b0a5f34
to
7feb66a
Compare
4232ba5
to
e18993c
Compare
@@ -3061,13 +3079,18 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite( | |||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) | |||
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); | |||
|
|||
if (start < 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an e2e test that checks these changes?
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x1x2xf32> -> !torch.vtensor<[1,1,2],f32> | ||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,1,2],f32> | ||
// CHECK: } | ||
func.func @torch.aten.slice(%arg0: !torch.vtensor<[1,128,2],f32>) -> !torch.vtensor<[1,1,2],f32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed
@@ -2645,6 +2645,24 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite( | |||
return rewriter.notifyMatchFailure(op, | |||
"size must consist of Scalar constants"); | |||
|
|||
// # the size -1 is inferred from other dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to first make sure there is at most one -1
in the list and return a notifyMatchFailure
if that is not the case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
for (size_t i = 0; i < outShape.size(); i++) { | ||
if (outShape[i] < 0) { | ||
outShape[i] = totalSize / otherSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a break
after this line to make it very clear that this is expected to only runs once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (start < 0) | ||
return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); | ||
if (start < 0) { | ||
start = start + selfType.getShape()[dim]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to check that it is positive after this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
int64_t end; | ||
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) | ||
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); | ||
|
||
if (end <= 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct. For the end == 0
case, end
should remain zero. Make sure to also e2e test this edge case
Split the aten.view op into a new patch #1815 |
Find this error in nod-ai/SHARK-Studio#494
Deal with aten.view input -1
Deal with aten.slice start -1
slice op explaination: https://cran.r-project.org/web/packages/torch/vignettes/indexing.html