Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
improve lm_eval get chatglm2 tokenizer from local (#1598)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss authored Jun 7, 2024
1 parent b1d3d3c commit 5df1556
Showing 1 changed file with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -838,12 +838,20 @@ def _create_tokenizer(
else:
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)

# chatglm2 tokenizer doesn't support loading from local.
if hasattr(self.model, "config") and hasattr(self.model.config, "auto_map") and \
"chatglm2" in self.model.config.auto_map["AutoConfig"]:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"THUDM/chatglm2-6b", trust_remote_code=True
)
else:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
return None

def _detect_batch_size(self, requests=None, pos: int = 0):
Expand Down

0 comments on commit 5df1556

Please sign in to comment.