-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian #989
Comments
Thanks for the report, we'll take a look soon |
Bisected to https://github.com/pytorch/pytorch/pull/75195/files. https://github.com/pytorch/pytorch/pull/75195/files by itself may not be a problem, perhaps the problem is our batching rule for mv. @yueyericardo is the repro you provided the entire model, or is it a subset of some model that you're running? |
cc @lezcano @ezyang for pytorch/pytorch#75195 -- this led to a performance regression in functorch. I'm not sure what the original intent of the PR is (there are no tests). I'm still trying to root cause this, but it is a bit difficult to visualize. What are the chances we could revert that PR? (I've confirmed that reverting that single PR on pytorch/pytorch master makes the performance regression go away) |
Great thanks to @zou3519 for the quick debugging!! Thanks again! |
@yueyericardo - for the original model itself, is the performance regression also 25%, or is it a smaller number? Is the original model public? One thing we can do to prevent future regressions is to check the original model into https://github.com/pytorch/benchmark. I've noticed a lot of other similar models where folks have a nn.Sequential that is just made up of nn.Linear and activations and need to compute a vmap(jacrev or vmap(hessian of the quantity, so we could also potentially just check your script into torchbench otherwise. |
Hi @zou3519 Our source code is free to download from the website, but it is not developed on GitHub. And our code base is also might too large to put into pytorch/benchmark. Yes, exactly! I believe the minimal repro I provided is enough to prevent future regression for our model. |
I think that rather than blindly reverting, we should get to the root of the problem, as it is very weird to get such a regression when dispatching from a more general function to a more concrete (that was the reason for that PR). Things that come to mind are:
If the answer to the above two is no, then this performance issue is likely on the functorch end and should be fixed. Otherwise, it's on the cuBLAS end and should be reported to NVIDIA |
@lezcano It's fair to submit the upstream bugs, but if we know that our upstream library's general kernel has better perf than a specialized one, we might as well use it. |
@lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems? Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else. |
No, I don't think so. The PR is supposed to make the kernel run faster. |
fwiw, I think this may be related to the open PR I have to avoid copies in matmul. Could you check whether pytorch/pytorch#76828 fixes this? In any case, I'm fine with reverting, but we should investigate what's causing this regardless |
That PR is fixing spurious resize warnings that were previously generated, and by itself is supposed to speedup things by avoiding squeeze/unsqueeze calls which are not free (and especially not free when autograd is needed). As for more general/more concrete function performance, we should investigate this, but I doubt that's the case. |
@zou3519 can you by any chance collect profiling results for old and new versions? |
I agree this warrants more investigation. We've got a problem in that there is a timeline for 1.12.1, and I am not sure how long it is going to take to actually get to the bottom of this.
I can try but it's been a long time since I touched nvprof or nsight profiler, so I will need to relearn the magic invocations.
Since we changed the mm to mv, functorch generates different code for the vmap(jacrev(jacfwd(mm)) as opposed to vmap(jacrev(jacfwd(mv)). It's plausible that the problem is that "functorch should generate better code"; we're still digging into it |
Isn't 1.12.1 done already? |
FYI, the nsight profiling result
after (the third row 18.4% on a copy kernel)
|
@ngimel the PR that fixed the warnings was already merged, this one is just concerned about avoiding copies. One of the cases where it elides a copy is when you multiply a matrix by a batch of matrices. This is exactly the batched version of the vector-matrix product that"s causing the regression. That's why I think it may fix it |
That must be coming from functorch, as that PR doesn't introduce any additional copies. |
yup, I think pytorch/pytorch#76828 would fix the regression |
I patched in pytorch/pytorch#76828 and the above script ends up OOM-ing :/ |
Not yet, we have a chance to change it (or request a 1.12.2 if necessary since this regression is large and these types of models are typical functorch usage) |
Re. OOM. Wow, that's certainly unexpected. I'm not sure what's the best way to follow up on that. Probably @ngimel has a better idea how to proceed. Regardless, landing that PR (the avoid copies matmul...) will be tricky due to some discrepancies in the accuracy of mm vs mv that we found for float16. As such, I think the most realistic way forward would be to revert the offending PR and investigate what's causing that OMM after 1.12.1 is released |
I don't think that's true, #75195 itself is fixing a warning (otherwise a user-supplied correct 1d out was sent to mm, and mm complained about resizing it to 2d). |
I think the one that fixed the out version was pytorch/pytorch#75197 and a previous one in that stack, but it may be the case that 75195 also fixed an out warning, I don't remember now. |
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Why is this a partial revert? - the out= tests for nn.functional.linear fail on a complete revert - the profiler tests fail on the revert (so I updated the expecttests for the profiler tests) Test Plan: - test offline that the functorch regression was fixed [ghstack-poisoned]
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Why is this a partial revert? - the out= tests for nn.functional.linear fail on a complete revert - the profiler tests fail on the revert (so I updated the expecttests for the profiler tests) Test Plan: - test offline that the functorch regression was fixed ghstack-source-id: 1332851f306c1354c705b0d1dfaa8a573233d024 Pull Request resolved: #82504
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed ghstack-source-id: f127f9be33ba35ceeabbc6a9a4d8c24654defad7 Pull Request resolved: #82504
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed [ghstack-poisoned]
@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?) |
Hi @zou3519, thanks for the following up! We are having some internal discussions regarding this and will come back to you tomorrow. |
From looking at this a bit, I think what happened is:
During backwards,
We can also validate that this is our issue since we hit pre-regression performance numbers by changing functorch's mv batch rule here: auto other_ = moveBatchDimToFront(other, other_bdim);
auto self_ = at::movedim(self, 0, 1);
auto result = at::matmul(other_, self_);
return std::make_tuple( std::move(result), 0 ); This doesn't trigger the copy since the batch dimension for the saved relu activation is now the first dimension. However, this may hit other perf issue from the transposing both self and result This is the smallest subset of @zou3519's trace where I can see the perf differences
notably, if If I change the views on either bmm or other_relu to be contiguous constants, it has much faster performance. So it seems like threshold_backwards doesn't copy if only one of its inputs is not contiguous but does if both are not |
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed [ghstack-poisoned]
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed [ghstack-poisoned]
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed ghstack-source-id: ea72d0f01b28fa298535233e818ca62180f1c3f5 Pull Request resolved: #82504
Thanks @samdow, copy from relu definitely seems to affect perf, however there's also another copy coming from MvBackward Also, why does threshold_backward on discontiguous inputs trigger a copy? In eager threshold_backward should be able to handle them via tensorIterator |
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed Pull Request resolved: #82504 Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
As discussed, I think that In my opinion, the path forward to fix this regression would be to:
|
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed Pull Request resolved: #82504 Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Test Plan: - test offline that the functorch regression was fixed Pull Request resolved: #82504 Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
Summary: This is a short-term fix for a serious regression in functorch (pytorch/functorch#989). Additional things this PR does: - the out= tests for nn.functional.linear fail after the revert. I added some xfails. These xfails were present in the original PR (#75195). - the profiler tests fail on the revert, so I updated the expecttests for the profiler tests Pull Request resolved: #82504 Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8f86e361918e3c8a0ee2b569be6c82dfbf32d705 Test plan from GitHub: - test offline that the functorch regression was fixed Reviewed By: kit1980 Differential Revision: D38394573 Pulled By: zou3519 fbshipit-source-id: f9185d9cb447fb439d8e402712f2f2617f73b8cc
### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: #56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: #82544 Approved by: https://github.com/soulitzer
Summary: ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: pytorch/pytorch#56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following pytorch/pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, #989, so we are using v0.1.1 for benchmark) zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc zasdfgbnm albanD X-link: pytorch/pytorch#82544 Approved by: https://github.com/soulitzer Reviewed By: seemethere Differential Revision: D38643340 fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
Summary: ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: #56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark) zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc zasdfgbnm albanD Pull Request resolved: #82544 Approved by: https://github.com/soulitzer Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/382ef1fda75dfce07c37a920e908ce96d99bf970 Reviewed By: seemethere Differential Revision: D38643340 fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
@yueyericardo quick bump on the above. We're looking to merge some form of the above code into our benchmark suite and additional information would be very helpful |
@zou3519 Sorry, I already finished my internship there. I believe the next release of Modulus (this month) will include the functorch integration and will be available on GitLab. |
For what is worth, the fix pytorch/pytorch#76828 is completely stalled on some errors on |
@zou3519 The modulus version with functorch support was made public last week. Here is the repo.
The specific place where functorch is used is in this wrapper class.
Yes, after computing the hessian, we need to run |
@akshaysubr provided links do not work, maybe project access settings are not correctly set. |
@IvanYashchuk Sorry, should've mentioned that you have to apply for access to that repo here and after that those links should work. |
I was able to add a version of the original benchmark to pytorch/benchmark so we now can prevent regressions in this model, so I'm lowering the priority of this issue. Leaving it open to discuss the changes to matmul above though. |
fwiw, the |
After the patching in the one-liner pytorch/pytorch#75195 on top of the stack pytorch/pytorch#76828, I still get an OOM using the script in the OP.
so that tensor that gets to |
Hi developers,
After I upgraded functorch from
v0.1.1
to0.2.0
, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.Please let me know if I did anything wrong, and also whether the perf regression could be fixed.
Thanks!
Benchmark result
Benchmark result on NVIDIA A100
Benchmark result on NVIDIA A6000
benchmark script
benchmark.py
The text was updated successfully, but these errors were encountered: