Skip to content

Commit

Permalink
RoPE Embeddings (#568)
Browse files Browse the repository at this point in the history
* rope embeddings added

* added sinusoidial embedding

* added rope to mha

* added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA

* remove `use_rope_embedding` flag

* fixed merge related errors

* removed unnecessary state_len flag and placed shape checking in if-clause

* rope embeddings added

* added sinusoidial embedding

* added rope to mha

* added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA

* remove `use_rope_embedding` flag

* fixed merge related errors

* removed unnecessary state_len flag and placed shape checking in if-clause

* worked in review

* export new embeddings

* removed state len again, oops

* add ensure_compile_time_eval

* remove max_seq_len completely

* removed unnecessary if check

* improved docstrings

* better mem, adhering to strict jax config

* fixed dtype promotion

* removed dtype float and use float(seq_len) instead

* jnp.arange(0.0, ...) to force floats

* Adjustments to RoPE:

- Changed how the rotation is done to match the ESM2 implementation.
- Lots of doc tidy-ups.
- Removed SinusoidalPositionalEmbedding. I think I want to be more certain that this is correct before merging it.

* added rope tests

* typo

* fixed tests and annotations

* removed internal_sinus cache

---------

Co-authored-by: Patrick Kidger <[email protected]>
  • Loading branch information
Artur-Galstyan and patrick-kidger committed Apr 14, 2024
1 parent 7988ec2 commit fce75be
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/api/nn/attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
members:
- __init__
- __call__

---

::: equinox.nn.RotaryPositionalEmbedding
selection:
members:
- __init__
- __call__
5 changes: 4 additions & 1 deletion equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
ConvTranspose3d as ConvTranspose3d,
)
from ._dropout import Dropout as Dropout
from ._embedding import Embedding as Embedding
from ._embedding import (
Embedding as Embedding,
RotaryPositionalEmbedding as RotaryPositionalEmbedding,
)
from ._inference import inference_mode as inference_mode
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
Expand Down
38 changes: 38 additions & 0 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools as ft
import math
import warnings
from collections.abc import Callable
from functools import partial
from typing import cast, Optional, Union

Expand Down Expand Up @@ -228,6 +229,20 @@ def __call__(
key: Optional[PRNGKeyArray] = None,
inference: Optional[bool] = None,
deterministic: Optional[bool] = None,
process_heads: Optional[
Callable[
[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"],
],
tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"],
],
]
] = None,
) -> Float[Array, "q_seq o_size"]:
"""**Arguments:**
Expand All @@ -246,6 +261,10 @@ def __call__(
- `inference`: As [`equinox.nn.Dropout.__call__`][]. (Keyword only
argument.)
- `deterministic`: (Deprecated in favour of `inference`.)
- `process_heads`: A function that takes in the query, key, and value heads and
returns new query, key, and value heads. For example, this can be
used to implement relative positional embeddings -
see e.g. `RotaryPositionalEmbedding`for an example. (Keyword only argument.)
**Returns:**
Expand All @@ -270,6 +289,25 @@ def __call__(
key_heads = self._project(self.key_proj, key_)
value_heads = self._project(self.value_proj, value)

if process_heads is not None:
q_shape, k_shape, v_shape = (
query_heads.shape,
key_heads.shape,
value_heads.shape,
)
query_heads, key_heads, value_heads = process_heads(
query_heads, key_heads, value_heads
)

if (
query_heads.shape != q_shape
or key_heads.shape != k_shape
or value_heads.shape != v_shape
):
raise ValueError(
"process_heads must not change the shape of the heads."
)

attn_fn = partial(
dot_product_attention, dropout=self.dropout, inference=inference
)
Expand Down
145 changes: 144 additions & 1 deletion equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import Array, ArrayLike, Float, Int, PRNGKeyArray
from jaxtyping import Array, ArrayLike, Complex, Float, Int, PRNGKeyArray

from .._caches import cache_clears
from .._filters import is_array_like
from .._module import field, Module


internal_rope_embedding_cache: dict[int, Array] = {}
cache_clears.append(internal_rope_embedding_cache.clear)


class Embedding(Module, strict=True):
"""A simple lookup table that stores embeddings of a fixed size."""

Expand Down Expand Up @@ -89,3 +94,141 @@ def __call__(
"`eqx.nn.Embedding()(x)` should be called with a scalar index `x`. "
"Use `jax.vmap` if you would like to index with multiple values."
)


class RotaryPositionalEmbedding(Module, strict=True):
"""A rotary positional encoding module, as described in the paper
"RoFormer: Enhanced Transformer with Rotary Position Embedding". While this module
can be used in any context, it is particularly useful for providing positional
information to transformer models.
!!! Example
The following example demonstrates how to use `RotaryPositionalEmbedding` in
a simple transformer model.
```python
class TransformerBlock(eqx.Module):
rope_embeddings: RotaryPositionalEmbedding
def __init__(...):
self.rope_embeddings = RotaryPositionalEmbedding(...)
def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
return query_heads, key_heads, value_heads
x = self.mha_attention(... process_heads=process_heads)
...
```
??? cite
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864)
```bibtex
@misc{su2023roformer,
title={RoFormer: Enhanced Transformer with Rotary Position Embedding},
author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and
Bo Wen and Yunfeng Liu},
year={2023},
eprint={arXiv:2104.09864},
}
```
"""

embedding_size: int = field(static=True)

def __check_init__(self):
if self.embedding_size < 0:
raise ValueError("`embedding_size` must not be negative.")
if (self.embedding_size % 2) != 0:
raise ValueError("`embedding_size` must be even.")

@staticmethod
def rotate_half(x: Float[Array, "seq_length embedding_size"]):
d_2 = x.shape[-1] // 2
return jnp.concatenate([-x[..., d_2:], x[..., :d_2]], axis=-1)

@staticmethod
def precompute_freqs_cis(
embedding_size: int, end: int, theta: float = 10000.0
) -> Complex[Array, "end half_of_embedding_size"]:
freqs = 1.0 / (
theta
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end))
freqs_outer = jnp.outer(t, freqs)
with jax.numpy_dtype_promotion("standard"):
freqs_cis = jnp.cos(freqs_outer) + jnp.sin(freqs_outer) * 1j

return freqs_cis

@jax.named_scope("eqx.nn.RotaryPositionalEmbedding")
def __call__(
self,
x: Float[Array, "seq_length embedding_size"],
*,
key: Optional[PRNGKeyArray] = None,
) -> Float[Array, "seq_length embedding_size"]:
"""**Arguments:**
- `x`: A JAX array of shape `(seq_length, embedding_size)`.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
**Returns:**
A JAX array of shape `(seq_length, embedding_size)`, with the rotary positional
encoding applied to the input.
"""

seq_len, embedding_size = x.shape
if embedding_size != self.embedding_size:
raise ValueError(
f"x.shape[-1] must match self.embedding_size, "
f"but {x.shape[-1]} != {self.embedding_size}"
)

with jax.ensure_compile_time_eval():
if embedding_size in internal_rope_embedding_cache:
freqs_cis = internal_rope_embedding_cache[embedding_size]
freqs_cis_seq_len, _ = freqs_cis.shape
if seq_len > freqs_cis_seq_len:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
internal_rope_embedding_cache[embedding_size] = freqs_cis
else:
freqs_cis = freqs_cis[:seq_len]
else:
freqs_cis = self.precompute_freqs_cis(embedding_size, seq_len)
internal_rope_embedding_cache[embedding_size] = freqs_cis

freqs_real = jnp.tile(freqs_cis.real, (1, 2))
freqs_imag = jnp.tile(freqs_cis.imag, (1, 2))

rotate_x = self.rotate_half(x)
x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
return x_rope


RotaryPositionalEmbedding.__init__.__doc__ = """**Arguments:**
- `embedding_size`: Size of the token embeddings. Must be non-negative and even.
"""
90 changes: 90 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,3 +1328,93 @@ def test_prelu(getkey):

assert activation.negative_slope.shape == (x.shape[-1],)
assert jnp.all(output == expected_multi_output)


def test_rope_embeddings_shapes(getkey):
embedding_size = 32
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)

n_heads = 4
seq_length = 8
query_size = 32
key_size = 32

query_heads = jax.random.normal(
key=getkey(), shape=(seq_length, n_heads, query_size)
)
key_heads = jax.random.normal(key=getkey(), shape=(seq_length, n_heads, key_size))
query_heads = jax.vmap(rope_embeddings, in_axes=1, out_axes=1)(query_heads)
key_heads = jax.vmap(rope_embeddings, in_axes=1, out_axes=1)(key_heads)

assert query_heads.shape == (seq_length, n_heads, query_size)
assert key_heads.shape == (seq_length, n_heads, key_size)


def test_rope_embeddings_freqs_cis():
# values are generated using
# Metas Rope embedding code. See this gist which generates these
# expected values: https://gist.github.com/Artur-Galstyan/8d0bb5743f00650aa6c0d7427595a0ff
expected_freqs_cis = jnp.array(
[
[1.0000 + 0.0000j, 1.0000 + 0.0000j, 1.0000 + 0.0000j, 1.0000 + 0.0000j],
[0.5403 + 0.8415j, 0.9950 + 0.0998j, 0.9999 + 0.0100j, 1.0000 + 0.0010j],
[-0.4161 + 0.9093j, 0.9801 + 0.1987j, 0.9998 + 0.0200j, 1.0000 + 0.0020j],
[-0.9900 + 0.1411j, 0.9553 + 0.2955j, 0.9996 + 0.0300j, 1.0000 + 0.0030j],
[-0.6536 - 0.7568j, 0.9211 + 0.3894j, 0.9992 + 0.0400j, 1.0000 + 0.0040j],
[0.2837 - 0.9589j, 0.8776 + 0.4794j, 0.9988 + 0.0500j, 1.0000 + 0.0050j],
[0.9602 - 0.2794j, 0.8253 + 0.5646j, 0.9982 + 0.0600j, 1.0000 + 0.0060j],
[0.7539 + 0.6570j, 0.7648 + 0.6442j, 0.9976 + 0.0699j, 1.0000 + 0.0070j],
[-0.1455 + 0.9894j, 0.6967 + 0.7174j, 0.9968 + 0.0799j, 1.0000 + 0.0080j],
[-0.9111 + 0.4121j, 0.6216 + 0.7833j, 0.9960 + 0.0899j, 1.0000 + 0.0090j],
[-0.8391 - 0.5440j, 0.5403 + 0.8415j, 0.9950 + 0.0998j, 0.9999 + 0.0100j],
[0.0044 - 1.0000j, 0.4536 + 0.8912j, 0.9940 + 0.1098j, 0.9999 + 0.0110j],
[0.8439 - 0.5366j, 0.3624 + 0.9320j, 0.9928 + 0.1197j, 0.9999 + 0.0120j],
[0.9074 + 0.4202j, 0.2675 + 0.9636j, 0.9916 + 0.1296j, 0.9999 + 0.0130j],
[0.1367 + 0.9906j, 0.1700 + 0.9854j, 0.9902 + 0.1395j, 0.9999 + 0.0140j],
[-0.7597 + 0.6503j, 0.0707 + 0.9975j, 0.9888 + 0.1494j, 0.9999 + 0.0150j],
]
)
embedding_size = 8
seq_length = 16
freqs_cis = eqx.nn.RotaryPositionalEmbedding.precompute_freqs_cis(
embedding_size, seq_length
)
assert jnp.allclose(freqs_cis, expected_freqs_cis, atol=1e-4)


def test_rope_embeddings_values():
# values are generated using
# the script in this gist:
# https://gist.github.com/Artur-Galstyan/d33eda74072fea61545127adb90197b5
# Those values are generated based on the HuggingFace implementation
# of the Rope embeddings
# (see here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_flax_llama.py#L169)
expected_values = jnp.array(
[
[
0.0,
1.0,
2.0,
3.0,
],
[-2.887617, 4.9297514, 6.6076975, 7.0496492],
[-12.422148, 8.778215, 3.1129112, 11.177788],
[-13.85559, 12.544218, -12.166454, 15.383192],
[3.1641474, 16.226604, -23.874424, 19.664621],
[26.769577, 19.824234, -12.937918, 24.020819],
[30.30889, 23.335985, 18.258457, 28.450514],
[1.3996639, 26.760752, 41.01269, 32.952423],
]
)

seq_length = 8
embedding_size = 4

x = jnp.arange(seq_length * embedding_size * 1.0).reshape(
seq_length, embedding_size
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)

0 comments on commit fce75be

Please sign in to comment.