Skip to content

Commit

Permalink
Do #298 only for MPS (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored May 22, 2023
1 parent cdfbd91 commit 6e60162
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def generate(
xm.mark_step()

# concatenate the new generation
idx[t] = idx_next.item()
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
Expand Down
3 changes: 2 additions & 1 deletion generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def generate(
idx_next = torch.multinomial(probs, num_samples=1)

# concatenate the new generation
idx[t] = idx_next
# https://github.com/pytorch/pytorch/issues/101936
idx[t] = idx_next.item() if idx.device.type == "mps" else idx_next

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
Expand Down

0 comments on commit 6e60162

Please sign in to comment.