Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible memory leak? #673

Closed
AakashKumarNain opened this issue Mar 4, 2024 · 34 comments
Closed

Possible memory leak? #673

AakashKumarNain opened this issue Mar 4, 2024 · 34 comments

Comments

@AakashKumarNain
Copy link
Contributor

AakashKumarNain commented Mar 4, 2024

@patrick-kidger as you are already aware of the Mistral port I have been working on. I have now put that code in a public repo. There are some TODOs remaining that I will complete anyway but I am dealing with a big issue. Here is the condensed code implementation for reference.

import gc
import jax
import numpy as np
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from collections import namedtuple


# Utility to convert dtypes
def to_dtype(model, dtype):
     def _to_dtype(leaf):
        if isinstance(leaf, jax.Array):  # not eqx.is_array, which also detects NumPy arrays
            leaf_with_dtype = leaf.astype(dtype)
            del leaf
            gc.collect()  # just in case?
            return leaf_with_dtype
        else:
            return leaf
     return jtu.tree_map(_to_dtype, model)


def count_jax_parameters(model):
    return sum(x.size for x in jtu.tree_leaves(eqx.filter(model, eqx.is_array)))


# 1. RoPE
def precompute_frequencies(dim, max_pos, theta=10000.0):
    inv_freq = 1.0 / (
        theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32)[: (dim // 2)] / dim)
    )
    t = jnp.arange(0, max_pos, dtype=jnp.float32)
    freqs = jnp.outer(t, inv_freq)
    return jnp.cos(freqs), jnp.sin(freqs)


def calculate_rope(x, cos_freq, sin_freq, offset=0):
    # x shape  is [seqlen, num_heads, head_dim]

    # Get the sequence length
    seqlen = x.shape[0]

    # Get the corresponding positional embeddings
    sin = sin_freq[offset : offset + seqlen, :]
    cos = cos_freq[offset : offset + seqlen, :]

    # Positional embeddings are 2D while our input is 3D
    # if `num_heads` dimension is present in the inputs.
    # We need to add another dimension to our positional embeddings
    sin = sin[:, jnp.newaxis, :]
    cos = cos[:, jnp.newaxis, :]

    # Get the even-odd positions from the inputs
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]

    # Matmul with the rotation matrix
    # [cos_nθ, -sin_nθ] [x1]
    # [sin_nθ,  cos_nθ] [x2]
    # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ]
    pos_embed = jnp.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
    pos_embed = jax.lax.collapse(pos_embed, -2)
    return pos_embed.astype(x.dtype)


# 2. Attention layer
class Attention(eqx.Module):
    n_heads: int
    n_kv_heads: int
    sliding_window: int
    scale: float
    kv_repeats: int
    head_dim: int
    wq: eqx.nn.Linear
    wk: eqx.nn.Linear
    wv: eqx.nn.Linear
    wo: eqx.nn.Linear
    
    def __init__(self, args:namedtuple, key:jax.Array):
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_kv_heads
        self.kv_repeats = self.n_heads // self.n_kv_heads
        self.sliding_window = args.sliding_window
        self.scale = args.head_dim **-0.5
        self.head_dim = args.head_dim

        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.wq = eqx.nn.Linear(args.dim, args.n_heads * args.head_dim, use_bias=False, key=key1)
        self.wk = eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim, use_bias=False, key=key2)
        self.wv = eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim,use_bias=False, key=key3)
        self.wo = eqx.nn.Linear(args.n_heads * args.head_dim, args.dim, use_bias=False, key=key4)

    def __call__(self, x, cos_freq, sin_freq, positions, mask):
        seqlen = x.shape[0]

        xq = jax.vmap(self.wq)(x)
        xk = jax.vmap(self.wk)(x)
        xv = jax.vmap(self.wv)(x)

        xq = jnp.reshape(xq, (seqlen, self.n_heads, self.head_dim))
        xk = jnp.reshape(xk, (seqlen, self.n_kv_heads, self.head_dim))
        xv = jnp.reshape(xv, (seqlen, self.n_kv_heads, self.head_dim))

        xq = calculate_rope(xq, cos_freq, sin_freq, 0)
        xk = calculate_rope(xk, cos_freq, sin_freq, 0)

        if positions.shape[0] > 1:
            # prefill
            key = jnp.repeat(xk, self.kv_repeats, axis=1)
            value = jnp.repeat(xv, self.kv_repeats, axis=1)
        # TODO: else fill from cache

        query = jnp.transpose(xq, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]
        key = jnp.transpose(key, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]
        value = jnp.transpose(value, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]

        # scores : [n_heads, seqlen | 1, seqlen]
        scores = jnp.matmul(query, jnp.transpose(key, (0, 2, 1))) * self.scale

        if mask is not None:
            # Mask will of shape [seqlen, seqlen] but our scores
            # have shape [num_heads, seqlen, seqlen], hence we need
            # to introduce another dimension in the mask
            mask = mask[jnp.newaxis, ...]
            scores = scores + mask

        scores = jax.nn.softmax(scores.astype(jnp.float32)).astype(query.dtype)
        output = jnp.matmul(scores, value)
        output = jnp.transpose(output, (0, 2, 1))
        output = jnp.reshape(output, (output.shape[-1], -1))
        output = jax.vmap(self.wo)(output)
        return output


# 3. FeedForward
class FeedForward(eqx.Module):
    w1: eqx.nn.Linear
    w2: eqx.nn.Linear
    w3: eqx.nn.Linear

    def __init__(self, args, key):
        super().__init__()
        key1, key2, key3 = jax.random.split(key, 3)

        self.w1 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key1)
        self.w2 = eqx.nn.Linear(args.hidden_dim, args.dim, use_bias=False, key=key2)
        self.w3 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key3)

    def __call__(self, x):
        return self.w2(jax.nn.silu(self.w1(x)) * self.w3(x))


# 4. TransformerBlock
class TransformerBlock(eqx.Module):
    dim: int
    n_heads: int
    attention: Attention
    attention_norm: eqx.nn.RMSNorm
    feed_forward: FeedForward
    ffn_norm: eqx.nn.RMSNorm

    def __init__(self, args, key):
        key1, key2 = jax.random.split(key, 2)
        self.n_heads = args.n_heads
        self.dim = args.dim

        self.attention = Attention(args, key=key1)
        self.attention_norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)

        self.feed_forward = FeedForward(args, key=key2)
        self.ffn_norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)

    def __call__(self, x, cos_freq, sin_freq, positions, mask):
        normed_x = jax.vmap(self.attention_norm)(x.astype(jnp.float32)).astype(jnp.float16)
        r = self.attention(normed_x, cos_freq, sin_freq, positions, mask)
        h1 = x + r
        h2 = jax.vmap(self.ffn_norm)(h1.astype(jnp.float32)).astype(jnp.float16)
        h2 = jax.vmap(self.feed_forward)(h2)
        out = h1 + h2
        return out


# 5. Transformer
class Transformer(eqx.Module):
    tok_embeddings: eqx.nn.Embedding
    layers: TransformerBlock
    norm: eqx.nn.RMSNorm
    output: eqx.nn.Linear
    vocab_size: int
    n_layers: int
    
    def __init__(self, args, key):
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        keys = jax.random.split(key, args.n_layers + 2)
        embed_key, linear_key, tf_layers_keys = keys[0], keys[-1], keys[1:-1]
        
        self.tok_embeddings = eqx.nn.Embedding(args.vocab_size, args.dim, key=embed_key)
        self.norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)
        self.output = eqx.nn.Linear(args.dim, args.vocab_size, use_bias=False, key=linear_key)

        make_tf_layers = lambda k: TransformerBlock(args, key=k)
        self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)

    def __call__(self, x, positions):
        # x is of shape (seqlen, ). We need to use vmap
        # as the embedding layer expects single token (scalar)
        # as input.
        h = jax.vmap(self.tok_embeddings)(x) # output shape: [seqlen, embed_size]
        sin_freq = precomputed_sin_freq[positions]
        cos_freq = precomputed_cos_freq[positions]
        
        if x.shape[-1] > 1: 
            seq_len = x.shape[-1]
            t = jnp.full((seq_len, seq_len), dtype=h.dtype, fill_value=1)
            mask = jnp.tril(t, k=0)
            # make the mask banded to account for sliding window
            mask = jnp.triu(mask, k=-args.sliding_window)
        else:
            mask = None

        # We need to call all the transformer blocks in a loop. Better to use lax.scan
        # as it would reduce compilation overhead and will be much faster.
        dynamic_tf_layers, static_tf_layers = eqx.partition(self.layers, eqx.is_array)
        
        def f(_x, _dynamic_tf_layers):
            tf_layer = eqx.combine(_dynamic_tf_layers, static_tf_layers)
            return tf_layer(_x, cos_freq, sin_freq, positions, mask), None

        h, _ = jax.lax.scan(f, h, dynamic_tf_layers)
        h = jax.vmap(self.norm)(h)
        h = jax.vmap(self.output)(h)
        # TODO: Calculate logits in this block
        return h


ModelArgs = namedtuple(
    "ModelArgs",
    [
        "dim",
        "n_layers",
        "hidden_dim",
        "n_heads",
        "head_dim",
        "n_kv_heads",
        "sliding_window",
        "norm_eps",
        "vocab_size",
        "max_batch_size"
    ]
)

# Same hparams as used in the original code
args = ModelArgs(
    dim=4096,
    n_layers=32,
    n_heads=32,
    n_kv_heads=8,
    head_dim=128,
    hidden_dim=14336,
    vocab_size=32000,
    max_batch_size=1,
    sliding_window=4096,
    norm_eps=1e-5
)

# Initialize the model
transformer = to_dtype(Transformer(args, key=jax.random.PRNGKey(1)), jnp.float16)

Given that this model has 7B parameters, I should be able to load this model on a A100 40G GPU as well as on TPUv3-8. The problem is that no matter what device, it always errors out with OOM.

I suspected that it might be the case that I am initializing the transformer with full precision, and then converting it to float16, so I tried the other way round where in each block I convert each layer to float16 as soon as it is initialized. Even after that it errors out with OOM. I suspect there is a memory leak somewhere but I am not able to pinpoint it.

Any suggestions/pointers would be very helpful. Thanks in advance 🙏

@Artur-Galstyan
Copy link
Contributor

From skimming the code, I suspect the issue might now lie with Equinox, since you're only using Embedding, Linear and RMSNorm from the Equinox namespace and none of those should be responsible for your problem.

Here's what I would do:

  1. Disable the 75% default allocation of the GPU by JAX:
    XLA_PYTHON_CLIENT_PREALLOCATE=false
  2. Decrease some values in your ModelArgs
  3. Check GPU usage
  4. Slowly increase the values in ModelArgs
  5. Repeat until you run out of memory
  6. Investigate your parameters and the code again.

That would be my first naive experimentation.

@AakashKumarNain
Copy link
Contributor Author

Thanks @Artur-Galstyan but I already performed all those analysis before opening the issue. For A100, it runs OOM after 4 transformer layers, and for TPUv3-8 it runs OOM after 16 layers. Both the devices have enough memory to handle a 7B model. In fact, I can load the entire torch model with no issues at all. There is definitely a memory leak, and I don't know if it's because of scan or something else

@Artur-Galstyan
Copy link
Contributor

I'll try to investigate this too, later. Unfortunately, I don't have an A100, but if my math is correct, I should (theoretically) be able to load 7B models on a 3090 with 16-bit precision. I'll let you know if I find something.

@AakashKumarNain
Copy link
Contributor Author

Sure thing. Thanks. BTW you can test it on TPUs on Kaggle for free

@patrick-kidger
Copy link
Owner

You need a leaf.delete(). Until the end of to_dtype then model still exists, so everything it references still exists in memory. (The del leaf just removes the local Python variable inside _to_dtype.)

@AakashKumarNain
Copy link
Contributor Author

Oh okay. Let me try that. Thanks

