Skip to content

Commit

Permalink
Add support for SqueezeLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
danielz02 committed Mar 26, 2024
1 parent 3a2b0d3 commit d8da4ca
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
60 changes: 59 additions & 1 deletion src/helm/proxy/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class StopOnToken(StoppingCriteria):
This class can be used to stop generation whenever the generation results encounter the token(s) specified in
`stop_token_ids`. This is particularly useful for chat models with multiple stopping criteria
"""

def __init__(self, stop_token_id: Union[int, torch.LongTensor]):
if isinstance(stop_token_id, torch.LongTensor):
stop_token_id = stop_token_id.item()
Expand Down Expand Up @@ -88,6 +89,63 @@ def __init__(self, model_config: HuggingFaceModelConfig):
**model_kwargs
)
hlog(self.model.hf_device_map)
elif quantization_config and quantization_config.model_loader == ModelLoader.SQUEEZELLM:
import torch
from squeezellm.modelutils import *
from squeezellm.quant import *

if (
"xgen" in quantization_config.quant_file
or "opt" in quantization_config.quant_file
or ("vicuna" in quantization_config.quant_file and "v1.3" in quantization_config.quant_file)
or "llama-2" in quantization_config.quant_file
or "mistral" in quantization_config.quant_file
):
# TODO: this is a hacky solution, will be preperly implemented after all the model checkpoints are updated with
# the new packing scheme that includes the non-linear weights
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_config(config)
else:
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype="auto")

model = model.eval()
layers = find_layers(model)

state_dict = torch.load(quantization_config.quant_file)

# load sparse thresholds from checkpoint
if quantization_config.include_sparse:
num_vals = {}
for k, v in state_dict.items():
if "sparse_threshold." in k:
key = k.replace("sparse_threshold.", "")
num_vals[key] = v
for k, v in num_vals.items():
del state_dict["sparse_threshold." + k]
else:
num_vals = None

# replace layers
for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant_lut(
model, layers, quantization_config.wbits, include_sparse=quantization_config.include_sparse,
numvals=num_vals, topX=quantization_config.num_dense_channels
)
del layers

hlog("SqueezeLLM Loading model ...")
state_dict = torch.load(quantization_config.quant_file)
model.load_state_dict(state_dict, strict=False)
model.seqlen = 2048
self.model = model
hlog("Done.")

else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, **model_kwargs
Expand Down Expand Up @@ -390,4 +448,4 @@ def do_it():

return DecodeRequestResult(
success=True, cached=cached, text=result["text"], request_time=result["request_time"]
)
)
12 changes: 11 additions & 1 deletion src/helm/proxy/clients/huggingface_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ModelLoader(Enum):
HF = auto()
GPTQ = auto()
AWQ = auto()
SQUEEZELLM = auto()


class WeightType(Enum):
Expand All @@ -40,7 +41,16 @@ class HuggingfaceModelQuantizationConfig:
disable_exllama: bool = False

quant_file: Optional[str] = None
"""Path to .pt file generated by AWQ. This argument is required when using AutoAWQ"""
"""Path to .pt file generated by AWQ. This argument is required when using AutoAWQ or SqueezeLLM"""

wbits: Optional[int] = None
"""Parameter for SqueezeLLM: #bits to use for quantization; use 16 for evaluating base model."""

include_sparse: Optional[bool] = None
"""Parameter for SqueezeLLM: Whether loaded checkpoint has sparse matrix."""

num_dense_channels: Optional[int] = 10
"""Parameter for SqueezeLLM: Number of dense channel used for hybrid kernel."""


@dataclass(frozen=True)
Expand Down

0 comments on commit d8da4ca

Please sign in to comment.