Skip to content

Commit

Permalink
Revert #75195
Browse files Browse the repository at this point in the history
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

ghstack-source-id: f127f9be33ba35ceeabbc6a9a4d8c24654defad7
Pull Request resolved: #82504
  • Loading branch information
zou3519 committed Aug 1, 2022
1 parent 6592259 commit b2bc557
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b3342319e96a0becd139019620d8665605b78475
73c64a55fb096f1e132029d3decbb6f4e532cc7b
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,8 @@ Tensor _matmul_impl(
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1);
return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
} else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) {
Expand Down
28 changes: 16 additions & 12 deletions test/test_profiler_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
aten::transpose
aten::as_strided
aten::matmul
aten::t
aten::transpose
aten::as_strided
aten::mv
aten::empty
aten::addmv_
aten::unsqueeze
aten::as_strided
aten::mm
aten::resolve_conj
aten::resolve_conj
aten::resolve_conj
aten::squeeze_
aten::as_strided_
aten::add_
nn.Module: ReLU_1
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
Expand Down Expand Up @@ -576,12 +578,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
aten::transpose
aten::as_strided
aten::matmul
aten::t
aten::transpose
aten::as_strided
aten::mv
aten::empty
aten::addmv_
aten::unsqueeze
aten::as_strided
aten::mm
aten::resolve_conj
aten::resolve_conj
aten::resolve_conj
aten::squeeze_
aten::as_strided_
aten::add_
nn.Module: ReLU_1
<built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13459,6 +13459,8 @@ def error_inputs_mean(op_info, device, **kwargs):
'TestCommon', 'test_noncontiguous_samples',
device_type='cpu'), ],
skips=(
# Strides are not the same!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
# https://github.com/pytorch/pytorch/issues/67470
DecorateInfo(unittest.skip("67470!"),
'TestCommon', 'test_noncontiguous_samples',
Expand Down Expand Up @@ -14861,6 +14863,8 @@ def error_inputs_mean(op_info, device, **kwargs):
check_batched_forward_grad=False,
supports_expanded_weight=True,
decorators=(
# Strides are not the same!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
)),
Expand Down

0 comments on commit b2bc557

Please sign in to comment.