Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-P sampling occasionally produces invalid tokens #1590

Closed
4 tasks done
AlessioNetti opened this issue May 13, 2024 · 4 comments
Closed
4 tasks done

Top-P sampling occasionally produces invalid tokens #1590

AlessioNetti opened this issue May 13, 2024 · 4 comments
Assignees
Labels
bug Something isn't working stale triaged Issue has been triaged by maintainers

Comments

@AlessioNetti
Copy link
Contributor

System Info

  • Nvidia A40
  • CUDA 12.2
  • TensorRT 10.0.1.6
  • TensorRT-LLM 0.10.0.dev2024050700

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

We noticed that TensorRT-LLM occasionally (~0.01% of requests) generates invalid tokens. The issue can be reproduced using a generic Falcon 7B model via the following:

python convert_checkpoint.py --model_dir ./falcon_7b_tp1_instruct/ --dtype bfloat16 --output_dir ./falcon_7b_tp1_instruct_trt_chkpt

trtllm-build --checkpoint_dir ./falcon_7b_tp1_instruct_trt_chkpt/ --gemm_plugin bfloat16 --remove_input_padding enable --gpt_attention_plugin bfloat16 --output_dir ./falcon_7b_tp1_instruct_p200_g200 --gather_all_token_logits --max_input_len 200 --max_output_len 200 --max_batch_size 64

python example_basic.py --model_path ./falcon_7b_tp1_instruct_p200_g200

The examples/bindings/executor/example_basic.py script was modified to issue random top-P requests (in batches of 16) until an invalid token is detected in the output. The changes are as in the following:

diff --git a/examples/bindings/executor/example_basic.py b/examples/bindings/executor/example_basic.py
index 2c7a3fc..65a9b57 100644
--- a/examples/bindings/executor/example_basic.py
+++ b/examples/bindings/executor/example_basic.py
@@ -1,4 +1,6 @@
 import argparse
+import torch
+import random
 
 import tensorrt_llm.bindings.executor as trtllm
 
@@ -20,16 +22,25 @@ if __name__ == "__main__":
                                trtllm.ExecutorConfig(1))
 
     if executor.can_enqueue_requests():
-        # Create the request.
-        request = trtllm.Request(input_token_ids=[1, 2, 3, 4],
-                                 max_new_tokens=10)
-
-        # Enqueue the request.
-        request_id = executor.enqueue_request(request)
-
-        # Wait for the new tokens.
-        responses = executor.await_responses(request_id)
-        output_tokens = responses[0].result.output_token_ids
-
-        # Print tokens.
-        print(output_tokens)
+        while True:
+            # Create the request.
+            requests = []
+            for _ in range(16):
+                input_token_ids = [random.randint(100, 10000) for _ in range(200)]
+                requests.append(trtllm.Request(input_token_ids=input_token_ids, max_new_tokens=200,
+                                               sampling_config=trtllm.SamplingConfig(top_p=0.5, top_k=None, temperature=20.0)))
+
+            # Enqueue the request.
+            request_ids = executor.enqueue_requests(requests)
+
+            # Wait for the new tokens.
+            responses = executor.await_responses(request_ids)
+            
+            for idx, re in enumerate(responses):
+                output_tokens = re[0].result.output_token_ids[0]
+                valid_output = all(el >= 0 and el < 200000 for el in output_tokens)
+                if not valid_output:
+                    print(f"Output tokens : {output_tokens[200:]}")
+                    exit(-1)
+                else:
+                    print(f"Valid output produced for request {request_ids[idx]}.")

Expected behavior

Requests should always generate valid tokens, that are in the [0, vocabulary_size) range.

actual behavior

Occasionally, requests will produce invalid tokens that are outside of the model's vocabulary size. Below is an example of the issue under our custom example_basic.py script:

