Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian #989

Open
yueyericardo opened this issue Jul 28, 2022 · 40 comments

Comments

@yueyericardo
Copy link
Contributor

yueyericardo commented Jul 28, 2022

Hi developers,

After I upgraded functorch from v0.1.1 to 0.2.0, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.

Please let me know if I did anything wrong, and also whether the perf regression could be fixed.
Thanks!

Benchmark result

Benchmark result on NVIDIA A100

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 0.00e+00
max hessian    error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms  (0.75 X)

Benchmark result on NVIDIA A6000

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.86e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)

benchmark script

benchmark.py

import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = False


_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2  # x, y
D2 = 3  # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False

model = nn.Sequential(
    nn.Linear(D1, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, D2),
).to(device)


def predict(x):
    torch.cuda.nvtx.range_push("forward")
    out = model(x)
    torch.cuda.nvtx.range_pop()
    return out, out  # return two outputs is needed for jacrev auxiliary object


def reference_hessian():
    x_ = x.clone().requires_grad_()
    ones = torch.ones(B, device=x.device)
    pred, _ = predict(x_)
    jacobian_rows = [None] * D2
    hessian_rows = [None] * (D2 * D1)
    for i in range(D2):
        torch.cuda.nvtx.range_push("autograd jacobian")
        jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
            0
        ]
        torch.cuda.nvtx.range_pop()

    for i in range(D2):
        for j in range(D1):
            torch.cuda.nvtx.range_push("autograd hesian")
            hessian_rows[i * D1 + j] = torch.autograd.grad(
                jacobian_rows[i][:, j], x_, ones, create_graph=True
            )[0]
            torch.cuda.nvtx.range_pop()

    jacobian = torch.stack(jacobian_rows)  # [D2, B, D1]
    hessian = torch.stack(hessian_rows)  # [D2 * D1, B, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian.transpose(0, 1), pred


def functorch_hessian():
    x_ = x.clone().requires_grad_()
    hessian, pred = vmap(
        jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
        in_dims=0,
    )(
        x_
    )  # [B, D2, D1, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian, pred


def validate_result():
    # test functorch result
    ref_hes, ref_pred = reference_hessian()
    ft_hes, ft_pred = functorch_hessian()
    ref_hes = ref_hes.view_as(ft_hes)
    print(f"max pred       error: functorch: {(ref_pred - ft_pred).max():.2e}")
    print(f"max hessian    error: functorch: {(ref_hes - ft_hes).max():.2e}")


def benchmark(func):
    N = 20

    torch.cuda.synchronize()
    start = time.time()

    for i in range(N):
        torch.cuda.nvtx.range_push(func.__name__)
        _ = func()
        torch.cuda.nvtx.range_pop()

    torch.cuda.synchronize()
    time_ms = ((time.time() - start) / N) * 1000
    print(f"{func.__name__}: {time_ms:.3f} ms")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--backward", default=False, action="store_true")
    args = parser.parse_args()
    if args.backward:
        run_backward = True
        print("===== benchmark with backward =====")
    else:
        print("===== benchmark without backward =====")

    validate_result()

    # warm up
    for i in range(10):
        reference_hessian()
        functorch_hessian()

    # benchmark hessian
    benchmark(reference_hessian)
    benchmark(functorch_hessian)
@yueyericardo
Copy link
Contributor Author

ping @samdow @zou3519

@zou3519 zou3519 added the high priority These issues are at the top of mind for us. label Jul 28, 2022
@zou3519
Copy link
Contributor

zou3519 commented Jul 28, 2022

Thanks for the report, we'll take a look soon

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

Bisected to https://github.com/pytorch/pytorch/pull/75195/files. https://github.com/pytorch/pytorch/pull/75195/files by itself may not be a problem, perhaps the problem is our batching rule for mv.

@yueyericardo is the repro you provided the entire model, or is it a subset of some model that you're running?

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

cc @lezcano @ezyang for pytorch/pytorch#75195 -- this led to a performance regression in functorch. I'm not sure what the original intent of the PR is (there are no tests). I'm still trying to root cause this, but it is a bit difficult to visualize. What are the chances we could revert that PR?

