diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 5e99e0a960..9389cf385f 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -144,6 +144,7 @@ def __init__( **self.fc_kwargs, ) + @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))