Skip to content

Commit

Permalink
minor (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Jul 24, 2024
1 parent 70586c4 commit cfab70e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding
from transformers.models.llama.modeling_llama import (
LlamaConfig,
LlamaRotaryEmbedding,
)

from llmfoundry.layers_registry import norms, param_init_fns
from llmfoundry.models.layers.attention import (
Expand Down Expand Up @@ -166,7 +166,7 @@ def gen_rotary_embedding(
num_attention_heads=n_heads,
)
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbeddingFoundry(config=partial_llama_config)
return LlamaRotaryEmbeddingFoundry(config=partial_llama_config)
elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}:
return LlamaRotaryEmbedding(config=partial_llama_config)
raise ValueError('rope_impl needs to be either dail or hf')
Expand Down Expand Up @@ -341,7 +341,7 @@ def apply_sequence_id(
return attn_bias


class HFRotaryEmbeddingFoundry(HFRotaryEmbedding):
class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding):

@torch.no_grad()
def forward(
Expand Down
9 changes: 4 additions & 5 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
Expand All @@ -48,7 +47,7 @@
)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel
from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry
from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
Expand Down Expand Up @@ -2924,15 +2923,15 @@ def test_hf_rotary_child_class_builds():
list(range(max_seq_len)),
] * bsz)

rot_emb_mp = HFRotaryEmbeddingFoundry(
rot_emb_mp = LlamaRotaryEmbeddingFoundry(
rope_head_dim,
max_seq_len,
rope_theta,
device='cpu',
)
cos_mp, sin_mp = rot_emb_mp(value, position_ids)

rot_emb = HFRotaryEmbedding(
rot_emb = LlamaRotaryEmbedding(
rope_head_dim,
max_seq_len,
rope_theta,
Expand Down

0 comments on commit cfab70e

Please sign in to comment.