-
Notifications
You must be signed in to change notification settings - Fork 293
/
vmap_hessian_fc.py
51 lines (43 loc) · 1.37 KB
/
vmap_hessian_fc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
from torch.func import jacfwd, jacrev, vmap
from .util import BenchmarkCase
# batched hessians of fully connected layers is a popular quantity
# in physics-related models.
# This test case is from https://github.com/pytorch/functorch/issues/989
# We haven't been able to get the full model yet, so, this test case
# is going into the functorch userbenchmark instead of torchbenchmark.
class VmapHessianFC(BenchmarkCase):
def __init__(self):
device = "cuda"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
self.model = model
self.x = x
def name(self):
return "vmap_hessian_fc_cuda"
def run(self):
def predict(x):
out = self.model(x)
return out, out
hessian, pred = vmap(
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
in_dims=0,
)(self.x)