diff --git a/src/helm/proxy/clients/huggingface_client.py b/src/helm/proxy/clients/huggingface_client.py index ad407beb2a0..4aab3bf0adf 100644 --- a/src/helm/proxy/clients/huggingface_client.py +++ b/src/helm/proxy/clients/huggingface_client.py @@ -36,6 +36,7 @@ class StopOnToken(StoppingCriteria): This class can be used to stop generation whenever the generation results encounter the token(s) specified in `stop_token_ids`. This is particularly useful for chat models with multiple stopping criteria """ + def __init__(self, stop_token_id: Union[int, torch.LongTensor]): if isinstance(stop_token_id, torch.LongTensor): stop_token_id = stop_token_id.item() @@ -88,6 +89,63 @@ def __init__(self, model_config: HuggingFaceModelConfig): **model_kwargs ) hlog(self.model.hf_device_map) + elif quantization_config and quantization_config.model_loader == ModelLoader.SQUEEZELLM: + import torch + from squeezellm.modelutils import * + from squeezellm.quant import * + + if ( + "xgen" in quantization_config.quant_file + or "opt" in quantization_config.quant_file + or ("vicuna" in quantization_config.quant_file and "v1.3" in quantization_config.quant_file) + or "llama-2" in quantization_config.quant_file + or "mistral" in quantization_config.quant_file + ): + # TODO: this is a hacky solution, will be preperly implemented after all the model checkpoints are updated with + # the new packing scheme that includes the non-linear weights + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name) + model = AutoModelForCausalLM.from_config(config) + else: + from transformers import LlamaForCausalLM + + model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype="auto") + + model = model.eval() + layers = find_layers(model) + + state_dict = torch.load(quantization_config.quant_file) + + # load sparse thresholds from checkpoint + if quantization_config.include_sparse: + num_vals = {} + for k, v in state_dict.items(): + if "sparse_threshold." in k: + key = k.replace("sparse_threshold.", "") + num_vals[key] = v + for k, v in num_vals.items(): + del state_dict["sparse_threshold." + k] + else: + num_vals = None + + # replace layers + for name in ["lm_head"]: + if name in layers: + del layers[name] + make_quant_lut( + model, layers, quantization_config.wbits, include_sparse=quantization_config.include_sparse, + numvals=num_vals, topX=quantization_config.num_dense_channels + ) + del layers + + hlog("SqueezeLLM Loading model ...") + state_dict = torch.load(quantization_config.quant_file) + model.load_state_dict(state_dict, strict=False) + model.seqlen = 2048 + self.model = model + hlog("Done.") + else: self.model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, **model_kwargs @@ -390,4 +448,4 @@ def do_it(): return DecodeRequestResult( success=True, cached=cached, text=result["text"], request_time=result["request_time"] - ) \ No newline at end of file + ) diff --git a/src/helm/proxy/clients/huggingface_model_registry.py b/src/helm/proxy/clients/huggingface_model_registry.py index 3869658ae27..103f9829286 100644 --- a/src/helm/proxy/clients/huggingface_model_registry.py +++ b/src/helm/proxy/clients/huggingface_model_registry.py @@ -21,6 +21,7 @@ class ModelLoader(Enum): HF = auto() GPTQ = auto() AWQ = auto() + SQUEEZELLM = auto() class WeightType(Enum): @@ -40,7 +41,16 @@ class HuggingfaceModelQuantizationConfig: disable_exllama: bool = False quant_file: Optional[str] = None - """Path to .pt file generated by AWQ. This argument is required when using AutoAWQ""" + """Path to .pt file generated by AWQ. This argument is required when using AutoAWQ or SqueezeLLM""" + + wbits: Optional[int] = None + """Parameter for SqueezeLLM: #bits to use for quantization; use 16 for evaluating base model.""" + + include_sparse: Optional[bool] = None + """Parameter for SqueezeLLM: Whether loaded checkpoint has sparse matrix.""" + + num_dense_channels: Optional[int] = 10 + """Parameter for SqueezeLLM: Number of dense channel used for hybrid kernel.""" @dataclass(frozen=True)