-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model #34629
Conversation
Now this implementation is compatible with both GLM-Edge and GLM-4 models, @Cyrilvallez , I would like to know how to modify modular_glm.py to achieve automatic updates, because some parts of the implementation in modeling_glm.py need to add new parameters to work properly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Sorry for the delay, as I've said we were all in Martinique for our offsite the past week!
You can check my comments, but unless I'm very much mistaken or missing something, most of the changes you propose are no-ops, and you only need to change q, q_pass
/k, k_pass
in apply_rptary_pos_emb
BTW, you tagged the wrong Cyril in the PR 🤣
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | ||
def __init__(self, dim, max_position_embeddings=2048, base=10000, rotary_percent=0.5, device=None): | ||
super().__init__() | ||
|
||
self.dim = dim | ||
self.rotary_percent = rotary_percent | ||
self.dim = dim * rotary_percent | ||
self.max_position_embeddings = max_position_embeddings | ||
self.base = base | ||
|
||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) | ||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) | ||
|
||
@torch.no_grad() | ||
def forward(self, x, position_ids, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
self.inv_freq.to(x.device) | ||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||
position_ids_expanded = position_ids[:, None, :].float() | ||
# Force float32 since bfloat16 loses precision on long contexts | ||
# See https://github.com/huggingface/transformers/pull/29285 | ||
device_type = x.device.type | ||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
def forward(self, x, position_ids=None): | ||
batch_size, seq_len, head_dim = x.shape | ||
device = x.device | ||
dtype = x.dtype | ||
|
||
seq_idx = torch.arange(0, self.max_position_embeddings, device=device).float() | ||
idx_theta = torch.outer(seq_idx, self.inv_freq) | ||
|
||
if position_ids is not None: | ||
idx_theta = idx_theta[position_ids[0]] | ||
else: | ||
idx_theta = idx_theta[:seq_len] | ||
if self.rotary_percent == 0.5: | ||
idx_theta = torch.cat([idx_theta, idx_theta], dim=-1) # for glm-4-9b | ||
|
||
device_type = device.type | ||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | ||
with torch.autocast(device_type=device_type, enabled=False): | ||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
cos = emb.cos() | ||
sin = emb.sin() | ||
cos = torch.cos(idx_theta).to(dtype=dtype) | ||
sin = torch.sin(idx_theta).to(dtype=dtype) | ||
|
||
cos = cos[None, :, :].expand(batch_size, seq_len, -1) | ||
sin = sin[None, :, :].expand(batch_size, seq_len, -1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why you modified the RotaryEmbedding
class here.
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) | ||
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) | ||
cos = cos[..., : int(cos.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1) | ||
sin = sin[..., : int(sin.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it here either, this is exactly the same as before without modifying the RotaryEmbedding
class, but will only work with rotary_percent=0.5
or rotary_percent=1
, and is much more confusing IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because there are two different models, 0.5 and 1. In the config, the glm-edge series model needs to be set to 1.
https://huggingface.co/ZP2HF/glm-edge-4b-chat/blob/6a5e92d0092bba5f94abd471720238b6dda8f9de/config.json#L11
Here I have made annotations.
# Keep rotary_percent(half or not) for later concatenation | ||
rotary_dim = int(q.shape[-1] * rotary_percent) | ||
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | ||
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed this is for me the only part that should need to be modified. The rest should not need any modification
I referred to your plan and modified it to look like this. I believe the mathematical logic of this implementation is equivalent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's getting better! Still some issue in the rotary though I think
Then, once we agree on the changes, you'll need to apply the changes in modular
instead of modeling
🤗
# Apply rotary embeddings on the first half | ||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
# Apply rotary embeddings to the rotary portion | ||
q = (q * cos[..., :rotary_dim]) + (rotate_half(q) * sin[..., :rotary_dim]) | ||
k = (k * cos[..., :rotary_dim]) + (rotate_half(k) * sin[..., :rotary_dim]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you don't need to modify anything, you are basically slicing up to the full length which is useless
I would like to know if there are any improvements needed for this version, and also, I would like to know if @Cyrilvallez could guide me on how to modify modular_glm.py to make good changes.
Let it automatically generate to modeling_glm.py. |
cc @Cyrilvallez but I think it would help to remove unrelated changes! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! That's it! 🤗
Final comment is that we don't even need to change the signature of apply_rotary_pos_embed
as we can retrieve the rotary_dim
from cos
and sin
, I forgot before sorry! That way, we don't even have to modify the attention implementation, which is a big win for the modular!
Also, please remove the unrelated notebook change you added (I assume as a mistake) 🤗
# Keep half or full tensor for later concatenation | ||
rotary_dim = int(q.shape[-1] * partial_rotary_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Keep half or full tensor for later concatenation | |
rotary_dim = int(q.shape[-1] * partial_rotary_factor) | |
# Keep half or full tensor for later concatenation | |
rotary_dim = cos.shape[-1] |
We actually don't need to pass the rotary_factor
as an argument to the function!
@@ -142,7 +142,7 @@ def rotate_half(x): | |||
return torch.stack((-x2, x1), dim=-1).flatten(-2) | |||
|
|||
|
|||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5): | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
We actually don't need to pass the rotary_factor as an argument to the function! See next comment, that way we don't even have to modify the modular file for the Attentions!
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to pass the extra arg! See above
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
@@ -68,7 +68,7 @@ def rotate_half(x): | |||
return torch.stack((-x2, x1), dim=-1).flatten(-2) | |||
|
|||
|
|||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to pass the extra arg! See above
@@ -85,6 +85,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |||
partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor by which the rotary embedding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, can be removed
# Keep half or full tensor for later concatenation | ||
rotary_dim = int(q.shape[-1] * partial_rotary_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Keep half or full tensor for later concatenation | |
rotary_dim = int(q.shape[-1] * partial_rotary_factor) | |
# Keep half or full tensor for later concatenation | |
rotary_dim = cos.shape[-1] |
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot a comment, you need to ensure the dim in an integer as well.
To automatically generate the modeling file from the modular, you can run
python utils/modular_model_converter.py --files_to_parse src/transformers/models/glm/modular_glm.py
from the root of the transformers
repo 🤗
self.rotary_emb = GlmRotaryEmbedding( | ||
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta | ||
dim=config.head_dim * config.partial_rotary_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim=config.head_dim * config.partial_rotary_factor, | |
dim=int(config.head_dim * config.partial_rotary_factor), |
You need int
here as well
This modification should meet the requirements, and I have tried to remove all unnecessary code. The remaining code is all the code that will be used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good, thanks for iterating!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, small convention for q_rot, q_pass, and a nit! We can merge afterwards!
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | ||
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | |
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] | |
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | |
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
we usually use these notations!
@@ -151,8 +152,11 @@ def __init__(self, config: GlmConfig): | |||
[GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |||
) | |||
self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |||
self.partial_rotary_factor = config.partial_rotary_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.partial_rotary_factor = config.partial_rotary_factor |
I don't think this is used no?
All good, merging! |
…ggingface#34629) * change apply_rotary_pos_emb * upload for glm-edge * remove useless part * follow the suggestion * fix * format * format * test * format again * format again * remove modular change * remove modular change * this apply_rotary_pos_emb need modify? * fix with this * format * format * ruff check * modify modular_glm failed * remove partial_rotary_factor of function partial_rotary_factor * fix wrong change of examples/research_projects * revert * remove line 118 * use q_rot
What does this PR do?
This PR is to allow the new version of the GLM-4 model to use different rotary_pos_emb.
I am still researching how to modify modular_glm.py so that model_glm.py can automatically generate an additional parameter called apply_rotary_pos_emb.
Who can review?
This PR may Cyrilvallez to help.