-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched generation with masking/padding #66
Comments
The zero-shot evals only require evaluating likelihood (to pick among multiple choices) and not generation. I don't think the current generation code supports batched generation of different lengths. |
That makes sense, thanks! |
@tridao do you think it would be feasible to implement masking by setting padded timesteps of the discretized A and B matrices to identity operators (i.e. all 1's for A and all 0's for B)? I tried implementing this in the naive from einops import rearrange, repeat
import torch
import torch.nn.functional as F
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, mask=None, delta_softplus=False,
return_last_state=False):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
+ mask: (B L)
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
+ if mask is not None:
+ mask = mask[:, None, :, None].expand(-1, dim, -1, dstate) == 0
+ deltaA[mask] = 1
+ deltaB_u[mask] = 0
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
batch = 1
dim = 1536
L = 10
N = 16
u = torch.randn((batch, dim, L))
delta = torch.randn((batch, dim, L))
A = torch.randn((dim, N))
B = torch.randn((batch, N, L))
C = torch.randn((batch, N, L))
D = torch.randn(dim)
z = torch.randn((batch, dim, L))
delta_bias = torch.randn(dim)
mask = torch.tensor([[0] + [1] * (L - 1)])
out = selective_scan_ref(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus=True,
)
out_masked = selective_scan_ref(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
mask,
delta_softplus=True,
)
out_true = selective_scan_ref(
u[..., 1:],
delta[..., 1:],
A,
B[..., 1:],
C[..., 1:],
D,
z[..., 1:],
delta_bias,
mask[..., 1:],
delta_softplus=True,
)
print("Should be False:", torch.allclose(out[:, :, 1:], out_true))
print("Should be True:", torch.allclose(out_masked[:, :, 1:], out_true)) But I'm not sure if there's a simple way to do this in the CUDA kernels. |
Yeah, that should work in principle. It might be easier to instead right-align (left-pad) all the prompts in your batch, and make sure that each layer zeros out its output in the padded regions (e.g. by passing in a mask as you did). Then you don't have to touch the internals of the SSM. |
Thanks, that makes sense. I didn't realize that |
Yeah, I think this relies on the idea that the only thing you care about is that the recurrent state is 0 in the padded region, so it's not affecting the relevant parts. Similar to your idea of setting Abar and Bbar appropriately to ensure that the hidden state gets transmitted through. In this case if the input |
I tested this out in the slow path of class Mamba(nn.Module):
...
def forward(self, hidden_states, mask=None, inference_params=None):
....
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
...
else:
x, z = xz.chunk(2, dim=1)
+ if mask is not None:
+ x = x * mask.unsqueeze(1)
# Compute short convolution
if conv_state is not None:
conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
+ if mask is not None:
+ x = x * mask.unsqueeze(1)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
... Testing with this script: import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-130m').to('cuda')
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda')
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0]]), input_ids], dim=1)
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0]]), torch.ones_like(input_ids)], dim=1)
out = model(input_ids_padded).logits.detach().cpu()
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu()
out_true = model(input_ids).logits.detach().cpu()
print("max L2 error:", (out_true - out[:, 1:]).norm(dim=-1).max())
print("max L2 errors (padded):", (out_true - out_padded[:, 1:]).norm(dim=-1).max()) This prints:
which isn't perfect but also doesn't seem too bad for 50k dim logits. I'm guessing this is due to the causal_conv1d leaking information from the pad token in index 0. Does causal_conv1d not use zero padding? |
The instructions in the README on running
lm-evaluation-harness
set batch size > 1, and I would like to try batched generation in a standalone script.Per this previous thread (#49 (comment)) it seems like standard attention masking/padding tokens are not supported yet, which should also mean batched generation with differently sized prompts is not currently possible, so how is
lm-evaluation-harness
is able to handle batch size > 1?The text was updated successfully, but these errors were encountered: