Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched generation with masking/padding #66

Open
normster opened this issue Dec 19, 2023 · 7 comments
Open

Batched generation with masking/padding #66

normster opened this issue Dec 19, 2023 · 7 comments

Comments

@normster
Copy link

The instructions in the README on running lm-evaluation-harness set batch size > 1, and I would like to try batched generation in a standalone script.

Per this previous thread (#49 (comment)) it seems like standard attention masking/padding tokens are not supported yet, which should also mean batched generation with differently sized prompts is not currently possible, so how is lm-evaluation-harness is able to handle batch size > 1?

@tridao
Copy link
Collaborator

tridao commented Dec 19, 2023

The zero-shot evals only require evaluating likelihood (to pick among multiple choices) and not generation.

I don't think the current generation code supports batched generation of different lengths.

@normster
Copy link
Author

That makes sense, thanks!

@normster
Copy link
Author

normster commented Dec 19, 2023

@tridao do you think it would be feasible to implement masking by setting padded timesteps of the discretized A and B matrices to identity operators (i.e. all 1's for A and all 0's for B)? I tried implementing this in the naive selective_scan_ref and it seems to work:

from einops import rearrange, repeat
import torch
import torch.nn.functional as F

def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, mask=None, delta_softplus=False,
                      return_last_state=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
+    mask: (B L)

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    if A.is_complex():
        if is_variable_B:
            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
        if is_variable_C:
            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
    else:
        B = B.float()
        C = C.float()
    x = A.new_zeros((batch, dim, dstate))
    ys = []
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
+    if mask is not None:
+        mask = mask[:, None, :, None].expand(-1, dim, -1, dstate) == 0
+        deltaA[mask] = 1
+        deltaB_u[mask] = 0
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
            last_state = x
        if y.is_complex():
            y = y.real * 2
        ys.append(y)
    y = torch.stack(ys, dim=2) # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
    return out if not return_last_state else (out, last_state)

batch = 1
dim = 1536
L = 10
N = 16

u = torch.randn((batch, dim, L))
delta = torch.randn((batch, dim, L))
A = torch.randn((dim, N))
B = torch.randn((batch, N, L))
C = torch.randn((batch, N, L))
D = torch.randn(dim)
z = torch.randn((batch, dim, L))
delta_bias = torch.randn(dim)
mask = torch.tensor([[0] + [1] * (L - 1)])

out = selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D,
    z,
    delta_bias,
    delta_softplus=True,
)

out_masked = selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D,
    z,
    delta_bias,
    mask,
    delta_softplus=True,
)

out_true = selective_scan_ref(
    u[..., 1:],
    delta[..., 1:],
    A,
    B[..., 1:],
    C[..., 1:],
    D,
    z[..., 1:],
    delta_bias,
    mask[..., 1:],
    delta_softplus=True,
)

print("Should be False:", torch.allclose(out[:, :, 1:], out_true))
print("Should be True:", torch.allclose(out_masked[:, :, 1:], out_true))

But I'm not sure if there's a simple way to do this in the CUDA kernels.

@normster normster reopened this Dec 19, 2023
@albertfgu
Copy link
Contributor

Yeah, that should work in principle. It might be easier to instead right-align (left-pad) all the prompts in your batch, and make sure that each layer zeros out its output in the padded regions (e.g. by passing in a mask as you did). Then you don't have to touch the internals of the SSM.

@normster
Copy link
Author

Thanks, that makes sense. I didn't realize that deltaB_u was a linear transformation of x. I guess this approach doesn't technically handle internal pad tokens correctly but it works for left padded generation.

@albertfgu
Copy link
Contributor

Yeah, I think this relies on the idea that the only thing you care about is that the recurrent state is 0 in the padded region, so it's not affecting the relevant parts. Similar to your idea of setting Abar and Bbar appropriately to ensure that the hidden state gets transmitted through. In this case if the input $x=0$ and state $h=0$, then the state should remain 0.

@normster
Copy link
Author

I tested this out in the slow path of Mamba.forward by masking twice (once before the causal conv1d and once before the selective scan):

class Mamba(nn.Module):
    ...
    def forward(self, hidden_states, mask=None, inference_params=None):
        ....
        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
            ...
        else:
            x, z = xz.chunk(2, dim=1)

+            if mask is not None:
+                x = x * mask.unsqueeze(1)

            # Compute short convolution
            if conv_state is not None:
                conv_state.copy_(x[:, :, -self.d_conv :])  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x,
                    rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    self.conv1d.bias,
                    self.activation,
                )

+            if mask is not None:
+                x = x * mask.unsqueeze(1)

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            ...

Testing with this script:

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-130m').to('cuda')
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda')
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0]]), input_ids], dim=1)
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0]]), torch.ones_like(input_ids)], dim=1)

out = model(input_ids_padded).logits.detach().cpu()
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu()
out_true = model(input_ids).logits.detach().cpu()

print("max L2 error:", (out_true - out[:, 1:]).norm(dim=-1).max())
print("max L2 errors (padded):", (out_true - out_padded[:, 1:]).norm(dim=-1).max())

This prints:

max L2 error: tensor(24580.3848)
max L2 errors (padded): tensor(0.5131)

which isn't perfect but also doesn't seem too bad for 50k dim logits. I'm guessing this is due to the causal_conv1d leaking information from the pad token in index 0. Does causal_conv1d not use zero padding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants