tags | |||
---|---|---|---|
|
An adjusted transformer decoder for longer sequences introduced in a paper by Child et al. (2019).
The paper introduced different kind of attention called sparse attention, where
each query token would attend to just few key tokens. This speads up computation
and reduced the originaly quadratic memory cost of self-attention to just
The paper also mentions:
- gradient checkpointing
- using fp16
The idea behind the sparse attention is to attend to as few tokens as possible
while providing a path of tokens through which information can travel between
two arbitrary tokens. E.g. let's say we have token
Concretely the authors use two kinds of sparse attention: strided and fixed.
When using strided attention, a query token attends to past (remember we're
dealing with decoders, so no attention to future inputs)
In fixed attention query token
- key tokens 7, 6, 5 (
$j$ tokens) - key tokens 2, 3, 4 (
$k$ tokens)
The image should illustrate this more clearly:
The authors suggest:
$l \in {256, 512, 1024}, c \in {8, 16, 32}$ - for different heads to use different
$c$ values - strided attention for repetative inputs like images or some kinds of music, fixed attention for text
For implementation the used custom kernels eventhough the described attentions can be efficiently computed by slicing out blocks from the query and key matricies. The only one I am not sure of are the strides in strided attention. The authors say "attention with a stride $k$ can be computed by transposing the matrix and computing a local window. I am not sure about that.