diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f097d1e4a5f8..97aa95bdfe91 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1563,7 +1563,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, - low_memory=generation_config.low_memory, + sequential=generation_config.low_memory, **model_kwargs, ) @@ -1827,7 +1827,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - low_memory: Optional[bool] = False, + sequential: Optional[bool] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1878,7 +1878,7 @@ def contrastive_search( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - low_memory (`bool`, *optional*): + sequential (`bool`, *optional*): Switches topk hidden state computation from parallel to sequential to reduce memory if True. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. @@ -1919,6 +1919,7 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + sequential = sequential if sequential is not None else self.generation_config.low_memory if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None @@ -1994,7 +1995,7 @@ def contrastive_search( is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) - if not low_memory: + if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs @@ -2047,14 +2048,14 @@ def contrastive_search( items = [] # item is either the key or the value matrix for item in layer: - if low_memory: + if sequential: items.append(item.repeat_interleave(1, dim=0)) else: items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - if low_memory: + if sequential: all_outputs = {key: [] for key in outputs} # defined in first loop iteration all_last_hstates, all_hstates, all_logits = [], [], [] for i in range(top_k): @@ -2137,7 +2138,7 @@ def contrastive_search( next_decoder_hidden_states += (layer,) # generate past_key_values cache of only the selected token - if low_memory: + if sequential: next_model_input = self.prepare_inputs_for_generation( top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs )