Skip to content

Commit

Permalink
Merge pull request #30 from blbadger/selected-pkv
Browse files Browse the repository at this point in the history
Renamed low memory flag
  • Loading branch information
blbadger authored Jul 20, 2023
2 parents f310f83 + b11c156 commit c619204
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,7 @@ def generate(
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
low_memory=generation_config.low_memory,
sequential=generation_config.low_memory,
**model_kwargs,
)

Expand Down Expand Up @@ -1827,7 +1827,7 @@ def contrastive_search(
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
low_memory: Optional[bool] = False,
sequential: Optional[bool] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1878,7 +1878,7 @@ def contrastive_search(
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
low_memory (`bool`, *optional*):
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
Expand Down Expand Up @@ -1919,6 +1919,7 @@ def contrastive_search(
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
sequential = sequential if sequential is not None else self.generation_config.low_memory
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
Expand Down Expand Up @@ -1994,7 +1995,7 @@ def contrastive_search(
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)
if not low_memory:
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
Expand Down Expand Up @@ -2047,14 +2048,14 @@ def contrastive_search(
items = []
# item is either the key or the value matrix
for item in layer:
if low_memory:
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values

if low_memory:
if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
all_last_hstates, all_hstates, all_logits = [], [], []
for i in range(top_k):
Expand Down Expand Up @@ -2137,7 +2138,7 @@ def contrastive_search(
next_decoder_hidden_states += (layer,)

# generate past_key_values cache of only the selected token
if low_memory:
if sequential:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
)
Expand Down

0 comments on commit c619204

Please sign in to comment.