Valid output produced for request 9534.
Valid output produced for request 9535.
Valid output produced for request 9536.
Output tokens : [47796, 54241, 47783, 58101, 6674, 23726, 23592, 42594, 6139, 25248, 52039, 47238, 46481, 59789, 36977, 9214, 30383, 31047, 19853, 59072, 25294, 63500, 59925, 44334, 38232, 28210, 38889, 26873, 35512, 48818, 38165, 14048, 49025, 30020, 59300, 49636, 5338, 63956, 4748, 22356, 26041, 19883, 22013, 32389, 24446, 36715, 11451, 13325, 58318, 29675, 12733, 15128, 323, 26868, 42477, 28018, 18622, 52692, 60096, 19486, 3727, 1427, 32693, 18763, 38281, 38747, 52358, 58497, 17945, 36842, 9453, 23113, 21691, 22407, 9894, 27278, 8361, 40261, 2147483647, 18931, 38614, 47912, 48115, 36611, 33955, 41329, 45530, 23243, 43669, 10268, 19238, 6055, 49515, 63961, 29434, 48151, 54508, 25936, 55805, 10214, 28366, 22400, 7200, 17613, 30007, 16812, 1529, 62540, 63633, 7331, 58970, 46938, 25656, 52488, 11953, 32571, 13142, 61313, 9385, 49280, 43718, 47734, 27930, 3368, 56759, 41270, 23886, 32473, 48038, 12786, 39043, 4837, 16915, 2584, 16430, 56707, 46255, 26404, 33055, 51739, 14011, 18179, 25129, 7630, 62620, 11823, 51429, 7700, 17108, 7422, 9389, 9999, 32405, 36641, 6937, 13023, 29698, 60332, 10098, 46336, 54260, 41558, 32326, 7579, 58826, 2443, 12843, 38563, 51635, 63544, 10124, 2484, 43080, 16858, 24803, 3017, 42640, 46269, 22102, 53352, 51123, 42491, 55109, 27590, 2322, 28774, 9365, 19873, 1538, 64635, 8407, 63458, 49056, 53777, 5887, 16413, 5956, 36375, 42348, 27573]

As it can be seen, one of the tokens is 2147483647. In other instances we have also observed negative tokens, but always in the billions range - this would suggest an integer overflow issue connected to top-P sampling logic somewhere.

additional notes

  • We first observed the issue on TensorRT-LLM 0.10.0.dev2024041600, and it is present up until 0.10.0.dev2024050700;
  • The issue occurs both when using the Executor and Python ModelRunner APIs.
@AlessioNetti AlessioNetti added the bug Something isn't working label May 13, 2024
@byshiue
Copy link
Collaborator

byshiue commented May 22, 2024

Thank you. I can reproduce the issue. I little change the basic_example to help accelerating the reproducing.

import argparse
import torch
import random

import tensorrt_llm.bindings.executor as trtllm

# This example hows to use the python bindings to create an executor, enqueue a
# request, and get the generated tokens.

# First, follow the steps in README.md to generate the engines.

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Executor Bindings Example")
    parser.add_argument("--model_path",
                        type=str,
                        required=True,
                        help="Directory containing model engine")
    args = parser.parse_args()

    # Create the executor.
    executor = trtllm.Executor(args.model_path, trtllm.ModelType.DECODER_ONLY,
                               trtllm.ExecutorConfig(1))

    random.seed(1234)
    if executor.can_enqueue_requests():
        ite_count = 0
        while True:
            # Create the request.
            requests = []
            ite_count += 16
            
            for _ in range(16):
                input_token_ids = [random.randint(100, 10000) for _ in range(200)]
                requests.append(trtllm.Request(input_token_ids=input_token_ids, max_new_tokens=105,
                                               sampling_config=trtllm.SamplingConfig(top_p=0.5, top_k=None, temperature=20.0)))
            if ite_count < 6616:
                continue
            
            # Enqueue the request.
            request_ids = executor.enqueue_requests(requests)

            # Wait for the new tokens.
            responses = executor.await_responses(request_ids)
            
            for idx, re in enumerate(responses):
                output_tokens = re[0].result.output_token_ids[0]
                valid_output = all(el >= 0 and el < 200000 for el in output_tokens)
                if not valid_output:
                    print(f"InValid output produced for request {request_ids[idx]}.")
                    print(f"Output tokens : {output_tokens[200:]}")
                    exit(-1)
                else:
                    print(f"Valid output produced for request {request_ids[idx]}.")

We are still investigating the reason.

@byshiue byshiue self-assigned this May 23, 2024
@byshiue byshiue added the triaged Issue has been triaged by maintainers label May 23, 2024
@ChristinaZ
Copy link

Hi Alessio,
Thank you for finding this bug. We are looking into this issue. In case this bug becomes a bottleneck in your workflow, one workaround is to change the value of variable mIsAirTopP to false, TRT-LLM will adopt another top-p sampling method. We will try to fix the bug as soon as possible.

@nv-guomingz
Copy link
Collaborator

Hi @AlessioNetti do u still have further issue or question now? If not, we'll close it soon.

@AlessioNetti
Copy link
Contributor Author

Hi - the bug has been fixed a few versions back, so we can close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants