Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

fix(llama): buffer tokens until valid UTF-8 #122

Merged
merged 2 commits into from
Apr 13, 2023
Merged

fix(llama): buffer tokens until valid UTF-8 #122

merged 2 commits into from
Apr 13, 2023

Conversation

philpax
Copy link
Collaborator

@philpax philpax commented Apr 7, 2023

As discussed on Discord and in #11.

This switches the internal representation of tokens over to raw bytes, and buffers tokens until they form valid UTF-8 in inference_with_prompt.

Open questions:

  1. Should we use smallvec or similar for tokens? We're going to be making a lot of unnecessary tiny allocations as-is.
  2. FnMut as a bound is OK, right?

@KerfuffleV2
Copy link
Contributor

KerfuffleV2 commented Apr 7, 2023

Nice, this seems to fix my problems. I can't test with Vicuna because it's GGJT.


我会讲你一个关于狐狸的故事: Once upon a time, there was an enchanted village of Kitsune. The villagers were very proud and pleased to live in such a beautiful place with bounteous natural resources. However, they soon realized that the village faced one great danger - the dragon who lived deep beneath them, guarding all sorts of magical treasures it had accumulated over its long life span...


The prompt is the bold part. It also seems to work fine with normal text, and I tested it against the main branch with a seed and got the same output in both cases (not a very extensive test).

I wouldn't really worry about performance too much for this since who's generating more than 10 tokens a second and the context limit is 2048, so... It's going to be pretty insignificant in terms of effects.

If you actually cared about allocations, probably the best way would be to just preserve the buffer. You can have the callback pass in a mutable reference to copy the completed token into when it's ready. That way both buffers only need to get allocated once and just live for the the length of the session or whatever.

@philpax
Copy link
Collaborator Author

philpax commented Apr 7, 2023

Eh, I'm not so worried about the allocations as much as I am with cache coherency. We'd be allocating lots of tiny little buffers that could just as well be inline.

You might be right though, we can figure that out later.

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

Same problem with Vicuna (using #114)

>> hello
⣽ 
 こんにちは、GPT-4についての������やご������ください。よろしくお���いします!—これと���った���みで���望的な状���でしたが、何かあり得るのはGPT-4について教えてほしい������やご
### Human: もう少す���ません。お���いします!—これと���った���みで���望的な状���でしたが、何かあり^C⏎

🤦 the model isn't trained to speak Unicode codepoints coherently

@KerfuffleV2
Copy link
Contributor

Are you saying it's worse than it was originally? The version I tried with at least seemed to do a reasonable just with Mandarin, not sure about Japanese.

As far as I know they really were only trained on English so it's not surprising if their non-English output is less than ideal.

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

Are you saying it's worse than it was originally? The version I tried with at least seemed to do a reasonable just with Mandarin, not sure about Japanese.

I haven't tried this patch yet.

As far as I know they really were only trained on English so it's not surprising if their non-English output is less than ideal.

Not the point. The model was definitely awarded for partial codepoint during training.

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

I found a fix. I set the logits of invalid tokens to 0.0.

Here's Vicuna speaking fluently.

>> こんにちは
⢿ 、

APIのテストを行うための`requests`と`json`クライアントを作成します。そして、Win32网关(win32net.exe)を使用してWindowsファイアウォール上にプロキシサーバーIP地址へのリンックを取得します。これは、APIが起動中の状态である場合や、コマンドレザスのタイムアウトなどから発生可能とさ
### Human: ごめんに、日本^C⏎      

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

Better fix:

diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs
index ddffce0..0a4c4db 100644
--- a/llama-rs/src/lib.rs
+++ b/llama-rs/src/lib.rs
@@ -1386,6 +1386,9 @@ impl InferenceSession {
         {
             let scale = 1.0 / params.temperature;
             for (i, &logit) in logits.iter().enumerate() {
+                if (131..=258).contains(&i) {
+                    continue;
+                };
                 let tid = i as TokenId;
 
                 let val = if let Some(logit_override) = params.bias_tokens.get(tid) {

@KerfuffleV2
Copy link
Contributor

KerfuffleV2 commented Apr 7, 2023

I found a fix. I set the logits of invalid tokens to 0.0.

Which ones? The token might be invalid individually, but get combined with other tokens to form a valid unicode character. So if you just set them all to 0.0, you'll prevent it from expressing any unicode characters where the components aren't all valid individually.

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

Which ones?

All of them.

prevent it from expressing any unicode characters where the components aren't all valid individually.

I think it's "unicode codepoints not present in the vocabulary as a standalone token".

@KerfuffleV2
Copy link
Contributor

I think it's "unicode codepoints not present in the vocabulary as a standalone token".

Right, but LLMs can combine those tokens that can't stand alone to create ones that can. If you remove all the ones that are invalid individually, that will limit the LLM's ability to express certain things. For example, it may not be able to use emoji (unless the emoji exists as a complete token in its vocabulary already).

@KerfuffleV2
Copy link
Contributor

llama-rs can't tokenize this yet (using Vicuna)

That's what this pull is intending to fix. Or do you mean it doesn't work even with this pull?

@iacore
Copy link
Contributor

iacore commented Apr 7, 2023

That's what this pull is intending to fix. Or do you mean it doesn't work even with this pull?

Sorry. This patch works for me.

I've merged this in my repo as branch ggjt+buffer-utf8. Maybe this is useful.

@iacore iacore mentioned this pull request Apr 8, 2023
@philpax philpax added this to the 0.1 milestone Apr 10, 2023
Copy link
Collaborator

@setzer22 setzer22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! 😄 Only one big comment w.r.t. error recovery but other than that we should be good to go.

llama-rs/src/convert.rs Outdated Show resolved Hide resolved
llama-rs/src/lib.rs Show resolved Hide resolved
llama-rs/src/lib.rs Show resolved Hide resolved
llama-rs/src/lib.rs Show resolved Hide resolved
@philpax philpax merged commit 7dd6748 into rustformers:main Apr 13, 2023
@philpax philpax deleted the buffer-utf8 branch April 13, 2023 00:10
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants