diff --git a/examples/generate.py b/examples/generate.py index 043e3487..c14f9086 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -2,28 +2,36 @@ from transformers import AutoTokenizer, TextStreamer -quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" +quant_path = "casperhansen/llama-3-8b-instruct-awq" # Load model model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) -# Convert prompt to tokens -prompt_template = "[INST] {prompt} [/INST]" - prompt = "You're standing on the surface of the Earth. "\ "You walk one mile south, one mile west and one mile north. "\ "You end up exactly where you started. Where are you?" -tokens = tokenizer( - prompt_template.format(prompt=prompt), - return_tensors='pt' -).input_ids.cuda() +chat = [ + {"role": "system", "content": "You are a concise assistant that helps answer questions."}, + {"role": "user", "content": prompt}, +] + +terminators = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +tokens = tokenizer.apply_chat_template( + chat, + return_tensors="pt" +).cuda() # Generate output generation_output = model.generate( tokens, streamer=streamer, - max_new_tokens=512 -) \ No newline at end of file + max_new_tokens=64, + eos_token_id=terminators +)