-
Notifications
You must be signed in to change notification settings - Fork 56
/
torch_utils.py
46 lines (37 loc) · 1.28 KB
/
torch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import optim, nn, FloatTensor as FT
import torch.nn.parallel
import torch.utils.data
from torch.backends import cudnn
from torchvision import datasets, transforms, utils as vutils
from torch.autograd import Variable
import operator
def unit_prefix(x, n=1):
for i in range(n): x = x.unsqueeze(0)
return x
def align(x, y, start_dim=2):
xd, yd = x.dim(), y.dim()
if xd > yd: y = unit_prefix(y, xd - yd)
elif yd > xd: x = unit_prefix(x, yd - xd)
xs, ys = list(x.size()), list(y.size())
nd = len(ys)
for i in range(start_dim, nd):
td = nd-i-1
if ys[td]==1: ys[td] = xs[td]
elif xs[td]==1: xs[td] = ys[td]
return x.expand(*xs), y.expand(*ys)
def dot(x, y):
assert(1<y.dim()<5)
x, y = align(x, y)
if y.dim() == 2: return x.mm(y)
elif y.dim() == 3: return x.bmm(y)
else:
xs,ys = x.size(), y.size()
res = torch.zeros(*(xs[:-1] + (ys[-1],)))
for i in range(xs[0]): res[i].baddbmm_(x[i], (y[i]))
return res
def aligned_op(x,y,f): return f(*align(x,y,0))
def add(x, y): return aligned_op(x, y, operator.add)
def sub(x, y): return aligned_op(x, y, operator.sub)
def mul(x, y): return aligned_op(x, y, operator.mul)
def div(x, y): return aligned_op(x, y, operator.truediv)