-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathmodel.py
165 lines (124 loc) · 5.96 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
this extremely minimal Decision Transformer model is based on
the following causal transformer (GPT) implementation:
Misha Laskin's tweet:
https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA
and its corresponding notebook:
https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing
** the above colab notebook has a bug while applying masked_fill
which is fixed in the following code
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MaskedCausalAttention(nn.Module):
def __init__(self, h_dim, max_T, n_heads, drop_p):
super().__init__()
self.n_heads = n_heads
self.max_T = max_T
self.q_net = nn.Linear(h_dim, h_dim)
self.k_net = nn.Linear(h_dim, h_dim)
self.v_net = nn.Linear(h_dim, h_dim)
self.proj_net = nn.Linear(h_dim, h_dim)
self.att_drop = nn.Dropout(drop_p)
self.proj_drop = nn.Dropout(drop_p)
ones = torch.ones((max_T, max_T))
mask = torch.tril(ones).view(1, 1, max_T, max_T)
# register buffer makes sure mask does not get updated
# during backpropagation
self.register_buffer('mask',mask)
def forward(self, x):
B, T, C = x.shape # batch size, seq length, h_dim * n_heads
N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim
# rearrange q, k, v as (B, N, T, D)
q = self.q_net(x).view(B, T, N, D).transpose(1,2)
k = self.k_net(x).view(B, T, N, D).transpose(1,2)
v = self.v_net(x).view(B, T, N, D).transpose(1,2)
# weights (B, N, T, T)
weights = q @ k.transpose(2,3) / math.sqrt(D)
# causal mask applied to weights
weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
# normalize weights, all -inf -> 0 after softmax
normalized_weights = F.softmax(weights, dim=-1)
# attention (B, N, T, D)
attention = self.att_drop(normalized_weights @ v)
# gather heads and project (B, N, T, D) -> (B, T, N*D)
attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)
out = self.proj_drop(self.proj_net(attention))
return out
class Block(nn.Module):
def __init__(self, h_dim, max_T, n_heads, drop_p):
super().__init__()
self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
self.mlp = nn.Sequential(
nn.Linear(h_dim, 4*h_dim),
nn.GELU(),
nn.Linear(4*h_dim, h_dim),
nn.Dropout(drop_p),
)
self.ln1 = nn.LayerNorm(h_dim)
self.ln2 = nn.LayerNorm(h_dim)
def forward(self, x):
# Attention -> LayerNorm -> MLP -> LayerNorm
x = x + self.attention(x) # residual
x = self.ln1(x)
x = x + self.mlp(x) # residual
x = self.ln2(x)
return x
class DecisionTransformer(nn.Module):
def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len,
n_heads, drop_p, max_timestep=4096):
super().__init__()
self.state_dim = state_dim
self.act_dim = act_dim
self.h_dim = h_dim
### transformer blocks
input_seq_len = 3 * context_len
blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
self.transformer = nn.Sequential(*blocks)
### projection heads (project to embedding)
self.embed_ln = nn.LayerNorm(h_dim)
self.embed_timestep = nn.Embedding(max_timestep, h_dim)
self.embed_rtg = torch.nn.Linear(1, h_dim)
self.embed_state = torch.nn.Linear(state_dim, h_dim)
# # discrete actions
# self.embed_action = torch.nn.Embedding(act_dim, h_dim)
# use_action_tanh = False # False for discrete actions
# continuous actions
self.embed_action = torch.nn.Linear(act_dim, h_dim)
use_action_tanh = True # True for continuous actions
### prediction heads
self.predict_rtg = torch.nn.Linear(h_dim, 1)
self.predict_state = torch.nn.Linear(h_dim, state_dim)
self.predict_action = nn.Sequential(
*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))
)
def forward(self, timesteps, states, actions, returns_to_go):
B, T, _ = states.shape
time_embeddings = self.embed_timestep(timesteps)
# time embeddings are treated similar to positional embeddings
state_embeddings = self.embed_state(states) + time_embeddings
action_embeddings = self.embed_action(actions) + time_embeddings
returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings
# stack rtg, states and actions and reshape sequence as
# (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...)
h = torch.stack(
(returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
h = self.embed_ln(h)
# transformer and prediction
h = self.transformer(h)
# get h reshaped such that its size = (B x 3 x T x h_dim) and
# h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
# h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t
# h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t
# that is, for each timestep (t) we have 3 output embeddings from the transformer,
# each conditioned on all previous timesteps plus
# the 3 input variables at that timestep (r_t, s_t, a_t) in sequence.
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)
# get predictions
return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a
state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a
action_preds = self.predict_action(h[:,1]) # predict action given r, s
return state_preds, action_preds, return_preds