This repository has been archived by the owner on Apr 14, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
layers.py
113 lines (96 loc) · 3.88 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2020 The FlaxBERT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers used in a Transformer."""
from flax import nn
import jax.numpy as jnp
LAYER_NORM_EPSILON = 1e-6
class PositionalEncoding(nn.Module):
"""Learned positional embeddings for the Transformer."""
def apply(self,
inputs, *,
max_len: int = 2048,
posemb_init=nn.initializers.xavier_normal()):
"""Applies PositionalEncoding module."""
assert inputs.ndim == 3, (
f'Number of dimention should be 3, but it is: {inputs.ndim}')
length = inputs.shape[1]
pos_emb_shape = (1, max_len, inputs.shape[-1])
pos_embedding = self.param('embedding', pos_emb_shape, posemb_init)
return pos_embedding[:, :length, :]
class FeedForward(nn.Module):
"""Feed-forward layer for a Transformer model."""
# TODO(kitaev): support chunking
def apply(self,
hidden_states, *,
d_ff: int,
dropout_rate: float = 0.0,
intermediate_activation=nn.gelu,
# TODO(kitaev): chunk_size hparam for chunking
kernel_init=nn.initializers.xavier_uniform(),
deterministic: bool = False):
"""Applies FeedForward module."""
d_model = hidden_states.shape[-1]
hidden_states = nn.Dense(
hidden_states,
d_ff,
kernel_init=kernel_init,
name='intermediate')
hidden_states = intermediate_activation(hidden_states)
hidden_states = nn.Dense(
hidden_states,
d_model,
kernel_init=kernel_init,
name='output')
hidden_states = nn.dropout(
hidden_states, rate=dropout_rate, deterministic=deterministic)
return hidden_states
class TransformerBlock(nn.Module):
"""Post-norm transformer block.."""
def apply(self,
hidden_states, mask=None, *,
feed_forward,
attention,
deterministic: bool = False):
"""Applies TransformerBlock module."""
attention_output = attention(hidden_states, mask,
deterministic=deterministic,
name='self_attention')
hidden_states = nn.LayerNorm(hidden_states + attention_output,
epsilon=LAYER_NORM_EPSILON,
name='self_attention_layer_norm')
feed_forward_output = feed_forward(hidden_states,
deterministic=deterministic,
name='feed_forward')
hidden_states = nn.LayerNorm(hidden_states + feed_forward_output,
epsilon=LAYER_NORM_EPSILON,
name='output_layer_norm')
return hidden_states
class OutputProjection(nn.Module):
"""A dense projection layer for computing output logits."""
def apply(self,
inputs, kernel=None, *,
n_out=None,
bias=True,
kernel_init=nn.initializers.lecun_normal(),
bias_init=nn.initializers.zeros):
"""Applies OutputProjection module."""
if kernel is None:
assert n_out is not None, (
'n_out argument is required when not re-using an embedding matrix')
kernel = self.param('kernel', (n_out, inputs.shape[-1]), kernel_init)
y = jnp.matmul(inputs, jnp.transpose(kernel, (1, 0)))
if bias:
bias = self.param('bias', (y.shape[-1],), bias_init)
y = y + bias
return y