Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeShell多batch 结果出错,如何支持多batch推理? #63

Open
MeJerry215 opened this issue Dec 11, 2023 · 0 comments
Open

CodeShell多batch 结果出错,如何支持多batch推理? #63

MeJerry215 opened this issue Dec 11, 2023 · 0 comments

Comments

@MeJerry215
Copy link

MeJerry215 commented Dec 11, 2023

当前从examples 看是单batch的,如何能够使用多batch进行推理额,现在多batch 的结果好像不太对的样子。

from transformers import AutoModelForCausalLM, AutoTokenizer
import pdb
import torch
tokenizer = AutoTokenizer.from_pretrained("CodeShell-7B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("CodeShell-7B", torch_dtype=torch.float16, trust_remote_code=True).cuda()
examples = [
        "import math\ndef print_hello():",
        "import math\ndef quick_sort():",
        "import math\ndef test_quick_sort():",
        "import math\ndef test_print_hello():",
        "import math\ndef test_merge_sort():",
        "import math\ndef two_sum():",
        "import math\ndef preoder_transverse():",
        "import math\ndef merge_sort():",
    ]




inputs = tokenizer(examples, return_tensors='pt', padding=True)['input_ids'].cuda()
outputs = model.generate(inputs, max_new_tokens=128)
for output in outputs:
    print("=====================> ",tokenizer.decode(output))

测试代码如上,

测试结果有点奇怪

=====================>  import math
def print_hello():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix>  }
    }
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace _04.Longest_Increasing_Subsequence
{
    class Program
    {
        static void Main(string[] args)
        {
            int[] nums = Console.ReadLine().Split(' ').Select(int.Parse).ToArray();
            int[] len = new int[nums.Length];
            int[] prev = new int[nums.Length];
            int maxLen = 0;
=====================>  import math
def quick_sort():<|endoftext|><|endoftext|><|endoftext|><fim_prefix><fim_suffix>  }
    }
}<fim_middle>using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant