-
Notifications
You must be signed in to change notification settings - Fork 0
/
patterns.py
53 lines (47 loc) · 1.8 KB
/
patterns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from einops import rearrange
# define reshape patterns
patterns_reshape = {
'image': {
'reshape': [
'h w c -> 1 c h w', # it will add one dimension to the end of the tensor
'1 c h w -> h w c', # it will remove the dimension from the end of the tensor
'1 c (n2 p2) (n1 p1) -> (n2 n1) (p2 p1) c', # split into patches
'(n2 p2) (n1 p1) c -> (n2 n1) (p2 p1) c', # merge patches
'(n2 n1) (p2 p1) c -> 1 c (n2 p2) (n1 p1)', # split into blocks
'(n2 n1) -> 1 n2 n1', # merge blocks
'1 n2 n1 -> 1 (n2 p2) (n1 p1)', # split into patches
'1 h w c -> 1 (h w) c'
],
'mode': 'bilinear'
},
'mesh':
{
'reshape':
[
'd h w c -> 1 c d h w',
'1 c d h w -> d h w c',
'1 c (n3 p3) (n2 p2) (n1 p1) -> (n3 n2 n1) (p3 p2 p1) c',
'(n3 p3) (n2 p2) (n1 p1) c -> (n3 n2 n1) (p3 p2 p1) c',
'(n3 n2 n1) (p3 p2 p1) c -> 1 c (n3 p3) (n2 p2) (n1 p1)',
'(n3 n2 n1) -> 1 n3 n2 n1',
'1 n3 n2 n1 -> 1 (n3 p3) (n2 p2) (n1 p1)',
'1 d h w c -> 1 (d h w) c',
],
'mode': 'trilinear'
}
}
def einops_f(x, pattern, hparams=None, f=rearrange):
"""
x: input tensor
pattern: list of patterns to apply
hparams: hparams object
f: function to apply to x
Apply einops operation @f on @x according to pattern
with keys in namespace @hparams
Filter out unused keys before passing to @f.
"""
if hparams is None:
return f(x, pattern)
required_keys = \
set(pattern.replace('(','').replace(')','').split(" "))
return f(x, pattern, **{k: v for k, v in hparams.__dict__.items() if k in required_keys})