Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] ChatGLM-V1 multi-batch infer and batched greedy search …
Browse files Browse the repository at this point in the history
…generation (#700)
  • Loading branch information
zhentaoyu authored and VincyZhang committed Dec 20, 2023
1 parent 2ee9fec commit c9fb9d1
Show file tree
Hide file tree
Showing 15 changed files with 575 additions and 224 deletions.
6 changes: 5 additions & 1 deletion graph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ option(NE_GELU_VEC "neural_engine: enable vec in gelu"
if (NE_GELU_VEC)
add_compile_definitions(NE_GELU_USE_VEC)
endif()
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
if (NE_SIMD_VEC_DOT_F16)
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
endif()

if(NE_BUILD_TESTS)
enable_testing()
Expand Down
22 changes: 16 additions & 6 deletions graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs):

def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
self.batch_size = input_ids.shape[0]
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0],
self.init_from_bin(self.model_type, self.bin_file, batch_size=self.batch_size,
**generate_kwargs)
self.generate_round = 0
elif not interactive:
Expand All @@ -160,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
beam_search = False
if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False):
beam_search = True
if not beam_search:
# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."

if streamer:
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
Expand Down Expand Up @@ -190,9 +188,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
if stopping_criteria is not None:
if stopping_criteria(torch.tensor(ret), None):
break
elif ret[0][-1] == self.eos_token_id() or \
(max_new_tokens != -1 and out_count >= max_new_tokens):
elif (max_new_tokens != -1 and out_count >= max_new_tokens):
break
else:
all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) for r in ret]
if False not in all_done:
break
if streamer:
streamer.end()

Expand All @@ -206,6 +207,15 @@ def eos_token_id(self):
if self.model_type == 'qwen':
return self.tokenizer.special_tokens['<|endoftext|>']
return self.tokenizer.eos_token_id

def pad_token_id(self):
if self.tokenizer.pad_token_id == None:
if self.batch_size == 1:
return None
else:
raise ValueError("Please set pad_token_id when doing multi batch inference"\
" with padding!")
return self.tokenizer.pad_token_id

def __call__(self, input_ids, reinit=False, **kwargs):
if self.model is None:
Expand Down
Loading

0 comments on commit c9fb9d1

Please sign in to comment.