Skip to content

Commit

Permalink
making gpt2 fx tracable
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzifei-dmatrix committed Nov 7, 2024
1 parent 7bbc624 commit 746d252
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,8 @@ def forward(
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i in range(len(self.h)):
block, layer_past = self.h[i], past_key_values[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
Expand Down

0 comments on commit 746d252

Please sign in to comment.