-
Notifications
You must be signed in to change notification settings - Fork 441
/
Copy pathllama_transformer.py
249 lines (203 loc) · 9.2 KB
/
llama_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# @lint-ignore-every LICENSELINT
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Llama 2 is licensed under the LLAMA 2 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
# Please refer to README.md in the same folder for more information.
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from executorch.examples.models.llama.attention import (
ATTENTION_REGISTRY,
ForwardOptions,
)
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope
from torch import nn
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.hidden_dim is not None
hidden_dim: int = args.hidden_dim
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConditionalFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
hidden_dim = args.hidden_dim
if hidden_dim is None:
# If hidden_dim is not explicitly set in the ModelArgs,
# then calculate implicitly based on dim and also multiple of `args.multiple_of`
multiple_of = args.multiple_of
hidden_dim = 4 * self.dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
self.num_experts = args.num_experts
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights))
x3 = torch.einsum("ti, taio -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights)
return expert_outs
class MOEFeedForward(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
self.cond_ffn = ConditionalFeedForward(config)
self.dim = config.dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(-1, self.dim)
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A]
expert_weights = expert_weights.softmax(dim=-1) # [T, A]
expert_outs = self.cond_ffn(x, expert_indices)
return torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
if args.attention_type not in ATTENTION_REGISTRY:
raise ValueError(
f"Unknown attention type: {args.attention_type}. "
f"Available: {list(ATTENTION_REGISTRY.keys())}"
)
cls = ATTENTION_REGISTRY[args.attention_type]
self.attention = cls(args, layer_id, rope)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
h, attn_options_update = self.attention.forward(
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
)
h = x + h
if hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
else:
out = h + self.feed_forward(self.ffn_norm(h))
return out, attn_options_update
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.rope = Rope(params)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params, self.rope))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.max_context_len = params.max_context_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map
def forward(
self,
tokens: Optional[torch.LongTensor] = None, # tokens
attn_options: Optional[ForwardOptions] = None,
h: Optional[torch.FloatTensor] = None, # embeddings
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[Any]]]:
if (tokens is None) ^ (h is not None):
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)
if attn_options is None:
attn_options = {}
seqlen = h.shape[1]
freqs_cos, freqs_sin = self.rope.get_freqs(
attn_options.get("input_pos"), seqlen
)
# Make a shallow copy so the updates don't get captured by export
attn_options_ = attn_options.copy() if attn_options is not None else {}
attn_options_update = None
for layer in self.layers:
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options_)
if attn_options_update is not None:
attn_options_.update(**attn_options_update)
if not self.generate_full_logits:
# Only the last logit is used for the new generated token
h = h[:, -1, :]
h = self.norm(h)
logits = self.output(h)
if self.output_prune_map is not None:
# expand to original size so that downstream applications can use the logits as-is.
if self.generate_full_logits:
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
expanded_logits = torch.full(
[logits.shape[0], logits.shape[1], self.vocab_size],
float("-inf"),
device=logits.device,
dtype=logits.dtype,
)
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
else:
# (1, pruned_size) -> (1, original_size)
expanded_logits = torch.full(
[logits.shape[0], self.vocab_size],
float("-inf"),
device=logits.device,
dtype=logits.dtype,
)
expanded_logits[:, list(self.output_prune_map.values())] = logits
logits = expanded_logits
if attn_options_update is not None:
return logits, attn_options_update
return logits