@AakashKumarNain
Copy link
Contributor Author

@patrick-kidger this throws ConcretizationError

---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
Cell In[6], line 1
----> 1 transformer = Transformer(args, key=jax.random.PRNGKey(1))
      2 jitted_transformer = eqx.filter_jit(transformer)

File /usr/local/lib/python3.10/site-packages/equinox/_module.py:514, in _ModuleMeta.__call__(cls, *args, **kwargs)
    512 initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False)
    513 # [Step 2] Instantiate the class as normal.
--> 514 self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
    515 assert not _is_abstract(cls)
    516 # [Step 3] Check that all fields are occupied.

    [... skipping hidden 2 frame]

Cell In[4], line 203, in Transformer.__init__(self, args, key)
    200 self.output = to_dtype(eqx.nn.Linear(args.dim, args.vocab_size, use_bias=False, key=linear_key), jnp.float16)
    202 make_tf_layers = lambda k: to_dtype(TransformerBlock(args, key=k), jnp.float16)
--> 203 self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)

File /usr/local/lib/python3.10/site-packages/equinox/_vmap_pmap.py:219, in _VmapWrapper.__call__(***failed resolving arguments***)
    214         raise ValueError(
    215             "Cannot resolve batch dimension. Non-`None` `out_axes` requires "
    216             "either `in_axes` or `axis_size` to be not `None`."
    217         )
    218 else:
--> 219     vmapd, static = jax.vmap(
    220         _fun_wrapper,
    221         in_axes=(in_axes,),
    222         out_axes=(0, None),
    223         axis_name=self._axis_name,
    224         axis_size=self._axis_size,
    225         **self._vmapkwargs,
    226     )(dynamic_args)
    227 nonvmapd, out_axes = static.value
    229 assert jtu.tree_structure(vmapd) == jtu.tree_structure(out_axes)

    [... skipping hidden 4 frame]

Cell In[4], line 202, in Transformer.__init__.<locals>.<lambda>(k)
    199 self.norm = to_dtype(eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True), jnp.float16)
    200 self.output = to_dtype(eqx.nn.Linear(args.dim, args.vocab_size, use_bias=False, key=linear_key), jnp.float16)
--> 202 make_tf_layers = lambda k: to_dtype(TransformerBlock(args, key=k), jnp.float16)
    203 self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)

    [... skipping hidden 3 frame]

Cell In[4], line 166, in TransformerBlock.__init__(self, args, key)
    163 self.n_heads = args.n_heads
    164 self.dim = args.dim
--> 166 self.attention = Attention(args, key=key1)
    167 self.attention_norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)
    169 self.feed_forward = FeedForward(args, key=key2)

    [... skipping hidden 3 frame]

Cell In[4], line 80, in Attention.__init__(self, args, key)
     77 self.head_dim = args.head_dim
     79 key1, key2, key3, key4 = jax.random.split(key, 4)
---> 80 self.wq = to_dtype(eqx.nn.Linear(args.dim, args.n_heads * args.head_dim, use_bias=False, key=key1), jnp.float16)
     81 self.wk = to_dtype(eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim, use_bias=False, key=key2), jnp.float16)
     82 self.wv = to_dtype(eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim,use_bias=False, key=key3), jnp.float16)

Cell In[4], line 11, in to_dtype(model, dtype)
      9     else:
     10         return leaf
---> 11 return jtu.tree_map(_to_dtype, model)

    [... skipping hidden 2 frame]

Cell In[4], line 6, in to_dtype.<locals>._to_dtype(leaf)
      4 if isinstance(leaf, jax.Array):  # not eqx.is_array, which also detects NumPy arrays
      5     leaf_with_dtype = leaf.astype(dtype)
----> 6     leaf.delete()
      7     gc.collect()  # just in case?
      8     return leaf_with_dtype

File /usr/local/lib/python3.10/site-packages/jax/_src/core.py:870, in Tracer.delete(self)
    869 def delete(self):
