Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model #34629

Merged
merged 28 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e6e54f0
change apply_rotary_pos_emb
zRzRzRzRzRzRzR Nov 6, 2024
6a75751
upload for glm-edge
zRzRzRzRzRzRzR Nov 20, 2024
935fe8a
remove useless part
zRzRzRzRzRzRzR Nov 21, 2024
fa070e0
follow the suggestion
zRzRzRzRzRzRzR Nov 21, 2024
0ba58d6
fix
zRzRzRzRzRzRzR Nov 21, 2024
87d90e5
format
zRzRzRzRzRzRzR Nov 21, 2024
7038703
format
zRzRzRzRzRzRzR Nov 21, 2024
1f17ea5
test
zRzRzRzRzRzRzR Nov 21, 2024
ef9fd9c
format again
zRzRzRzRzRzRzR Nov 21, 2024
aceb417
format again
zRzRzRzRzRzRzR Nov 21, 2024
31cf72e
remove modular change
zRzRzRzRzRzRzR Nov 21, 2024
a8d3377
remove modular change
zRzRzRzRzRzRzR Nov 21, 2024
a75d83c
this apply_rotary_pos_emb need modify?
zRzRzRzRzRzRzR Nov 21, 2024
2a12a1c
fix with this
zRzRzRzRzRzRzR Nov 21, 2024
cb7a09b
format
zRzRzRzRzRzRzR Nov 21, 2024
a9001a1
format
zRzRzRzRzRzRzR Nov 21, 2024
93fb505
ruff check
zRzRzRzRzRzRzR Nov 21, 2024
c674c3e
Merge branch 'huggingface:main' into glm-4-1108
zRzRzRzRzRzRzR Nov 21, 2024
34e7229
modify modular_glm failed
zRzRzRzRzRzRzR Nov 21, 2024
c57cd93
Merge branch 'huggingface:main' into glm-4-1108
zRzRzRzRzRzRzR Nov 24, 2024
b605489
Merge branch 'huggingface:main' into glm-4-1108
zRzRzRzRzRzRzR Nov 26, 2024
0c44372
remove partial_rotary_factor of function partial_rotary_factor
zRzRzRzRzRzRzR Nov 26, 2024
8703374
fix wrong change of examples/research_projects
zRzRzRzRzRzRzR Nov 26, 2024
f81ba89
revert
zRzRzRzRzRzRzR Nov 26, 2024
73afd71
remove line 118
zRzRzRzRzRzRzR Nov 26, 2024
73614df
Merge branch 'huggingface:main' into glm-4-1108
zRzRzRzRzRzRzR Nov 26, 2024
dd47bb0
use q_rot
zRzRzRzRzRzRzR Nov 26, 2024
1ae053c
Merge branch 'glm-4-1108' of github.com:zRzRzRzRzRzRzR/transformers i…
zRzRzRzRzRzRzR Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
num_hidden_layers=40,
num_attention_heads=32,
num_key_value_heads=2,
partial_rotary_factor=0.5,
head_dim=128,
hidden_act="silu",
attention_dropout=0.0,
Expand All @@ -114,6 +115,7 @@ def __init__(
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.partial_rotary_factor = partial_rotary_factor
zRzRzRzRzRzRzR marked this conversation as resolved.
Show resolved Hide resolved
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
Expand Down
72 changes: 46 additions & 26 deletions src/transformers/models/glm/convert_glm_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import os
import re

import torch
from safetensors.torch import load_file
from tokenizers import processors
Expand Down Expand Up @@ -37,16 +36,28 @@
# fmt: on


def merge_safetensors(input_dir: str):
all_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")]
all_files = sorted(all_files, key=lambda x: int(x.rsplit("-", 3)[1]))
def load_weights(input_dir: str):
safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".safetensors")]
bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")]

all_weights = {}
for file in all_files:
tensors = load_file(file)
all_weights.update(tensors)

return all_weights
if safetensor_files:
safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in safetensor_files:
tensors = load_file(file)
all_weights.update(tensors)
return all_weights

elif bin_files:
bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1]))
for file in bin_files:
tensors = torch.load(file, map_location="cpu")
all_weights.update(tensors)
return all_weights

else:
raise ValueError("No .safetensors or .bin files found in the specified directory.")


def map_old_key_to_new(old_key):
Expand Down Expand Up @@ -100,7 +111,8 @@ def convert_config(original_config: dict):
"attention_bias": "add_qkv_bias",
}
similar_keys_to_keep = [
"num_attention_heads" "hidden_size",
"num_attention_heads",
"hidden_size",
"attention_dropout",
"use_cache",
"eos_token_id",
Expand All @@ -120,40 +132,43 @@ def convert_config(original_config: dict):
return new_config


def convert_glm_tokenizer(input_dir):
def convert_glm_tokenizer(input_dir, use_post_processor=False):
fast_tok = PreTrainedTokenizerFast.from_pretrained(input_dir, model_input_names=["input_ids", "attention_mask"])
# Add the two tokens automatically with post processor
fast_tok._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="[gMASK]:0 <sop>:0 $A:0",
pair="[gMASK]:0 <sop>:0 $A:0 $B:1",
special_tokens=[("[gMASK]", 151331), ("<sop>", 151333)],
),
],
)

if use_post_processor:
fast_tok._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="[gMASK]:0 <sop>:0 $A:0",
pair="[gMASK]:0 <sop>:0 $A:0 $B:1",
special_tokens=[("[gMASK]", 151331), ("<sop>", 151333)],
),
],
)
else:
fast_tok._tokenizer.post_processor = processors.Sequence(
[processors.ByteLevel(trim_offsets=False)],
)
return fast_tok


def convert_glm_model(input_dir, output_dir):
def convert_glm_model(input_dir, output_dir, use_post_processor=False):
# Load and convert config
with open(os.path.join(input_dir, "config.json")) as f:
original_config = json.load(f)
config = convert_config(original_config)
config.save_pretrained(output_dir)

# Load and convert weights
original_state_dict = merge_safetensors(input_dir)
original_state_dict = load_weights(input_dir)
new_dict = convert_state_dict(original_state_dict, config)
with torch.device("meta"):
model = GlmForCausalLM(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)

# Load and convert tokenizer
tokenizer = convert_glm_tokenizer(input_dir)
tokenizer = convert_glm_tokenizer(input_dir, use_post_processor)
tokenizer.save_pretrained(output_dir)


Expand All @@ -169,6 +184,11 @@ def convert_glm_model(input_dir, output_dir):
type=str,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--use_post_processor",
action="store_true",
help="Whether to apply post processor with special tokens",
)

args = parser.parse_args()
convert_glm_model(args.input_dir, args.output_dir)
convert_glm_model(args.input_dir, args.output_dir, args.use_post_processor)
Loading