Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take hidden states from last non-padding token when batching #38

Merged
merged 4 commits into from
Jul 2, 2024

Conversation

ohxh
Copy link
Contributor

@ohxh ohxh commented Jun 30, 2024

First of all, this is a really neat repo!

I noticed that batched_get_hiddens always takes hidden states from the last token in each sequence in a batch. Since the sequences are padded to the same length, this means that batching affects the hidden states for all but the longest sequence in each batch.

After this change, there's still some difference between the batched and non-batched hidden states, but I think that might be due to the model itself since batching changes the order of operations: huggingface/transformers#23017 (comment)

I've only tried this on llama-3-8b, I'm not sure if it will need changes to work on other models.

4 sequences, batch_size=4, old method:
[[-2.777  -2.205   3.318  ...  1.834   2.014   1.123 ]
 [ 1.21   -2.031   2.41   ...  1.883   0.391   1.652 ]
 [ 1.153  -1.737   2.281  ...  2.236   2.676   2.178 ]
 [ 1.25   -1.308   2.342  ...  0.9683  3.71    2.516 ]]
4 sequences, batch_size=4, new method:
[[-2.777   -2.205    3.318   ...  1.834    2.014    1.123  ]
 [ 0.852   -3.914    1.661   ...  1.693    0.828   -0.0934 ]
 [ 0.10767 -2.484   -1.208   ...  2.771    2.46     0.7217 ]
 [-1.701   -2.082    2.62    ...  1.927    2.334   -0.33   ]]
4 sequences, batch_size=1, old method:
[[-2.79    -2.2      3.314   ...  1.833    2.012    1.125  ]
 [ 0.8516  -3.912    1.659   ...  1.693    0.8306  -0.08746]
 [ 0.1023  -2.49    -1.211   ...  2.775    2.453    0.714  ]
 [-1.699   -2.084    2.61    ...  1.923    2.334   -0.3232 ]]
4 sequences, batch_size=1, new method:
[[-2.79    -2.2      3.314   ...  1.833    2.012    1.125  ]
 [ 0.8516  -3.912    1.659   ...  1.693    0.8306  -0.08746]
 [ 0.1023  -2.49    -1.211   ...  2.775    2.453    0.714  ]
 [-1.699   -2.084    2.61    ...  1.923    2.334   -0.3232 ]]

@vgel vgel self-requested a review July 2, 2024 05:03
@vgel
Copy link
Owner

vgel commented Jul 2, 2024

weird! i thought the tokenizers were left-padding by default... ah, mistral does...

>>> llama3_tokenizer(["x", "x x"], padding=True)
{'input_ids': [[128000, 87, 128001], [128000, 87, 865]], 'attention_mask': [[1, 1, 0], [1, 1, 1]]}
>>> mistral_tokenizer(["x", "x x"], padding=True)
{'input_ids': [[2, 1, 1318], [1, 1318, 1318]], 'attention_mask': [[0, 1, 1], [1, 1, 1]]}

@ohxh
Copy link
Contributor Author

ohxh commented Jul 2, 2024

Oh huh… maybe an easier fix would be to force the tokenizer to always left pad

@vgel
Copy link
Owner

vgel commented Jul 2, 2024

can you check "allow edits by maintainers" so i can make changes to this PR?
image

@vgel
Copy link
Owner

vgel commented Jul 2, 2024

Oh huh… maybe an easier fix would be to force the tokenizer to always left pad

yeah i was thinking that, but i think your approach is better because the user might want right-padding for whatever reason--better to not mess with their tokenizer instance if we can avoid it.

@ohxh
Copy link
Contributor Author

ohxh commented Jul 2, 2024

I think it is checked already…

Copy link
Owner

@vgel vgel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for catching this! Will have to retry all my llama-3 generations now... :-)

@vgel vgel merged commit 9c1c4c2 into vgel:main Jul 2, 2024
3 checks passed
@vgel
Copy link
Owner

vgel commented Jul 2, 2024

Glad I checked the PRs too, was just about to cut the 0.3 release so you just squeaked in!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants