From 3bda3624ef92c4c5dd8552acd151427cb2471298 Mon Sep 17 00:00:00 2001 From: yukang Date: Tue, 3 Oct 2023 11:44:07 +0800 Subject: [PATCH] Update inference.py --- inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference.py b/inference.py index bdd3bb5a..2521bf7f 100644 --- a/inference.py +++ b/inference.py @@ -109,7 +109,7 @@ def main(args): if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p, - max_gen_len=args.max_gen_len, use_cache=not args.flash_attn) + max_gen_len=args.max_gen_len, use_cache=True) output = respond(args.material, args.question, args.material_type, args.material_title) print("output", output)