--> 870   raise ConcretizationTypeError(self,
    871     f"The delete() method was called on {self._error_repr()}."
    872     f"{self._origin_msg()}")

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[4096,4096].
The delete() method was called on traced array with shape float32[4096,4096]..
This BatchTracer with object id 134203306918672 was created on line:
  /tmp/ipykernel_13/2973233627.py:80 (__init__)
This BatchTracer with object id 134203306918672 was created on line:
  /tmp/ipykernel_13/2973233627.py:80 (__init__)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

@patrick-kidger
Copy link
Owner

Do the dtype conversion after model creation, not inside of it. In this case you can't delete the old buffer whilst inside a vmap.

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 5, 2024

@patrick-kidger fair enough but the other one didn't make any difference. BTW, I figured out where the issue lies, and it is definitely with equinox.filter_vmap(...). If I use a for loop to create a list of n layers of TransformeBlocks, it works i.e. I am able to load it successfully at least on TPUv3-8.

In a nutshell, inside TransformerBlock

This fails, and leads to OOM

make_tf_layers = lambda k: TransformerBlock(args, key=k)
self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)
self.layers = to_dtype(self.layers, jnp.float16)

This works though

self.layers = [to_dtype(TransformerBlock(args, key=k), jnp.float16) for k in tf_layers_keys]

@patrick-kidger
Copy link
Owner

Oh! That's super weird. What about when there is no to_dtype call inside the model at all, and only the one at the very outside (on the last line of your original example)?

(In all cases including the leaf.delete() call.)

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 6, 2024

What about when there is no to_dtype call inside the model at all, and only the one at the very outside (on the last line of your original example)? (In all cases including the leaf.delete() call.)

Still OOM. And I somewhat have an idea why it is happening. When we use eqx.filter_vmap(...) for n layers, it actually creates stacked weights with shape [n_layers, d1, d2, ....]. Unless you convert each layer to half precision as soon as it is initialized, you can't load those arrays in memory especially with full precision.

Quick question: Any trivial reason for not supporting dtype for each layer explicitly?

@patrick-kidger
Copy link
Owner

Ah, maybe I misunderstood the memory constraints of your model. Indeed, filter_vmap will create all of the weights for all layers simultaneously, all in full precision.

Quick question: Any trivial reason for not supporting dtype for each layer explicitly?

You're asking about adding an e.g. MultiheadAttention(..., dtype=...) argument you mean? We've considered this (along with other kinds of custom initialisation strategies, e.g. sampling from other distributions), but have generally decided not to include this, as it fairly rapidly gets out of hand (perhaps wanting lots of arguments for how to initialise every sub-sub-sub-layer in a model). The recommended approach has been to create the layer and then manipulate it using tree_map or tree_at as appropriate.

That said, this does seem like a case where that turns out to be fairly complicated. Hmm, it might be that we should reconsider the above approach, or else find some other way to directly create the weights with the desired precision.

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 6, 2024

Indeed, filter_vmap will create all of the weights for all layers simultaneously, all in full precision.

Thanks for the confirmation. Now that the only solution is to create a list of TransformerBlocks in a for loop, what's your recommended way of calling these layers? I want to jit the whole thing but I am not aware of any other method than what I tried above. How can I call self.layers efficiently?

self.layers = [to_dtype(TransformerBlock(args, key=k), jnp.float16) for k in tf_layers_keys]

but have generally decided not to include this, as it fairly rapidly gets out of hand (perhaps wanting lots of arguments for how to initialise every sub-sub-sub-layer in a model).

Agreed but I think dtype is far more imp than any other parameter, else loading models/layers with different precisions will become a headache

@patrick-kidger
Copy link
Owner

How can I call self.layers efficiently?

One simple thing that will work at the moment would be to stack all the layers after creating them:

layers = [to_dtype(...) ...]
self.layers = jtu.tree_map(lambda *x: jnp.stack(x) if eqx.is_array(x) else x, layers)

On dtype: agreed that this seems like it could be thought of differently to the other initialisations. I'm not currently sure if we should (a) bite the bullet and add general initialiser, (b) add just a dtype argument, or (c) maybe there is a way to finesse this with the current API. Either way this seems like a problem worth solving.

@AakashKumarNain
Copy link
Contributor Author

Thanks for the suggestion. I will try that out.

On dtype: agreed that this seems like it could be thought of differently to the other initialisations. I'm not currently sure if we should (a) bite the bullet and add general initialiser, (b) add just a dtype argument, or (c) maybe there is a way to finesse this with the current API. Either way this seems like a problem worth solving.

I can help contribute to it. Do you want to open an issue to track this?

@patrick-kidger
Copy link
Owner

Sure, go ahead and open an issue to track this.
If there is some cunning way to do this without modifying the existing API then that would definitely be best.

@AakashKumarNain
Copy link
Contributor Author

self.layers = jtu.tree_map(lambda *x: jnp.stack(x) if eqx.is_array(x) else x, layers)

I don't know if this is the intended behavior here but this doesn't create stacked arrays. self.layers is still a list.

@patrick-kidger
Copy link
Owner

Sorry, should be *layers to unpack as arguments.

@AakashKumarNain
Copy link
Contributor Author

No worries. Thanks for the clarification. Also, once the weights are stacked, I guess it will be a simple vmapped call?

@patrick-kidger
Copy link
Owner

Probably a lax.scan to iterate through the layers one at a time, rather than evaluating them in parallel. See https://docs.kidger.site/equinox/tricks/#improve-compilation-speed-with-scan-over-layers

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 10, 2024

Couple of issues here:

  1. If we stack the layers after creating a list of TransformerBlocks, then we need to comment out leaf.delete() statement in the to_dtype function otherwise it complains that the array has been deleted. This is weird as it wasn't happening when I was using eqx.filter_vmap(...) to create the stacked layers
  2. The scan example doesn't work in this case. If you look the code given above, I have done the exact same thing inside the call of the Transformer module:
dynamic_tf_layers, static_tf_layers = eqx.partition(self.layers, eqx.is_array)

def f(_x, _dynamic_tf_layers):
      tf_layer = eqx.combine(_dynamic_tf_layers, static_tf_layers)
      return tf_layer(_x, cos_freq, sin_freq, positions, mask), None

h, _ = jax.lax.scan(f, h, dynamic_tf_layers)

This throws ValueError: scan got values with different leading axis sizes:. This is again weird as nothing has fundamentally changed in the input/output shapes of the layers

Pardon if this sounds naive but isn't there a way where you can inspect the input/output shapes of each layer in an equinox module iteratively?

@patrick-kidger
Copy link
Owner

  1. Not sure I completely understand what you're getting at, but I think this sounds like a bug in how you've implemented things. You probably want to delete each thing only once, after all.
  2. You can check the sizes of each layer just by printing it out: print(self.layers); this will include the shape of every array.

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 16, 2024

I figured out what's wrong with the list-> stack approach. I have put a simple code example for your reference:

class FeedForward(eqx.Module):
    layer1: eqx.nn.Linear
    layer2: eqx.nn.Linear

    def __init__(self, in_features, out_features, key):
        key1, key2 = jax.random.split(key)
        self.layer1 = eqx.nn.Linear(in_features, out_features, key=key1)
        self.layer2 = eqx.nn.Linear(out_features, out_features, key=key2)

    def __call__(self, x):
        x = jax.vmap(self.layer1(x))
        x = jax.vmap(self.layer2(x))
        return x

class CustomModel(eqx.Module):
    layers: FeedForward
    
    def __init__(self, in_features, out_features, num_layers):
        keys = jax.random.split(jax.random.PRNGKey(1), num_layers)
        self.layers = [FeedForward(in_features, out_features, key=k) for k in keys]
        self.layers = jtu.tree_map(lambda *x: jnp.stack(x) if eqx.is_array(x) else x , *self.layers)
    
    def __call__(self, x):
           ...

If I print out CustomModel, it has this structure:

CustomModel(
  layers=FeedForward(
    layer1=Linear(
      weight=(f32[8,16], f32[8,16]),
      bias=(f32[8], f32[8]),
      in_features=16,
      out_features=8,
      use_bias=True
    ),
    layer2=Linear(
      weight=(f32[8,8], f32[8,8]),
      bias=(f32[8], f32[8]),
      in_features=8,
      out_features=8,
      use_bias=True
    )
  )
)

You can see that the FeedForward layer has weight and bias as tuples, and not big stacked matrices of sizes [2, 8, 16] and [2, 8] respectively.

PS: If you just use jnp.stack(x), then it stacks properly but if there are many layers like in the original example where mistral has 32 TransformerBlocks, it OOMs out, and I have no idea why. TPUv3-8 has enough memory to hold two blocks each of which are ~15G max

@AakashKumarNain
Copy link
Contributor Author

Any suggestions? @patrick-kidger

@patrick-kidger
Copy link
Owner

Ah, the typo is that x is a tuple, so eqx.is_array(x) returns False. I suppose the robust way to do this would be something like this:

def _stack(*x):
    is_array = {eqx.is_array(xi) for xi in x}
    if len(is_array) == 1:
        if is_array.pop():
            return jnp.stack(x)
        else:
            return x
    else:
        raise ValueError

self.layers = jtu.tree_map(_stack, *layers)

not terribly pretty but it works.
(Hopefully with the new dtype PRs then we shouldn't need to do any of htis any more anyway.)

@AakashKumarNain
Copy link
Contributor Author

Thanks for the help but it OOMs out as soon as I try to stack layers, and I have no idea why! The simple for loop for creating layers work but stacking them doesn't. Though I can give it one more try, I would take a break here as I am exhausted debugging this.

Hopefully with the new dtype PRs then we shouldn't need to do any of htis any more anyway.

Yes, it will save a lot of time, and will make the code concise

@patrick-kidger
Copy link
Owner

Probably because both versions of the weights (pre-stack and post-stack) need to live in memory at the same time for a short while.

In any case you shouldn't need to worry about this any more anyway, we've merged your PR over on the dev branch?

@AakashKumarNain
Copy link
Contributor Author

I tried it on TPUv3-8 with the changes on dev branch. Given everything is in half-precision, I used eqx.filter_vmap(...) to create one big stacked layer as opposed to creating in the loop. It OOMs out!

@patrick-kidger
Copy link
Owner

Hard to say what's going on from where I'm sitting I'm afraid. I hope this is something you can debug, and then let us know where the error came from!

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Apr 1, 2024

I literally have no idea. If it was just a memory issue, the same code should OOM out when we are creating transformer blocks in a loop. The issue lies in replacing the for loop with eqx.filter_vmap

@AakashKumarNain
Copy link
Contributor Author

Will pick this up again next week, and will update the findings

@AakashKumarNain
Copy link
Contributor Author

This is now solved but I figured out to solve it in a much better way.

@patrick-kidger Mistral-7B port is complete. Here is the source code for your reference. Thanks again for all the help.

@patrick-kidger
Copy link
Owner

Oh awesome! I'm really glad that you managed to complete this.

I'd be happy to link to your GitHub page from the "advanced examples" in the Equinox documentation, if you'd like?

By the way, I see a comment in your code about RMSNorm being wrong -- what's going on there? :)

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented May 12, 2024

I'd be happy to link to your GitHub page from the "advanced examples" in the Equinox documentation, if you'd like?

That would be awesome! Thanks 🍺

By the way, I see a comment in your code about RMSNorm being wrong -- what's going on there? :)

Like RoPE, when you compute norm of the inputs, it should be calculated atleast with full precision. Right now now the source code looks like this:

inv_rms = jax.lax.rsqrt(jnp.mean(x**2) + self.eps)

If I initialize RMSNorm with flot16/bfloat16, the calculation for norm will take place with that precision resulting in wrong output (not mathematically wrong, but that it can lead to instability at certain places). This inputs should first be scaled up and then the resulting output should be casted to the desired precision:

upscaled_x = x.astype(jnp.float32)
inv_rms = jax.lax.rsqrt(jnp.mean(upscaled_x**2) + self.eps)

I will send a PR for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants