From 5df1556ee6ee559a6b97b8d7e85662d5c82d4d3c Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Fri, 7 Jun 2024 10:49:03 +0800 Subject: [PATCH] improve lm_eval get chatglm2 tokenizer from local (#1598) Signed-off-by: changwangss --- .../evaluation/lm_eval/models/huggingface.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/models/huggingface.py b/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/models/huggingface.py index 42541846c58..24ed84b63a4 100644 --- a/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/models/huggingface.py +++ b/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/models/huggingface.py @@ -838,12 +838,20 @@ def _create_tokenizer( else: # get the HF hub name via accessor on model model_name = self.model.name_or_path - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name, - revision=revision, - trust_remote_code=trust_remote_code, - use_fast=use_fast_tokenizer, - ) + + # chatglm2 tokenizer doesn't support loading from local. + if hasattr(self.model, "config") and hasattr(self.model.config, "auto_map") and \ + "chatglm2" in self.model.config.auto_map["AutoConfig"]: + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + "THUDM/chatglm2-6b", trust_remote_code=True + ) + else: + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast=use_fast_tokenizer, + ) return None def _detect_batch_size(self, requests=None, pos: int = 0):