Skip to content

Commit

Permalink
Fix device mismatch when running hf.generate (#1486)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Aug 27, 2024
1 parent abdf7cf commit 2978c60
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
3 changes: 2 additions & 1 deletion llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def forward(
n = self.apply_ffn(attention_mask, m)
# In the following line we move the `x` tensor to the same devices as the output of ffn layer. This operation should be a no-op during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204
x = x.to(device=n.device) + self.resid_ffn_dropout(n)
x = x.to(device=n.device,
) + self.resid_ffn_dropout(n).to(device=n.device,)
return x, attn_weights, past_key_value

def apply_ffn(
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(
self.act(self.gate_proj(x)).to(device=x.device) * self.up_proj(x),
)


def build_mptglu(
Expand Down
5 changes: 1 addition & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,7 @@ def gen_rotary_embedding(
hidden_size=d_model,
num_attention_heads=n_heads,
)
if rope_hf_config['type'] == 'no_scaling':
return LlamaRotaryEmbeddingFoundry(config=partial_llama_config)
elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}:
return LlamaRotaryEmbedding(config=partial_llama_config)
return LlamaRotaryEmbeddingFoundry(config=partial_llama_config)
raise ValueError('rope_impl needs to be either dail or hf')


Expand Down

0 comments on commit 2978c60

Please sign in to comment.