(I've confirmed that reverting that single PR on pytorch/pytorch master makes the performance regression go away)

@yueyericardo
Copy link
Contributor Author

Great thanks to @zou3519 for the quick debugging!!
I'm working on NVIDIA Modulus project, we are using functorch because it provides a lot of perfs for the Jacobian and Hessian calculations.
The minimal repro I provided is only a subset of our model to demonstrate the performance regression.

Thanks again!

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

@yueyericardo - for the original model itself, is the performance regression also 25%, or is it a smaller number? Is the original model public? One thing we can do to prevent future regressions is to check the original model into https://github.com/pytorch/benchmark.

I've noticed a lot of other similar models where folks have a nn.Sequential that is just made up of nn.Linear and activations and need to compute a vmap(jacrev or vmap(hessian of the quantity, so we could also potentially just check your script into torchbench otherwise.

@yueyericardo
Copy link
Contributor Author

yueyericardo commented Jul 29, 2022

Hi @zou3519 Our source code is free to download from the website, but it is not developed on GitHub. And our code base is also might too large to put into pytorch/benchmark.

Yes, exactly! I believe the minimal repro I provided is enough to prevent future regression for our model.
Thanks!

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

I think that rather than blindly reverting, we should get to the root of the problem, as it is very weird to get such a regression when dispatching from a more general function to a more concrete (that was the reason for that PR).

Things that come to mind are:

  • Is this regression architecture / cublas version - dependent?
  • Is this regresion also happening on regular matmul for that codepath?

If the answer to the above two is no, then this performance issue is likely on the functorch end and should be fixed. Otherwise, it's on the cuBLAS end and should be reported to NVIDIA

cc @ngimel @xwang233 @IvanYashchuk

@ezyang
Copy link
Contributor

ezyang commented Jul 29, 2022

@lezcano It's fair to submit the upstream bugs, but if we know that our upstream library's general kernel has better perf than a specialized one, we might as well use it.

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

@lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

@ezyang
Copy link
Contributor

ezyang commented Jul 29, 2022

No, I don't think so. The PR is supposed to make the kernel run faster.

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

fwiw, I think this may be related to the open PR I have to avoid copies in matmul. Could you check whether pytorch/pytorch#76828 fixes this?

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

@ngimel
Copy link

ngimel commented Jul 29, 2022

@lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

That PR is fixing spurious resize warnings that were previously generated, and by itself is supposed to speedup things by avoiding squeeze/unsqueeze calls which are not free (and especially not free when autograd is needed). As for more general/more concrete function performance, we should investigate this, but I doubt that's the case.

@ngimel
Copy link

ngimel commented Jul 29, 2022

@zou3519 can you by any chance collect profiling results for old and new versions?

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

I agree this warrants more investigation. We've got a problem in that there is a timeline for 1.12.1, and I am not sure how long it is going to take to actually get to the bottom of this.

@zou3519 can you by any chance collect profiling results for old and new versions?

I can try but it's been a long time since I touched nvprof or nsight profiler, so I will need to relearn the magic invocations.

but we should investigate what's causing this regardless

Since we changed the mm to mv, functorch generates different code for the vmap(jacrev(jacfwd(mm)) as opposed to vmap(jacrev(jacfwd(mv)). It's plausible that the problem is that "functorch should generate better code"; we're still digging into it

@ngimel
Copy link

ngimel commented Jul 29, 2022

Isn't 1.12.1 done already?
No nsight needed, just torch profiler should be enough. (with torch.profiler.profile as p(): and print key averages and export chrome trace in the end).

@yueyericardo
Copy link
Contributor Author

FYI, the nsight profiling result
before

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
64.7%	928.624 ms	1000	928.623 μs	867.621 μs	297.345 μs	1.686 ms	493.406 μs	ampere_sgemm_128x64_nn
10.3%	147.304 ms	250	589.217 μs	589.187 μs	588.643 μs	590.435 μs	275 ns	ampere_sgemm_128x128_tn
9.4%	134.966 ms	1300	103.819 μs	70.432 μs	5.344 μs	194.209 μs	75.470 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, at::native::AddFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
6.8%	98.146 ms	900	109.050 μs	94.848 μs	70.081 μs	177.921 μs	43.651 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
5.4%	77.550 ms	350	221.572 μs	300.321 μs	18.080 μs	302.050 μs	124.914 μs	ampere_sgemm_128x64_tn
0.8%	10.981 ms	1450	7.572 μs	3.552 μs	2.880 μs	26.112 μs	7.973 μs	void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.7%	9.949 ms	150	66.324 μs	82.928 μs	27.968 μs	88.672 μs	26.863 μs	ampere_sgemm_32x32_sliced1x4_nn
0.6%	8.074 ms	300	26.912 μs	26.976 μs	25.696 μs	27.776 μs	390 ns	void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.5%	7.713 ms	50	154.257 μs	154.177 μs	153.184 μs	155.905 μs	634 ns	ampere_sgemm_32x128_nn
0.3%	4.847 ms	100	48.468 μs	47.664 μs	34.688 μs	63.393 μs	13.397 μs	ampere_sgemm_32x32_sliced1x4_tn
0.2%	3.342 ms	650	5.141 μs	5.344 μs	4.256 μs	5.728 μs	423 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
0.1%	2.028 ms	50	40.569 μs	40.512 μs	40.032 μs	41.536 μs	272 ns	ampere_sgemm_64x64_nn
0.1%	994.305 μs	50	19.886 μs	19.856 μs	19.649 μs	20.896 μs	169 ns	ampere_sgemm_128x32_nn
0.0%	435.235 μs	100	4.352 μs	4.384 μs	4.160 μs	4.576 μs	100 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)

after (the third row 18.4% on a copy kernel)

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
21.8%	433.409 ms	500	866.818 μs	866.756 μs	865.028 μs	870.116 μs	881 ns	ampere_sgemm_32x128_nt
21.2%	420.235 ms	250	1.681 ms	1.681 ms	1.678 ms	1.685 ms	813 ns	ampere_sgemm_128x64_nn
18.4%	365.797 ms	2850	128.349 μs	83.809 μs	4.096 μs	608.931 μs	173.023 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
10.8%	213.750 ms	300	712.501 μs	588.802 μs	588.195 μs	1.335 ms	277.174 μs	ampere_sgemm_128x128_tn
8.1%	160.586 ms	1300	123.527 μs	93.057 μs	5.089 μs	266.497 μs	92.181 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
7.9%	156.152 ms	500	312.304 μs	312.290 μs	310.145 μs	314.689 μs	781 ns	ampere_sgemm_128x32_nn
5.9%	117.864 ms	900	130.960 μs	129.824 μs	77.504 μs	203.681 μs	49.220 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::<unnamed>::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
3.3%	66.223 ms	50	1.324 ms	1.324 ms	1.324 ms	1.326 ms	397 ns	ampere_sgemm_128x128_tt
1.1%	21.094 ms	1850	11.402 μs	3.584 μs	2.816 μs	65.857 μs	12.662 μs	void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.4%	8.071 ms	300	26.904 μs	26.976 μs	25.792 μs	28.128 μs	392 ns	void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.4%	7.903 ms	50	158.063 μs	158.049 μs	156.609 μs	159.617 μs	634 ns	ampere_sgemm_32x128_nn
0.4%	7.359 ms	100	73.594 μs	73.600 μs	72.000 μs	78.368 μs	1.223 μs	ampere_sgemm_32x32_sliced1x4_nt
0.2%	3.242 ms	50	64.844 μs	64.928 μs	62.464 μs	65.761 μs	528 ns	ampere_sgemm_32x32_sliced1x4_tn
0.1%	2.255 ms	100	22.551 μs	22.688 μs	20.865 μs	24.192 μs	1.106 μs	void gemmSN_NN_kernel<float, (int)256, (int)4, (int)2, (int)8, (int)3, (int)4, (bool)0, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<T9, T10, T11, T1>)
0.1%	2.043 ms	100	20.427 μs	20.416 μs	20.160 μs	22.656 μs	257 ns	ampere_sgemm_128x32_tn
0.0%	432.292 μs	100	4.322 μs	4.336 μs	4.128 μs	4.640 μs	97 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

@ngimel the PR that fixed the warnings was already merged, this one is just concerned about avoiding copies. One of the cases where it elides a copy is when you multiply a matrix by a batch of matrices. This is exactly the batched version of the vector-matrix product that"s causing the regression. That's why I think it may fix it

@ngimel
Copy link

ngimel commented Jul 29, 2022

That must be coming from functorch, as that PR doesn't introduce any additional copies.

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

yup, I think pytorch/pytorch#76828 would fix the regression

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

I patched in pytorch/pytorch#76828 and the above script ends up OOM-ing :/

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2022

Isn't 1.12.1 done already?

Not yet, we have a chance to change it (or request a 1.12.2 if necessary since this regression is large and these types of models are typical functorch usage)

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

Re. OOM. Wow, that's certainly unexpected. I'm not sure what's the best way to follow up on that. Probably @ngimel has a better idea how to proceed.

Regardless, landing that PR (the avoid copies matmul...) will be tricky due to some discrepancies in the accuracy of mm vs mv that we found for float16. As such, I think the most realistic way forward would be to revert the offending PR and investigate what's causing that OMM after 1.12.1 is released

@ngimel
Copy link

ngimel commented Jul 29, 2022

the PR that fixed the warnings was already merged

I don't think that's true, #75195 itself is fixing a warning (otherwise a user-supplied correct 1d out was sent to mm, and mm complained about resizing it to 2d).

@lezcano
Copy link
Contributor

lezcano commented Jul 29, 2022

I think the one that fixed the out version was pytorch/pytorch#75197 and a previous one in that stack, but it may be the case that 75195 also fixed an out warning, I don't remember now.

zou3519 added a commit to pytorch/pytorch that referenced this issue Jul 29, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Why is this a partial revert?
- the out= tests for nn.functional.linear fail on a complete revert
- the profiler tests fail on the revert (so I updated the expecttests
for the profiler tests)

Test Plan:
- test offline that the functorch regression was fixed

[ghstack-poisoned]
zou3519 added a commit to pytorch/pytorch that referenced this issue Jul 29, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Why is this a partial revert?
- the out= tests for nn.functional.linear fail on a complete revert
- the profiler tests fail on the revert (so I updated the expecttests
for the profiler tests)

Test Plan:
- test offline that the functorch regression was fixed

ghstack-source-id: 1332851f306c1354c705b0d1dfaa8a573233d024
Pull Request resolved: #82504
zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 1, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

ghstack-source-id: f127f9be33ba35ceeabbc6a9a4d8c24654defad7
Pull Request resolved: #82504
zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 1, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

[ghstack-poisoned]
@zou3519
Copy link
Contributor

zou3519 commented Aug 1, 2022

@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

@yueyericardo
Copy link
Contributor Author

yueyericardo commented Aug 1, 2022

Hi @zou3519, thanks for the following up! We are having some internal discussions regarding this and will come back to you tomorrow.

@samdow
Copy link
Contributor

samdow commented Aug 2, 2022

From looking at this a bit, I think what happened is:

  • pytorch changes unsqueeze + mm -> mv in case where first element is vector
  • this example hits that codepath under vmap => functorch's mv kernel
  • vector is batched and matrix is not => functorch moves the vector's batch dim to the end and matmuls. Batch dim is now the 1st dimension, not the 0th
  • relu's output is now BatchedTensor([512, 10000], lvl=0, dim=1). This needs to be saved as an activation for the backwards

During backwards,

  • grad_output is a BatchedTensor(BatchedTensor([10000, 512, 6], lvl=0, dim =0), lvl=1, dim=2)
  • threshold_backward (relu's backward) takes in grad_output and relu (binary pointwise batch rule). Starts with the lvl 1 batch rule
    • grad_output is the only value batched at the highest level, we move its batch dim to the front with a movedim. grad_output is no longer contiguous
    • redispatch on threshold_backward for lvl 0
    • move relu's bdim to front with a movedim and pad to be 3D. relu is no longer contiguous grad_output does not need to since it's lvl 0 bdim is 0. It is still contiguous
    • the view on relu is not contiguous, the view on grad_output is not contiguous => threshold_backward triggers copy
      Locally I've tested where only one of the inputs is not contiguous and haven't seen the same slowdowns.

We can also validate that this is our issue since we hit pre-regression performance numbers by changing functorch's mv batch rule here:

  auto other_ = moveBatchDimToFront(other, other_bdim);
  auto self_ = at::movedim(self, 0, 1);
  auto result = at::matmul(other_, self_);
  return std::make_tuple( std::move(result), 0 );

This doesn't trigger the copy since the batch dimension for the saved relu activation is now the first dimension. However, this may hit other perf issue from the transposing both self and result


This is the smallest subset of @zou3519's trace where I can see the perf differences

mm_2 = torch.randn(60000, 512)
relu = torch.randn(10000, 512)
def old_linear_faster():
    mm_2_view = mm_2.view([10000, 6, 1, 512])
    mm_2_view_squeeze = mm_2_view.squeeze(2)

    relu_view = relu.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(mm_2_view_squeeze, relu_view, 0)
    return threshold_backward

bmm = torch.randn(10000, 512, 6)
other_relu = torch.randn(512, 10000)
def new_linear_faster():
    bmm_view = bmm.view([10000, 512, 6])
    bmm_view_permute = bmm_view.permute([0, 2, 1])

    other_relu_permute = other_relu.permute([1, 0])
    other_relu_permute_view = other_relu_permute.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(bmm_view_permute, other_relu_permute_view, 0)
    return threshold_backward

notably, if If I change the views on either bmm or other_relu to be contiguous constants, it has much faster performance. So it seems like threshold_backwards doesn't copy if only one of its inputs is not contiguous but does if both are not

zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

[ghstack-poisoned]
zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

[ghstack-poisoned]
zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed

ghstack-source-id: ea72d0f01b28fa298535233e818ca62180f1c3f5
Pull Request resolved: #82504
@ngimel
Copy link

ngimel commented Aug 2, 2022

Thanks @samdow, copy from relu definitely seems to affect perf, however there's also another copy coming from MvBackward
Screen Shot 2022-08-02 at 10 13 10 AM

Also, why does threshold_backward on discontiguous inputs trigger a copy? In eager threshold_backward should be able to handle them via tensorIterator

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed
Pull Request resolved: #82504
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
@lezcano
Copy link
Contributor

lezcano commented Aug 2, 2022

As discussed, I think that mv_backward copy_ would be fixed by pytorch/pytorch#76828. Now, Richard reports in #989 (comment) that patching in this fix results in OOM, so it's not clear what's going on there.

In my opinion, the path forward to fix this regression would be to:

  1. Figure out why is it OOM. I'm not sure how to tackle this, but perhaps @ngimel can help here
  2. Unblock and land Avoid copies in matmul pytorch#76828. I will try to get a small repro tomorrow and pass it on to the NVIDIA folks for them to investigate further.

zou3519 added a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed
Pull Request resolved: #82504
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
atalman pushed a commit to pytorch/pytorch that referenced this issue Aug 2, 2022
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed
Pull Request resolved: #82504
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Aug 4, 2022
Summary:
This is a short-term fix for a serious regression in functorch
(pytorch/functorch#989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Pull Request resolved: #82504
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8f86e361918e3c8a0ee2b569be6c82dfbf32d705

Test plan from GitHub:
- test offline that the functorch regression was fixed

Reviewed By: kit1980

Differential Revision: D38394573

Pulled By: zou3519

fbshipit-source-id: f9185d9cb447fb439d8e402712f2f2617f73b8cc
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Aug 11, 2022
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: #56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: #82544
Approved by: https://github.com/soulitzer
facebook-github-bot pushed a commit that referenced this issue Aug 12, 2022
Summary:
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: pytorch/pytorch#56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following pytorch/pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, #989, so we are using v0.1.1 for benchmark)

zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc zasdfgbnm albanD

X-link: pytorch/pytorch#82544
Approved by: https://github.com/soulitzer

Reviewed By: seemethere

Differential Revision: D38643340

fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Aug 12, 2022
Summary:
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: #56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark)

zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc zasdfgbnm albanD

Pull Request resolved: #82544
Approved by: https://github.com/soulitzer

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/382ef1fda75dfce07c37a920e908ce96d99bf970

Reviewed By: seemethere

Differential Revision: D38643340

fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
@zou3519
Copy link
Contributor

zou3519 commented Sep 6, 2022

@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

@yueyericardo quick bump on the above. We're looking to merge some form of the above code into our benchmark suite and additional information would be very helpful

@yueyericardo
Copy link
Contributor Author

@yueyericardo quick bump on the above. We're looking to merge some form of the above code into our benchmark suite and additional information would be very helpful

@zou3519 Sorry, I already finished my internship there. I believe the next release of Modulus (this month) will include the functorch integration and will be available on GitLab.
Including @NickGeneva from Modulus team for further communications.

@lezcano
Copy link
Contributor

lezcano commented Sep 7, 2022

For what is worth, the fix pytorch/pytorch#76828 is completely stalled on some errors on NestedTensor on some GPU architectures. If any one would like to chime in on how to tackle it, you are more than welcome.

@akshaysubr
Copy link

@zou3519 The modulus version with functorch support was made public last week. Here is the repo.

could you point me to which model in the modulus repo contains the nn.Sequential above please?

The specific place where functorch is used is in this wrapper class.

Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

Yes, after computing the hessian, we need to run .backward() to actually get the weight gradients. In terms of the representative input sizes, they are typically O(1)-O(10) and same for the outputs. The sizes in the benchmark script are representative for some of the simpler cases and they might go up from there but the hidden dimensions are usually the same as in the benchmark.

@IvanYashchuk
Copy link
Contributor

@akshaysubr provided links do not work, maybe project access settings are not correctly set.

@akshaysubr
Copy link

@IvanYashchuk Sorry, should've mentioned that you have to apply for access to that repo here and after that those links should work.

@zou3519 zou3519 removed the high priority These issues are at the top of mind for us. label Oct 10, 2022
@zou3519
Copy link
Contributor

zou3519 commented Oct 10, 2022

I was able to add a version of the original benchmark to pytorch/benchmark so we now can prevent regressions in this model, so I'm lowering the priority of this issue. Leaving it open to discuss the changes to matmul above though.

@lezcano
Copy link
Contributor

lezcano commented Oct 10, 2022

fwiw, the matmul PR is as stalled as it was before. Neither Ivan or me have been able to find the time to put together a standalone C repro for the cublas team.

@lezcano
Copy link
Contributor

lezcano commented Feb 22, 2023

After the patching in the one-liner pytorch/pytorch#75195 on top of the stack pytorch/pytorch#76828, I still get an OOM using the script in the OP.
The issue happens to be that, somehow, we end up with a tensor of shape [10000, 512, 512] and strides equal to zero within bmm_out_cuda. Then, we try to materialize this tensor and we OOM. Interestingly enough, if you print the sizes/strides of the tensors that go through matmul you get

size t1
[512, 512]     
strides t1
[512, 1]  
size t2                                                                                                                                                                                                           
[10000, 512, 3]               
strides t2                                                                                                                                                                                       
[3, 30000, 1] 

so that tensor that gets to bmm looks like a batched version of the one that gets to the last if-else within matmul. Perhaps this is a feature of functorch, but I don't really understand it. Note that this would be fine if this tensor rather than strides equal to zero, it had strides equal to, say, [0, 512, 1], as we would not need to materialise it there. Now perhaps my question is, is this tensor with strides equal to zero reasonable, or is it a bug / optimisation opportunity in functorch, or whether this follows from some functorch feature, and I may be doing something wrong here.

cc @kshitij12345

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants