Skip to content

Commit

Permalink
Fixing GTPQ device santacoder.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 20, 2023
1 parent 7faef69 commit 900ac49
Showing 1 changed file with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
qweight = qweight.to(device=weights.device)

slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
Expand All @@ -59,6 +60,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
scales = scales.to(device=weights.device)

slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
Expand All @@ -69,8 +71,10 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)

g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_qparams()

from text_generation_server.utils.layers import HAS_EXLLAMA
Expand All @@ -88,6 +92,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)

return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
Expand Down

0 comments on commit 900ac49

Please sign in to comment.