Skip to content

Commit

Permalink
removing duplicate mlp for single layer_mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
risingsunomi committed Jan 31, 2025
1 parent a7757d3 commit 1431d48
Showing 1 changed file with 0 additions and 27 deletions.
27 changes: 0 additions & 27 deletions exo/inference/torch/models/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,31 +517,4 @@ def layer_mlp(dim: int, hidden_dim: int) -> FeedForward:
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)

"""
Llama utils
"""
def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_1/_component_builders.py#L124
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)

return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)

"""
Qwen utils
"""
def qwen2_mlp(dim: int, hidden_dim: int) -> FeedForward:
"""
Build the MLP layer associated with the Qwen2 model.
Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_component_builders.py#L127C1-L134C82
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)

0 comments on commit 1431d48

Please sign in to comment.