From e9c05dc676a9da895b2ee002567153f15dee24e3 Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 6 Jun 2023 15:14:43 -0400 Subject: [PATCH 01/83] added hidden subset --- src/transformers/generation/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9bf9add17e1..599ee8935bb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1807,6 +1807,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, + subset_hidden: Optional[bool] = True, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1961,6 +1962,9 @@ def contrastive_search( last_hidden_states = outputs.decoder_hidden_states[-1] else: last_hidden_states = outputs.hidden_states[-1] + + if subset_hidden: + last_hidden_states = last_hidden_states[:, :, :100] # next logit for contrastive search to select top-k candidate tokens logit_for_next_step = outputs.logits[:, -1, :] From fc5216677ab8b7942572d5fa050a715538c3ff1a Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 6 Jun 2023 16:11:04 -0400 Subject: [PATCH 02/83] debugged hidden subset contrastive search --- src/transformers/generation/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 599ee8935bb..d9188b84f2e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1808,6 +1808,7 @@ def contrastive_search( synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, subset_hidden: Optional[bool] = True, + compress_hidden: Optional[bool] = False, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1932,6 +1933,14 @@ def contrastive_search( this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] + # compression mat mult init + if model.config.n_embd: + hidden_size = model.config.n_embd + elif model.config.hidden_size: + hidden_size = model.config.hidden_size + compression_factor = 8 + compressor = torch.nn.Linear(hidden_size, hidden_size//compression_factor, bias=False) + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1964,7 +1973,13 @@ def contrastive_search( last_hidden_states = outputs.hidden_states[-1] if subset_hidden: + print ('Using subset of last hidden layer') last_hidden_states = last_hidden_states[:, :, :100] + + elif compress_hidden: + print ('Using compressed last hidden layer') + last_hidden_states = compressor(last_hidden_states) + # next logit for contrastive search to select top-k candidate tokens logit_for_next_step = outputs.logits[:, -1, :] @@ -2047,6 +2062,10 @@ def contrastive_search( else: next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states + + if subset_hidden: + next_hidden = next_hidden[:, :, :100] + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the From 3906d608f176de70e0902db6e7c188096155daba Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 6 Jun 2023 17:02:53 -0400 Subject: [PATCH 03/83] added contrastive search compression --- src/transformers/generation/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d9188b84f2e..f0eab2eaaae 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1934,12 +1934,10 @@ def contrastive_search( batch_size = input_ids.shape[0] # compression mat mult init - if model.config.n_embd: - hidden_size = model.config.n_embd - elif model.config.hidden_size: - hidden_size = model.config.hidden_size + outputs = self(**model_inputs, output_hidden_states=True) + hidden_dim = outputs.hidden_states[-1].shape[-1] compression_factor = 8 - compressor = torch.nn.Linear(hidden_size, hidden_size//compression_factor, bias=False) + compressor = torch.nn.Linear(hidden_dim, hidden_dim//compression_factor, bias=False) while True: if synced_gpus: @@ -2066,6 +2064,9 @@ def contrastive_search( if subset_hidden: next_hidden = next_hidden[:, :, :100] + elif compress_hidden: + next_hidden = compressor(next_hidden) + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the From 2881aef1db6925376fe3956046d1cb4e419430a5 Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 6 Jun 2023 17:13:56 -0400 Subject: [PATCH 04/83] debugged compressed contrastive search --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f0eab2eaaae..71e18540726 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1934,6 +1934,7 @@ def contrastive_search( batch_size = input_ids.shape[0] # compression mat mult init + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, output_hidden_states=True) hidden_dim = outputs.hidden_states[-1].shape[-1] compression_factor = 8 From 7d29c5508871f6c5ed010e4a7dca861e258e826e Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 10:42:37 -0400 Subject: [PATCH 05/83] memory reduction for contrastive search --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 71e18540726..ecf06017725 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2048,6 +2048,8 @@ def contrastive_search( # compute the candidate tokens by the language model and collects their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + print (next_model_inputs.shape) + del outputs outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) From 57dfaacb82cf9d95d69f460241a7a1a7399278ac Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 11:03:18 -0400 Subject: [PATCH 06/83] debugged mem red --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ecf06017725..62a35209d29 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2048,8 +2048,10 @@ def contrastive_search( # compute the candidate tokens by the language model and collects their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - print (next_model_inputs.shape) + print (next_model_inputs) + print (next_model_inputs.input_ids.shape) del outputs + outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) From fd0e19f61b679ed108635ee355505a3240bdf887 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 12:06:16 -0400 Subject: [PATCH 07/83] added low memory option feature --- src/transformers/generation/utils.py | 55 ++++++++++++++++++---------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 62a35209d29..feea0ab1f4b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2036,27 +2036,42 @@ def contrastive_search( else (outputs.hidden_states,) ) - # Replicates the new past_key_values to match the `top_k` candidates - new_key_values = [] - for layer in model_kwargs["past_key_values"]: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(items) - model_kwargs["past_key_values"] = new_key_values - - # compute the candidate tokens by the language model and collects their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - print (next_model_inputs) - print (next_model_inputs.input_ids.shape) - del outputs - - outputs = self( - **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + # if the used memory exceeds a threshold, do not batch + if torch.cuda.get_mem_info()[0] < torch.cuda_get_mem_info()[1]//2: + outputs = dict() + for i in range(len(top_k_ids)) + # compute the candidate tokens by the language model and collect their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[i], **model_kwargs) + new_outputs = self( + **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + if not outputs: + outputs = new_outputs + else: + for key in outputs: + outputs[key] = torch.stack(outputs[key], new_outputs[key], dim=0) + print (outputs) + + else: + # Replicates the new past_key_values to match the `top_k` candidates + new_key_values = [] + for layer in model_kwargs["past_key_values"]: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(items) + model_kwargs["past_key_values"] = new_key_values + + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k (leading to OOM for large models) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] # name is different for encoder-decoder and decoder-only models if self.config.is_encoder_decoder: From 802cfd4b6efd0276dcb53d5bc609029c60b1c6e1 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 12:35:13 -0400 Subject: [PATCH 08/83] debugged mem optmimization output stack --- src/transformers/generation/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index feea0ab1f4b..bace0f9d9cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2038,19 +2038,19 @@ def contrastive_search( # if the used memory exceeds a threshold, do not batch if torch.cuda.get_mem_info()[0] < torch.cuda_get_mem_info()[1]//2: - outputs = dict() - for i in range(len(top_k_ids)) + all_outputs = {key:[] for key in outputs} # defined in first loop iteration + for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[i], **model_kwargs) - new_outputs = self( + outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) - if not outputs: - outputs = new_outputs - else: - for key in outputs: - outputs[key] = torch.stack(outputs[key], new_outputs[key], dim=0) - print (outputs) + for key in all_outputs: + all_outputs[key].append(outputs[key]) + + for key in all_outputs: + all_outputs[key] = torch.stack(all_outputs[key], dim=0) + print (outputs) else: # Replicates the new past_key_values to match the `top_k` candidates From 0632f0611aefe346c99e43d2e25b3408fe8d885a Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 12:37:13 -0400 Subject: [PATCH 09/83] debugged mem optmimization output stack --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bace0f9d9cd..023ca3edc76 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2037,7 +2037,7 @@ def contrastive_search( ) # if the used memory exceeds a threshold, do not batch - if torch.cuda.get_mem_info()[0] < torch.cuda_get_mem_info()[1]//2: + if torch.cuda.get_mem_info()[0] < torch.cuda.get_mem_info()[1]//2: all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states From 9bad256efb5c5cae15534ddf1a119d17b6d2846a Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 13:05:04 -0400 Subject: [PATCH 10/83] debugged low mem --- src/transformers/generation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 023ca3edc76..97e316d5012 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1809,6 +1809,7 @@ def contrastive_search( streamer: Optional["BaseStreamer"] = None, subset_hidden: Optional[bool] = True, compress_hidden: Optional[bool] = False, + low_memory: Optional[bool] = True, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -2037,7 +2038,7 @@ def contrastive_search( ) # if the used memory exceeds a threshold, do not batch - if torch.cuda.get_mem_info()[0] < torch.cuda.get_mem_info()[1]//2: + if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states From a89bb8e8a2c37809b2cfb1993986ce4d9f8264d5 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 14:18:39 -0400 Subject: [PATCH 11/83] added low mem cache --- src/transformers/generation/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97e316d5012..4b407791e45 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1989,11 +1989,11 @@ def contrastive_search( is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) - - # Expands model inputs top_k times, for batched forward passes (akin to beam search). - _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) + if not low_memory: + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) past_key_values = model_kwargs.get("past_key_values") if past_key_values is None: @@ -2036,6 +2036,8 @@ def contrastive_search( if self.config.is_encoder_decoder else (outputs.hidden_states,) ) + + print (top_k_ids) # if the used memory exceeds a threshold, do not batch if low_memory: From f90f948946c7fa7683e895e9df7b7f978c0c9279 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 14:54:25 -0400 Subject: [PATCH 12/83] fixed 2047 tensor view --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4b407791e45..0432eccd247 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2044,7 +2044,7 @@ def contrastive_search( all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[i], **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1)[i], **model_kwargs) outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) From e1718c372a6318b6ad1ef67176cf0448a5321f55 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 15:34:46 -0400 Subject: [PATCH 13/83] debugged 2042 past key val inputs --- src/transformers/generation/utils.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0432eccd247..5bc9bfeade5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2036,8 +2036,20 @@ def contrastive_search( if self.config.is_encoder_decoder else (outputs.hidden_states,) ) - - print (top_k_ids) + + # Replicates the new past_key_values to match the `top_k` candidates + new_key_values = [] + for layer in model_kwargs["past_key_values"]: + items = [] + # item is either the key or the value matrix + for item in layer: + if low_memory: + items.append(item.repeat_interleave(1, dim=0)) + else: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(items) + model_kwargs["past_key_values"] = new_key_values + print (top_k_ids.view(-1, 1)) # if the used memory exceeds a threshold, do not batch if low_memory: @@ -2056,16 +2068,6 @@ def contrastive_search( print (outputs) else: - # Replicates the new past_key_values to match the `top_k` candidates - new_key_values = [] - for layer in model_kwargs["past_key_values"]: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(items) - model_kwargs["past_key_values"] = new_key_values - # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k (leading to OOM for large models) next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) From 3fd54e679cd9936c8b1d97306a2a1088d14ffe07 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 16:27:34 -0400 Subject: [PATCH 14/83] reformatted tensors --- src/transformers/generation/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5bc9bfeade5..a23096e0051 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2049,14 +2049,15 @@ def contrastive_search( items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - print (top_k_ids.view(-1, 1)) + print (top_k_ids) + print (top_k_ids[:, 0].unsqueeze(0)) # if the used memory exceeds a threshold, do not batch if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1)[i], **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), **model_kwargs) outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) From 12d5aea586c52609713e2bf0eebd9e579ead309c Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 16:47:28 -0400 Subject: [PATCH 15/83] changed low mem output --- src/transformers/generation/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a23096e0051..c9cc2ad6ab7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2049,8 +2049,6 @@ def contrastive_search( items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - print (top_k_ids) - print (top_k_ids[:, 0].unsqueeze(0)) # if the used memory exceeds a threshold, do not batch if low_memory: @@ -2065,7 +2063,8 @@ def contrastive_search( all_outputs[key].append(outputs[key]) for key in all_outputs: - all_outputs[key] = torch.stack(all_outputs[key], dim=0) + if torch.is_tensor(all_outputs[key]): + outputs[key] = torch.stack(all_outputs[key], dim=0) print (outputs) else: From 44a9ec4260e46a89cad62c71cf8e3c654b4f55d1 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 7 Jun 2023 17:46:05 -0400 Subject: [PATCH 16/83] final clean --- src/transformers/generation/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c9cc2ad6ab7..d30ac8d9b21 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1807,7 +1807,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - subset_hidden: Optional[bool] = True, + subset_hidden: Optional[bool] = False, compress_hidden: Optional[bool] = False, low_memory: Optional[bool] = True, **model_kwargs, @@ -1974,7 +1974,7 @@ def contrastive_search( if subset_hidden: print ('Using subset of last hidden layer') - last_hidden_states = last_hidden_states[:, :, :100] + last_hidden_states = last_hidden_states[:, :, :500] elif compress_hidden: print ('Using compressed last hidden layer') @@ -2065,7 +2065,6 @@ def contrastive_search( for key in all_outputs: if torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) - print (outputs) else: # compute the candidate tokens by the language model and collect their hidden_states @@ -2087,7 +2086,7 @@ def contrastive_search( full_hidden_states = outputs.hidden_states if subset_hidden: - next_hidden = next_hidden[:, :, :100] + next_hidden = next_hidden[:, :, :500] elif compress_hidden: next_hidden = compressor(next_hidden) From 37bb62d9d5e0599813986ccd2f0bb1079f616078 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 10:58:38 -0400 Subject: [PATCH 17/83] removed subset hidden csearch --- src/transformers/generation/utils.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d30ac8d9b21..e810dfecb65 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1807,7 +1807,6 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - subset_hidden: Optional[bool] = False, compress_hidden: Optional[bool] = False, low_memory: Optional[bool] = True, **model_kwargs, @@ -1934,12 +1933,13 @@ def contrastive_search( this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] - # compression mat mult init - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, output_hidden_states=True) - hidden_dim = outputs.hidden_states[-1].shape[-1] - compression_factor = 8 - compressor = torch.nn.Linear(hidden_dim, hidden_dim//compression_factor, bias=False) + # compression initialization + if compress_hidden: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self(**model_inputs, output_hidden_states=True) + hidden_dim = outputs.hidden_states[-1].shape[-1] + compression_factor = 8 + compressor = torch.nn.Linear(hidden_dim, hidden_dim//compression_factor, bias=False, device=hidden_dim.device) while True: if synced_gpus: @@ -1972,11 +1972,7 @@ def contrastive_search( else: last_hidden_states = outputs.hidden_states[-1] - if subset_hidden: - print ('Using subset of last hidden layer') - last_hidden_states = last_hidden_states[:, :, :500] - - elif compress_hidden: + if compress_hidden: print ('Using compressed last hidden layer') last_hidden_states = compressor(last_hidden_states) @@ -2013,7 +2009,6 @@ def contrastive_search( # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - logit_for_next_step = logits_processor(input_ids, logit_for_next_step) logit_for_next_step = logits_warper(input_ids, logit_for_next_step) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) @@ -2085,10 +2080,7 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - if subset_hidden: - next_hidden = next_hidden[:, :, :500] - - elif compress_hidden: + if compress_hidden: next_hidden = compressor(next_hidden) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 68c1cd86494b037bfb946c29fdf33edfd9beaf85 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 11:18:55 -0400 Subject: [PATCH 18/83] fixed hidden device --- src/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e810dfecb65..8b6bca21b05 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1938,8 +1938,8 @@ def contrastive_search( model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, output_hidden_states=True) hidden_dim = outputs.hidden_states[-1].shape[-1] - compression_factor = 8 - compressor = torch.nn.Linear(hidden_dim, hidden_dim//compression_factor, bias=False, device=hidden_dim.device) + r = 8 + compressor = torch.nn.Linear(hidden_dim, hidden_dim//r, bias=False, device=outputs.hidden_states.device) while True: if synced_gpus: From e199ddc950956cd6728d84aff18d345128b0e2d9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 11:29:37 -0400 Subject: [PATCH 19/83] fixed hidden device --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8b6bca21b05..3e8dc15e3a8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1939,7 +1939,7 @@ def contrastive_search( outputs = self(**model_inputs, output_hidden_states=True) hidden_dim = outputs.hidden_states[-1].shape[-1] r = 8 - compressor = torch.nn.Linear(hidden_dim, hidden_dim//r, bias=False, device=outputs.hidden_states.device) + compressor = torch.nn.Linear(hidden_dim, hidden_dim//r, bias=False, device=outputs.hidden_states[-1].device) while True: if synced_gpus: From 8ace5a32eef3513e2bc86ce0af4610db1b7a3789 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 11:46:36 -0400 Subject: [PATCH 20/83] changed compressor dtype --- src/transformers/generation/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3e8dc15e3a8..4bf5f2ffe37 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1939,7 +1939,12 @@ def contrastive_search( outputs = self(**model_inputs, output_hidden_states=True) hidden_dim = outputs.hidden_states[-1].shape[-1] r = 8 - compressor = torch.nn.Linear(hidden_dim, hidden_dim//r, bias=False, device=outputs.hidden_states[-1].device) + compressor = torch.nn.Linear(hidden_dim, + hidden_dim//r, + bias=False, + device=outputs.hidden_states[-1].device, + dtype=outputs.hidden_states[-1].dtype + ) while True: if synced_gpus: From 1ac80a09db4fa89dd5e586218699667385b2f9ba Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 12:16:33 -0400 Subject: [PATCH 21/83] removed hstate compression --- src/transformers/generation/utils.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4bf5f2ffe37..80280d8f2d3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1807,8 +1807,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - compress_hidden: Optional[bool] = False, - low_memory: Optional[bool] = True, + low_memory: Optional[bool] = False, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1859,6 +1858,8 @@ def contrastive_search( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + low_memory (`bool`, *optional*): + Switches topk hidden state computation from parallel to sequential to reduce memory if True. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1933,19 +1934,6 @@ def contrastive_search( this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] - # compression initialization - if compress_hidden: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, output_hidden_states=True) - hidden_dim = outputs.hidden_states[-1].shape[-1] - r = 8 - compressor = torch.nn.Linear(hidden_dim, - hidden_dim//r, - bias=False, - device=outputs.hidden_states[-1].device, - dtype=outputs.hidden_states[-1].dtype - ) - while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1977,10 +1965,6 @@ def contrastive_search( else: last_hidden_states = outputs.hidden_states[-1] - if compress_hidden: - print ('Using compressed last hidden layer') - last_hidden_states = compressor(last_hidden_states) - # next logit for contrastive search to select top-k candidate tokens logit_for_next_step = outputs.logits[:, -1, :] @@ -2085,9 +2069,6 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - if compress_hidden: - next_hidden = compressor(next_hidden) - context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the From 1c3aae7d8f8f65f4d74c4081724801081b83cf2f Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 12:43:16 -0400 Subject: [PATCH 22/83] integrated csearch in generate --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 80280d8f2d3..193bafafcc8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1549,6 +1549,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, + *low_memory=low_memory, **model_kwargs, ) From f18bccd67c9bc7712f7baff85a94946395a0e9e9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 12:56:00 -0400 Subject: [PATCH 23/83] test csearch integration into generation exit() --- src/transformers/generation/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 193bafafcc8..c81f784eacd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1549,7 +1549,6 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, - *low_memory=low_memory, **model_kwargs, ) @@ -1808,7 +1807,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - low_memory: Optional[bool] = False, + low_memory: Optional[bool] = True, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" From abf0a72e2e586c287249285a3101c1a52bc59cfc Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 13:12:36 -0400 Subject: [PATCH 24/83] fixed csearch kwarg integration with generation --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c81f784eacd..ac660deac95 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1152,6 +1152,7 @@ def generate( prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, + low_memory: Optional[bool] = False, streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: @@ -1549,6 +1550,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, + low_memory=low_memory, **model_kwargs, ) From e517d5f3a24b289645ecdc9820c52da741541e67 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 8 Jun 2023 13:33:58 -0400 Subject: [PATCH 25/83] final wrap and added doc --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ac660deac95..81c91f29d83 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1211,6 +1211,8 @@ def generate( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + low_memory (`bool`, *optional*): + Switch to sequential topk for contrastive search to reduce peak memory requirements. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder From cc1ea6d4c63a2042d0b4cc8fb5f53c4858bb5d45 Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:25:09 -0400 Subject: [PATCH 26/83] Update src/transformers/generation/utils.py Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 81c91f29d83..d9a5049d88d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2038,7 +2038,6 @@ def contrastive_search( new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - # if the used memory exceeds a threshold, do not batch if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): From bd2e36b1cf4934789b4cbc5d508643dad987fafa Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:26:25 -0400 Subject: [PATCH 27/83] Update src/transformers/generation/utils.py Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d9a5049d88d..9a8c5b1ee5e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1212,7 +1212,7 @@ def generate( Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. low_memory (`bool`, *optional*): - Switch to sequential topk for contrastive search to reduce peak memory requirements. + Switch to sequential topk for contrastive search to reduce peak memory requirements. Used with contrastive search. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder From b59ec6d17145911f03c3ea3d08a13ed05ce9d38a Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:32:32 -0400 Subject: [PATCH 28/83] Update src/transformers/generation/utils.py Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9a8c5b1ee5e..d4a5b8c5b51 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1811,7 +1811,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - low_memory: Optional[bool] = True, + low_memory: Optional[bool] = False, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" From a7fb76e6742d7b74a8ce3b733f47261e664152cd Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 15 Jun 2023 22:48:41 -0400 Subject: [PATCH 29/83] added debug print --- src/transformers/generation/utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 81c91f29d83..59e47b84212 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1811,7 +1811,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - low_memory: Optional[bool] = True, + low_memory: Optional[bool] = False, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -2038,14 +2038,17 @@ def contrastive_search( new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - # if the used memory exceeds a threshold, do not batch if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), + **model_kwargs) outputs = self( - **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions ) for key in all_outputs: all_outputs[key].append(outputs[key]) @@ -2060,7 +2063,10 @@ def contrastive_search( next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) outputs = self( - **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions ) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) @@ -2073,6 +2079,7 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states + print (next_hidden.shape) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the From 961a1babccfaf30e31dbf015e1ff17276dc4c5ac Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 15 Jun 2023 22:59:50 -0400 Subject: [PATCH 30/83] direct hstate cat --- src/transformers/generation/utils.py | 35 ++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 59e47b84212..07202f404ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2040,6 +2040,7 @@ def contrastive_search( if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration + all_last_hstates = [] for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), @@ -2053,6 +2054,23 @@ def contrastive_search( for key in all_outputs: all_outputs[key].append(outputs[key]) + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + all_last_hstates.append(next_hidden) + all_hstates.append(full_hidden_states) + + # stack hidden states + for i in range(top_k): + next_hidden = torch.stack(all_last_hstates[i], dim=0) + full_hidden_states = torch.stack(all_hstates[i], dim=0) + + # stack all_outputs attentions for key in all_outputs: if torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) @@ -2068,16 +2086,19 @@ def contrastive_search( output_hidden_states=True, output_attentions=output_attentions ) + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + print (next_hidden.shape, full_hidden_states.shape) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states + print (next_hidden.shape) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 882b6d2fdd39f0a35404fa5497a688986c37fe4f Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 15 Jun 2023 23:04:55 -0400 Subject: [PATCH 31/83] direct hstate cat --- src/transformers/generation/utils.py | 8 +++----- tests/generation/test_utils.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 07202f404ff..289bb0f2222 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2041,6 +2041,7 @@ def contrastive_search( if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration all_last_hstates = [] + all_hstates = [] for i in range(len(top_k_ids)): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), @@ -2066,9 +2067,8 @@ def contrastive_search( all_hstates.append(full_hidden_states) # stack hidden states - for i in range(top_k): - next_hidden = torch.stack(all_last_hstates[i], dim=0) - full_hidden_states = torch.stack(all_hstates[i], dim=0) + next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) + full_hidden_states = torch.stack([all_hstates[i] for i in range(top_k)], dim=0) # stack all_outputs attentions for key in all_outputs: @@ -2099,8 +2099,6 @@ def contrastive_search( next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] - - print (next_hidden.shape) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5f835917ea0..e6ab884ed2f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1457,6 +1457,36 @@ def test_contrastive_generate_dict_outputs_use_cache(self): for output in (output_contrastive, output_generate): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_contrastive_generate_low_memory(self): + # Check that choosing 'low_memory' reduces memory overhead and does not change the model output + for model_class in self.all_generative_model_classes: + # won't fix: FSMT and Reformer have a different cache variable type (and format). + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + return + + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # NOTE: contrastive search only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + config.use_cache = True + config.is_decoder = True + + prompt = "The rain in Spain" + + # test output equality of low versus high memory + model = model_class(config).to(torch_device).eval() + low_output = model.generate(input_ids, top_k=4, penalty_alpha=0.6, low_memory=True) + lmem_usage = torch.cuda.max_memory_allocated() + + high_output = model.generate(input_ids, top_k=4, penalty_alpha=0.6, low_memory=False) + hmem_usage = torch.cuda.max_memory_allocated() + self.assertListEqual(low_mem_output.tolist(), high_mem_output.tolist()) + assert lmem_usage < hmem_usage + + return + + @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. From c3f3db3769238cfea2c75bc4fe267487be7d5e71 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 15 Jun 2023 23:12:41 -0400 Subject: [PATCH 32/83] direct hstate cat debug --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 289bb0f2222..3e9b5798225 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2066,6 +2066,8 @@ def contrastive_search( all_last_hstates.append(next_hidden) all_hstates.append(full_hidden_states) + print (len(all_last_hstates)) + print (len(all_hstates)) # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) full_hidden_states = torch.stack([all_hstates[i] for i in range(top_k)], dim=0) From 692b5e10783a42c8672563a6699ca512e12c7116 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 15 Jun 2023 23:16:01 -0400 Subject: [PATCH 33/83] direct hstate cat debug --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3e9b5798225..228c584b772 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2042,7 +2042,7 @@ def contrastive_search( all_outputs = {key:[] for key in outputs} # defined in first loop iteration all_last_hstates = [] all_hstates = [] - for i in range(len(top_k_ids)): + for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), **model_kwargs) @@ -2066,8 +2066,6 @@ def contrastive_search( all_last_hstates.append(next_hidden) all_hstates.append(full_hidden_states) - print (len(all_last_hstates)) - print (len(all_hstates)) # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) full_hidden_states = torch.stack([all_hstates[i] for i in range(top_k)], dim=0) From 349bbf95a17319711bcd52893a082566af28eadc Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 10:30:06 -0400 Subject: [PATCH 34/83] expanded full hidden state stack --- src/transformers/generation/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 228c584b772..2b806b2ef80 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2068,7 +2068,13 @@ def contrastive_search( # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) - full_hidden_states = torch.stack([all_hstates[i] for i in range(top_k)], dim=0) + print (next_hidden.shape) + + for layer in range(len(full_hidden_states)): + full_hidden_states[layer] = torch.stack([all_hstates[layer][i] for i in range(top_k)], dim=0) + + print (full_hidden_states[0].shape) + # full_hidden_states = tuple(map(torch.stack, zip(*full_hidden_states))) # stack all_outputs attentions for key in all_outputs: From cd4bed0e791239dbffa8a2910afe71843c7af28b Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 10:32:29 -0400 Subject: [PATCH 35/83] expanded full hidden state stack --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2b806b2ef80..4060fe9a811 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2100,7 +2100,7 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - print (next_hidden.shape, full_hidden_states.shape) + print (next_hidden.shape, full_hidden_states[0].shape) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] From ae41c50c3dcdd7c0b91fddab8d295317ada5ff15 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 10:41:31 -0400 Subject: [PATCH 36/83] matched dims for hstates --- src/transformers/generation/utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4060fe9a811..68a096628ae 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2063,18 +2063,16 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - all_last_hstates.append(next_hidden) + all_last_hstates.append(torch.squeeze(next_hidden, 1)) all_hstates.append(full_hidden_states) # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) - print (next_hidden.shape) - + final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): - full_hidden_states[layer] = torch.stack([all_hstates[layer][i] for i in range(top_k)], dim=0) - - print (full_hidden_states[0].shape) - # full_hidden_states = tuple(map(torch.stack, zip(*full_hidden_states))) + final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[layer][i], 1) + for i in range(top_k)], dim=0) + full_hidden_states = tuple(final_full_hstates) # stack all_outputs attentions for key in all_outputs: From 30baaa650f2f718754ab238dd9bd92d1c78d37af Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 10:45:42 -0400 Subject: [PATCH 37/83] matched dims for hstates --- src/transformers/generation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 68a096628ae..8c85943f6f6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2068,9 +2068,10 @@ def contrastive_search( # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) + print (next_hidden.shape) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): - final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[layer][i], 1) + final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 1) for i in range(top_k)], dim=0) full_hidden_states = tuple(final_full_hstates) From ebc19ffaaf9e639429a9423cb785e51b9b0af04a Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 10:59:42 -0400 Subject: [PATCH 38/83] logits fix --- src/transformers/generation/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8c85943f6f6..9992d66292a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2042,6 +2042,7 @@ def contrastive_search( all_outputs = {key:[] for key in outputs} # defined in first loop iteration all_last_hstates = [] all_hstates = [] + all_logits = [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), @@ -2065,10 +2066,10 @@ def contrastive_search( all_last_hstates.append(torch.squeeze(next_hidden, 1)) all_hstates.append(full_hidden_states) + all_logits.append(outputs.logits[:, -1, :]) # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) - print (next_hidden.shape) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 1) @@ -2079,6 +2080,9 @@ def contrastive_search( for key in all_outputs: if torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) + + logits = torch.cat(all_logits, dim=0) + print (logits.shape) else: # compute the candidate tokens by the language model and collect their hidden_states @@ -2099,11 +2103,10 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - print (next_hidden.shape, full_hidden_states[0].shape) + logits = outputs.logits[:, -1, :] + print (logits.shape) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) - logits = outputs.logits[:, -1, :] - context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the From 752a488a3d943767a80bfa8cc59e2d13e3c895e9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:07:47 -0400 Subject: [PATCH 39/83] equality test --- src/transformers/generation/utils.py | 50 +++++++++++++++------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9992d66292a..d0080b88665 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2040,9 +2040,7 @@ def contrastive_search( if low_memory: all_outputs = {key:[] for key in outputs} # defined in first loop iteration - all_last_hstates = [] - all_hstates = [] - all_logits = [] + all_last_hstates, all_hstates, all_logits = [], [], [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), @@ -2069,7 +2067,7 @@ def contrastive_search( all_logits.append(outputs.logits[:, -1, :]) # stack hidden states - next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) + cnext_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 1) @@ -2081,30 +2079,34 @@ def contrastive_search( if torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) - logits = torch.cat(all_logits, dim=0) - print (logits.shape) + coutputs = torch.clone(outputs) + # stack logits + clogits = torch.cat(all_logits, dim=0) + # else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k (leading to OOM for large models) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions + ) + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k (leading to OOM for large models) - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions - ) - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states + logits = outputs.logits[:, -1, :] - logits = outputs.logits[:, -1, :] - print (logits.shape) + print (f'Logit equality: {torch.equal(clogits, logits)}') + print (f'Attentions equality: {torch.equal(outputs[key], coutputs[key])}') + print (f'Last hidden equality: {torch.equal(cnext_hidden, next_hidden)}') next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 4f973ba98b7094eb79ffa9b9042e9ca41a77d94f Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:12:44 -0400 Subject: [PATCH 40/83] equality hidden debug --- src/transformers/generation/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d0080b88665..0afffb525ec 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2079,7 +2079,6 @@ def contrastive_search( if torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) - coutputs = torch.clone(outputs) # stack logits clogits = torch.cat(all_logits, dim=0) @@ -2105,7 +2104,7 @@ def contrastive_search( logits = outputs.logits[:, -1, :] print (f'Logit equality: {torch.equal(clogits, logits)}') - print (f'Attentions equality: {torch.equal(outputs[key], coutputs[key])}') + # print (f'Attentions equality: {torch.equal(outputs[key], coutputs[key])}') print (f'Last hidden equality: {torch.equal(cnext_hidden, next_hidden)}') next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) From b8094156b24e8f94d7fcae3ed53c5ba2bd36b94a Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:24:20 -0400 Subject: [PATCH 41/83] debug --- src/transformers/generation/utils.py | 45 +++++++++++++--------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0afffb525ec..05e97869dd2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2067,7 +2067,7 @@ def contrastive_search( all_logits.append(outputs.logits[:, -1, :]) # stack hidden states - cnext_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) + next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 1) @@ -2080,33 +2080,30 @@ def contrastive_search( outputs[key] = torch.stack(all_outputs[key], dim=0) # stack logits - clogits = torch.cat(all_logits, dim=0) + logits = torch.cat(all_logits, dim=0) - # else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k (leading to OOM for large models) - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions - ) - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - - logits = outputs.logits[:, -1, :] + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k (leading to OOM for large models) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - print (f'Logit equality: {torch.equal(clogits, logits)}') - # print (f'Attentions equality: {torch.equal(outputs[key], coutputs[key])}') - print (f'Last hidden equality: {torch.equal(cnext_hidden, next_hidden)}') + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions + ) + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + logits = outputs.logits[:, -1, :] + + print (logits, next_hidden, full_hidden_states) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 9230061cb6e40f2562d9224d90492550c5636d73 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:37:38 -0400 Subject: [PATCH 42/83] added prints for debug --- src/transformers/generation/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 05e97869dd2..5f6377b2379 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2102,14 +2102,16 @@ def contrastive_search( full_hidden_states = outputs.hidden_states logits = outputs.logits[:, -1, :] - - print (logits, next_hidden, full_hidden_states) + + print ('Logits: ', logits) + print ('last hidden start', next_hidden[0]) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + print ('index: ', selected_index) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From 2863471ba83a19783c5edd8bcb9082bf2c1ab9f9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:41:09 -0400 Subject: [PATCH 43/83] added prints for debug --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5f6377b2379..eec9addd278 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2111,7 +2111,7 @@ def contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) - print ('index: ', selected_index) + print ('index: ', selected_idx) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From e65335368e4f37ce751505fbdacd3d5d268bcf34 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 11:52:17 -0400 Subject: [PATCH 44/83] equality check --- src/transformers/generation/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eec9addd278..7dd63a099bb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2103,8 +2103,6 @@ def contrastive_search( logits = outputs.logits[:, -1, :] - print ('Logits: ', logits) - print ('last hidden start', next_hidden[0]) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) @@ -2112,6 +2110,7 @@ def contrastive_search( # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) print ('index: ', selected_idx) + print ('next_hidden', next_hidden) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From d790ea589f20a4eef7fd8c4131f23c2b12a058e5 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 12:19:21 -0400 Subject: [PATCH 45/83] switched squeeze dim --- src/transformers/generation/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7dd63a099bb..f74ca81b3dc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2062,7 +2062,7 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - all_last_hstates.append(torch.squeeze(next_hidden, 1)) + all_last_hstates.append(torch.squeeze(next_hidden, 0)) all_hstates.append(full_hidden_states) all_logits.append(outputs.logits[:, -1, :]) @@ -2070,7 +2070,7 @@ def contrastive_search( next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): - final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 1) + final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0) full_hidden_states = tuple(final_full_hstates) @@ -2109,8 +2109,6 @@ def contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) - print ('index: ', selected_idx) - print ('next_hidden', next_hidden) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From f1942219088e86d13db2ba691a4cf3e635a83617 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 12:23:25 -0400 Subject: [PATCH 46/83] input format debug --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f74ca81b3dc..6830349de12 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2045,6 +2045,7 @@ def contrastive_search( # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), **model_kwargs) + print (next_model_inputs) outputs = self( **next_model_inputs, return_dict=True, @@ -2086,6 +2087,7 @@ def contrastive_search( # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k (leading to OOM for large models) next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + print (next_model_inputs) outputs = self( **next_model_inputs, From 665c323c601262fc2832d4cac0e2dc71556a10d8 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 15:08:50 -0400 Subject: [PATCH 47/83] tracing top_k_ids --- src/transformers/generation/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6830349de12..03e216eb4cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2045,7 +2045,7 @@ def contrastive_search( # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), **model_kwargs) - print (next_model_inputs) + outputs = self( **next_model_inputs, return_dict=True, @@ -2087,7 +2087,6 @@ def contrastive_search( # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k (leading to OOM for large models) next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - print (next_model_inputs) outputs = self( **next_model_inputs, @@ -2105,6 +2104,8 @@ def contrastive_search( logits = outputs.logits[:, -1, :] + print (top_k_ids) + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 6259b56d6c7cf1784bb04a049113d10c27858aac Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 20:22:55 -0400 Subject: [PATCH 48/83] removed trace --- src/transformers/generation/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 03e216eb4cd..01724789f16 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2104,8 +2104,6 @@ def contrastive_search( logits = outputs.logits[:, -1, :] - print (top_k_ids) - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 6d2734cfa97974714d387ca502ad7dabc56e8014 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 16 Jun 2023 20:46:31 -0400 Subject: [PATCH 49/83] added test context --- tests/generation/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 357d25511d6..df42413665f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1481,7 +1481,10 @@ def test_contrastive_generate_low_memory(self): high_output = model.generate(input_ids, top_k=4, penalty_alpha=0.6, low_memory=False) hmem_usage = torch.cuda.max_memory_allocated() + + # will usually fail due to the propegation of batch vs unbatched forward pass numerical errors self.assertListEqual(low_mem_output.tolist(), high_mem_output.tolist()) + assert lmem_usage < hmem_usage return From 4033b191b2019b93bce2efadb37b75a6aa0c3e0e Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 17 Jun 2023 21:25:26 -0400 Subject: [PATCH 50/83] added jitter --- src/transformers/generation/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index de3e613a23f..508856c9c2c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2106,7 +2106,12 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] + next_hidden = next_hidden + (torch.randn(next_hidden.shape) / 100).to(input_ids.device) + for i in full_hidden_states: + full_hidden_states[i] += (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device) + + + logits = outputs.logits[:, -1, :] + torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From e2051a727be37645920ea8f911f63a0dabc1c059 Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 17 Jun 2023 21:27:34 -0400 Subject: [PATCH 51/83] added jitter --- src/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 508856c9c2c..34b3c04623d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2106,8 +2106,8 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - next_hidden = next_hidden + (torch.randn(next_hidden.shape) / 100).to(input_ids.device) - for i in full_hidden_states: + next_hidden = next_hidden + (torch.randn(next_hidden.shape)/100).to(input_ids.device) + for i in range(len(full_hidden_states)): full_hidden_states[i] += (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device) From e8f4cd1e6c885a9db59c819ac0234859e325ed22 Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 17 Jun 2023 21:30:13 -0400 Subject: [PATCH 52/83] added jitter --- src/transformers/generation/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 34b3c04623d..b43eda8f258 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2107,9 +2107,10 @@ def contrastive_search( full_hidden_states = outputs.hidden_states next_hidden = next_hidden + (torch.randn(next_hidden.shape)/100).to(input_ids.device) + final = [] for i in range(len(full_hidden_states)): - full_hidden_states[i] += (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device) - + final.append(full_hidden_states[i] + (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device)) + full_hidden_states = tuple(final) logits = outputs.logits[:, -1, :] + torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) @@ -2119,6 +2120,7 @@ def contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + print (selected_idx) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From 6bed197140c392ef860461810252f9b833083742 Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 17 Jun 2023 21:38:23 -0400 Subject: [PATCH 53/83] returned state --- src/transformers/generation/utils.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b43eda8f258..d5bdaffcc48 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2106,13 +2106,27 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - next_hidden = next_hidden + (torch.randn(next_hidden.shape)/100).to(input_ids.device) + next_hidden = next_hidden #+ (torch.randn(next_hidden.shape)/100).to(input_ids.device) final = [] for i in range(len(full_hidden_states)): - final.append(full_hidden_states[i] + (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device)) + final.append(full_hidden_states[i]) #+ (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device)) full_hidden_states = tuple(final) - logits = outputs.logits[:, -1, :] + torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) + logits = outputs.logits[:, -1, :] #+ torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) + + if low_memory: + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + + new_key_values = [] + for layer in model_kwargs["past_key_values"]: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(items) + model_kwargs["past_key_values"] = new_key_values next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) @@ -2120,7 +2134,7 @@ def contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) - print (selected_idx) + print (selected_idx.item) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From 67946f214dd8c3ff53aa308104dc4ff513068224 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 10:05:23 -0400 Subject: [PATCH 54/83] rebuilt past key value reconstruction --- src/transformers/generation/utils.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d5bdaffcc48..e1a543e5ab7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2054,7 +2054,8 @@ def contrastive_search( **next_model_inputs, return_dict=True, output_hidden_states=True, - output_attentions=output_attentions + output_attentions=output_attentions, + past_key_values=model.past_key_values ) for key in all_outputs: all_outputs[key].append(outputs[key]) @@ -2079,9 +2080,22 @@ def contrastive_search( for i in range(top_k)], dim=0) full_hidden_states = tuple(final_full_hstates) - # stack all_outputs attentions for key in all_outputs: - if torch.is_tensor(all_outputs[key]): + # rebuild key value output + if key == 'past_key_values': + layers_kv = [] + for layer in range(len(all_outputs[key][0])): + kv = [] + kv.append(torch.cat([all_outputs[key][seq][layer][0] + for seq in range(len(all_outputs[key]))], dim=0)) + + kv.append(torch.cat([all_outputs[key][seq][layer][1] + for seq in range(len(all_outputs[key]))], dim=0)) + + layers_kv.append(tuple(kv)) + outputs[key] = tuple(layers_kv) + + elif torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) # stack logits @@ -2114,20 +2128,6 @@ def contrastive_search( logits = outputs.logits[:, -1, :] #+ torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) - if low_memory: - _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) - - new_key_values = [] - for layer in model_kwargs["past_key_values"]: - items = [] - # item is either the key or the value matrix - for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(items) - model_kwargs["past_key_values"] = new_key_values - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 3dbd776268b76e331692c1e197b36d53ec2dff10 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 10:09:27 -0400 Subject: [PATCH 55/83] debugged --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e1a543e5ab7..7373ec45d73 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2055,7 +2055,6 @@ def contrastive_search( return_dict=True, output_hidden_states=True, output_attentions=output_attentions, - past_key_values=model.past_key_values ) for key in all_outputs: all_outputs[key].append(outputs[key]) From 547df692791ee40bc2bccff4f534c8eb68373a11 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 10:35:13 -0400 Subject: [PATCH 56/83] cleaned traces --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7373ec45d73..b86160945d6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2133,7 +2133,6 @@ def contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) - print (selected_idx.item) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores From f4b1f28401ccd93e2e0d00ac03d33ac823ad0ceb Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 10:52:16 -0400 Subject: [PATCH 57/83] added selection for pkv --- src/transformers/generation/utils.py | 53 +++++++++++++++++----------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b86160945d6..5a349f33d1c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2082,17 +2082,18 @@ def contrastive_search( for key in all_outputs: # rebuild key value output if key == 'past_key_values': - layers_kv = [] - for layer in range(len(all_outputs[key][0])): - kv = [] - kv.append(torch.cat([all_outputs[key][seq][layer][0] - for seq in range(len(all_outputs[key]))], dim=0)) + pass + # layers_kv = [] + # for layer in range(len(all_outputs[key][0])): + # kv = [] + # kv.append(torch.cat([all_outputs[key][seq][layer][0] + # for seq in range(len(all_outputs[key]))], dim=0)) - kv.append(torch.cat([all_outputs[key][seq][layer][1] - for seq in range(len(all_outputs[key]))], dim=0)) + # kv.append(torch.cat([all_outputs[key][seq][layer][1] + # for seq in range(len(all_outputs[key]))], dim=0)) - layers_kv.append(tuple(kv)) - outputs[key] = tuple(layers_kv) + # layers_kv.append(tuple(kv)) + # outputs[key] = tuple(layers_kv) elif torch.is_tensor(all_outputs[key]): outputs[key] = torch.stack(all_outputs[key], dim=0) @@ -2127,7 +2128,6 @@ def contrastive_search( logits = outputs.logits[:, -1, :] #+ torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -2148,16 +2148,29 @@ def contrastive_search( next_decoder_hidden_states += (layer,) # select the past_key_value - new_key_values = () - for layer in next_past_key_values: - items = () - # item is either the key or the value matrix - for item in layer: - item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] - items += (item,) - new_key_values += (items,) - next_past_key_values = new_key_values + if low_memory: + next_model_input = self.prepare_inputs_for_generation(top_k_ids[:, selected_idx].unsqueeze(0), + **model_kwargs) + selected_outputs = self( + **next_model_input, + return_dict=False, + output_hidden_states=False, + output_attentions=False, + ) + next_past_key_values = selected_outputs.past_key_values + + else: + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + new_key_values = () + for layer in next_past_key_values: + items = () + # item is either the key or the value matrix + for item in layer: + item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] + items += (item,) + new_key_values += (items,) + next_past_key_values = new_key_values logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] From d1af0f003dc9a4d120ac6f735caadb22f7df994e Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 10:55:05 -0400 Subject: [PATCH 58/83] changed output to dict --- src/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5a349f33d1c..e666c4b6b65 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2153,11 +2153,11 @@ def contrastive_search( **model_kwargs) selected_outputs = self( **next_model_input, - return_dict=False, + return_dict=True, output_hidden_states=False, output_attentions=False, ) - next_past_key_values = selected_outputs.past_key_values + next_past_key_values = selected_outputs['past_key_values'] else: next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) From ee94a31da974d18101a65fc5d8c2cd117f020dba Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 11:23:00 -0400 Subject: [PATCH 59/83] cleaned --- src/transformers/generation/utils.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e666c4b6b65..d963c575b5b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2078,28 +2078,6 @@ def contrastive_search( final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0) full_hidden_states = tuple(final_full_hstates) - - for key in all_outputs: - # rebuild key value output - if key == 'past_key_values': - pass - # layers_kv = [] - # for layer in range(len(all_outputs[key][0])): - # kv = [] - # kv.append(torch.cat([all_outputs[key][seq][layer][0] - # for seq in range(len(all_outputs[key]))], dim=0)) - - # kv.append(torch.cat([all_outputs[key][seq][layer][1] - # for seq in range(len(all_outputs[key]))], dim=0)) - - # layers_kv.append(tuple(kv)) - # outputs[key] = tuple(layers_kv) - - elif torch.is_tensor(all_outputs[key]): - outputs[key] = torch.stack(all_outputs[key], dim=0) - - # stack logits - logits = torch.cat(all_logits, dim=0) else: # compute the candidate tokens by the language model and collect their hidden_states @@ -2120,13 +2098,13 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - next_hidden = next_hidden #+ (torch.randn(next_hidden.shape)/100).to(input_ids.device) + next_hidden = next_hidden final = [] for i in range(len(full_hidden_states)): - final.append(full_hidden_states[i]) #+ (torch.randn(full_hidden_states[i].shape)/100).to(input_ids.device)) + final.append(full_hidden_states[i]) full_hidden_states = tuple(final) - logits = outputs.logits[:, -1, :] #+ torch.randn(outputs.logits[:, -1, :].shape).to(inputs_ids.device) + logits = outputs.logits[:, -1, :] context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From 5cfd454676d4f650f77aab79c94644a0f343b097 Mon Sep 17 00:00:00 2001 From: blbadger Date: Wed, 21 Jun 2023 11:25:39 -0400 Subject: [PATCH 60/83] cleaned --- src/transformers/generation/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d963c575b5b..c128d1331d5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2078,6 +2078,9 @@ def contrastive_search( final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0) full_hidden_states = tuple(final_full_hstates) + + # stack logits + logits = torch.cat(all_logits, dim=0) else: # compute the candidate tokens by the language model and collect their hidden_states From 2fbca3581c86007ac54c8e380302d0702c0ff931 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 22 Jun 2023 10:37:46 -0400 Subject: [PATCH 61/83] cleaned up contrastive search test --- tests/generation/test_utils.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index df42413665f..a1c61c1129a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1458,7 +1458,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): self._check_outputs(output, input_ids, model.config, use_cache=True) def test_contrastive_generate_low_memory(self): - # Check that choosing 'low_memory' reduces memory overhead and does not change the model output + # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): @@ -1476,16 +1476,19 @@ def test_contrastive_generate_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, top_k=4, penalty_alpha=0.6, low_memory=True) - lmem_usage = torch.cuda.max_memory_allocated() - - high_output = model.generate(input_ids, top_k=4, penalty_alpha=0.6, low_memory=False) - hmem_usage = torch.cuda.max_memory_allocated() - - # will usually fail due to the propegation of batch vs unbatched forward pass numerical errors - self.assertListEqual(low_mem_output.tolist(), high_mem_output.tolist()) - - assert lmem_usage < hmem_usage + low_output = model.generate(input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=True, + max_length=max_length + ) + high_output = model.generate(input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=False, + max_length=max_length + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) return From efcba6f0cbf32df2c875e251cab67c5769bd7a0a Mon Sep 17 00:00:00 2001 From: blbadger Date: Sun, 2 Jul 2023 21:19:47 -0400 Subject: [PATCH 62/83] moved low_memory kwarg --- src/transformers/generation/configuration_utils.py | 3 +++ src/transformers/generation/utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index d024ff4718e..564ed722552 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -181,6 +181,8 @@ class GenerationConfig(PushToHubMixin): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. + low_memory (`bool`, *optional*, defaults to False): + Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search. > Parameters that define the output variables of `generate` @@ -260,6 +262,7 @@ def __init__(self, **kwargs): self.suppress_tokens = kwargs.pop("suppress_tokens", None) self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) + self.low_memory = kwargs.pop("low_memory", False) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c128d1331d5..dcb603820c6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1156,7 +1156,6 @@ def generate( prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, - low_memory: Optional[bool] = False, streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: From 5a3b26c533c5e94defef09b1d53f4599524078bf Mon Sep 17 00:00:00 2001 From: blbadger Date: Sun, 2 Jul 2023 21:22:36 -0400 Subject: [PATCH 63/83] debugged --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dcb603820c6..8fa67cc3e27 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1214,8 +1214,6 @@ def generate( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - low_memory (`bool`, *optional*): - Switch to sequential topk for contrastive search to reduce peak memory requirements. Used with contrastive search. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1555,7 +1553,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, - low_memory=low_memory, + low_memory=generation_config.low_memory, **model_kwargs, ) From cf122309e09e0e37a40a513d758cf47f1e7c654a Mon Sep 17 00:00:00 2001 From: blbadger Date: Mon, 3 Jul 2023 12:55:10 -0400 Subject: [PATCH 64/83] changed low mem test batch size to 1 --- src/transformers/generation/utils.py | 75 +++++++++++----------------- tests/generation/test_utils.py | 4 +- tests_output.txt | 0 3 files changed, 31 insertions(+), 48 deletions(-) create mode 100644 tests_output.txt diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8fa67cc3e27..0c3627e7c73 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -594,10 +594,7 @@ def _maybe_initialize_input_ids_for_generation( return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( - self, - inputs: torch.Tensor, - pad_token_id: Optional[int], - eos_token_id: Optional[Union[int, List[int]]], + self, inputs: torch.Tensor, pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) @@ -681,8 +678,7 @@ def _prepare_decoder_input_ids_for_generation( if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - dim=-1, + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask @@ -789,10 +785,7 @@ def _reorder_cache(self, past_key_values, beam_idx): f" enable beam search for {self.__class__}" ) - def _get_logits_warper( - self, - generation_config: GenerationConfig, - ) -> LogitsProcessorList: + def _get_logits_warper(self, generation_config: GenerationConfig,) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. @@ -2040,17 +2033,18 @@ def contrastive_search( model_kwargs["past_key_values"] = new_key_values if low_memory: - all_outputs = {key:[] for key in outputs} # defined in first loop iteration + all_outputs = {key: [] for key in outputs} # defined in first loop iteration all_last_hstates, all_hstates, all_logits = [], [], [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].unsqueeze(0), - **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids[:, i].unsqueeze(0), **model_kwargs + ) outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, + **next_model_inputs, + return_dict=True, + output_hidden_states=True, output_attentions=output_attentions, ) for key in all_outputs: @@ -2063,7 +2057,7 @@ def contrastive_search( else: next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - + all_last_hstates.append(torch.squeeze(next_hidden, 0)) all_hstates.append(full_hidden_states) all_logits.append(outputs.logits[:, -1, :]) @@ -2072,23 +2066,24 @@ def contrastive_search( next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] for layer in range(len(full_hidden_states)): - final_full_hstates[layer] = torch.stack([torch.squeeze(all_hstates[i][layer], 0) - for i in range(top_k)], dim=0) + final_full_hstates[layer] = torch.stack( + [torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0 + ) full_hidden_states = tuple(final_full_hstates) # stack logits logits = torch.cat(all_logits, dim=0) - + else: # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k (leading to OOM for large models) next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, ) # name is different for encoder-decoder and decoder-only models if self.config.is_encoder_decoder: @@ -2098,10 +2093,10 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - next_hidden = next_hidden + next_hidden = next_hidden final = [] for i in range(len(full_hidden_states)): - final.append(full_hidden_states[i]) + final.append(full_hidden_states[i]) full_hidden_states = tuple(final) logits = outputs.logits[:, -1, :] @@ -2127,15 +2122,13 @@ def contrastive_search( # select the past_key_value if low_memory: - next_model_input = self.prepare_inputs_for_generation(top_k_ids[:, selected_idx].unsqueeze(0), - **model_kwargs) + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx].unsqueeze(0), **model_kwargs + ) selected_outputs = self( - **next_model_input, - return_dict=True, - output_hidden_states=False, - output_attentions=False, + **next_model_input, return_dict=True, output_hidden_states=False, output_attentions=False, ) - next_past_key_values = selected_outputs['past_key_values'] + next_past_key_values = selected_outputs["past_key_values"] else: next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) @@ -4513,11 +4506,7 @@ def assisted_decoding( ) else: decoder_attentions = _split_model_outputs( - decoder_attentions, - outputs.attentions, - cur_len, - added_len, - is_decoder_attention=True, + decoder_attentions, outputs.attentions, cur_len, added_len, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: @@ -4598,10 +4587,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ): for idx in range(len(past_key_values)): new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length], - past_key_values[idx][1][:, :maximum_length, :], - ) + (past_key_values[idx][0][:, :, :maximum_length], past_key_values[idx][1][:, :maximum_length, :],) ) past_key_values = tuple(new_past) # gptbigcode is too @@ -4617,10 +4603,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): else: for idx in range(len(past_key_values)): new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - ) + (past_key_values[idx][0][:, :, :maximum_length, :], past_key_values[idx][1][:, :, :maximum_length, :],) ) past_key_values = tuple(new_past) return past_key_values diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a1c61c1129a..bda69da8226 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1464,14 +1464,14 @@ def test_contrastive_generate_low_memory(self): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): return + config.use_cache = True config.is_decoder = True - prompt = "The rain in Spain" # test output equality of low versus high memory diff --git a/tests_output.txt b/tests_output.txt new file mode 100644 index 00000000000..e69de29bb2d From 60fd1850783d7e3cb1255a06a1e7306bfd0526eb Mon Sep 17 00:00:00 2001 From: blbadger Date: Mon, 3 Jul 2023 12:56:35 -0400 Subject: [PATCH 65/83] removed output --- tests_output.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests_output.txt diff --git a/tests_output.txt b/tests_output.txt deleted file mode 100644 index e69de29bb2d..00000000000 From a3355c1ca767f513428bcacf36fe060721ea38bb Mon Sep 17 00:00:00 2001 From: blbadger Date: Mon, 3 Jul 2023 15:47:50 -0400 Subject: [PATCH 66/83] debugged test input shape --- tests/generation/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bda69da8226..6b88f06689e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1465,6 +1465,7 @@ def test_contrastive_generate_low_memory(self): return config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + input_ids = torch.unsqueeze(input_ids, 0) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): From 87be0de5caa8e768dcb0a35a09b36e5176bf693a Mon Sep 17 00:00:00 2001 From: blbadger Date: Mon, 3 Jul 2023 16:33:49 -0400 Subject: [PATCH 67/83] reformatted csearch test --- tests/generation/test_utils.py | 36 ++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6b88f06689e..19b8e007de2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1464,8 +1464,8 @@ def test_contrastive_generate_low_memory(self): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return + # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) - input_ids = torch.unsqueeze(input_ids, 0) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1477,18 +1477,30 @@ def test_contrastive_generate_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, - top_k=4, - penalty_alpha=0.6, - low_memory=True, - max_length=max_length - ) - high_output = model.generate(input_ids, - top_k=4, - penalty_alpha=0.6, - low_memory=False, - max_length=max_length + _, low_output = self._contrastive_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + low_memory=True + ) + + _, high_output = self._contrastive_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + low_memory=False ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) return From ab307f98b8129b9e98c3e79c160f6b6306d9c484 Mon Sep 17 00:00:00 2001 From: blbadger Date: Mon, 3 Jul 2023 16:56:47 -0400 Subject: [PATCH 68/83] added trace --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0c3627e7c73..b6de87d988c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2125,6 +2125,7 @@ def contrastive_search( next_model_input = self.prepare_inputs_for_generation( top_k_ids[:, selected_idx].unsqueeze(0), **model_kwargs ) + print (next_model_input['input_ids'].shape) selected_outputs = self( **next_model_input, return_dict=True, output_hidden_states=False, output_attentions=False, ) From dfff73d80a4f789f1e9ed7e1c4fa8cca8e7c7a5f Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 11:06:06 -0400 Subject: [PATCH 69/83] removed unsqueeze on final forward pass --- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 34 +++++++++++----------------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b6de87d988c..3be5b2abee8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2123,7 +2123,7 @@ def contrastive_search( # select the past_key_value if low_memory: next_model_input = self.prepare_inputs_for_generation( - top_k_ids[:, selected_idx].unsqueeze(0), **model_kwargs + top_k_ids[:, selected_idx], **model_kwargs ) print (next_model_input['input_ids'].shape) selected_outputs = self( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 19b8e007de2..1bacc5d084e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1464,8 +1464,8 @@ def test_contrastive_generate_low_memory(self): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return - # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + print (input_ids.shape) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1477,31 +1477,23 @@ def test_contrastive_generate_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - _, low_output = self._contrastive_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, + + low_output = model.generate(input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=True, max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - low_memory=True + attention_maks=attention_mask ) - _, high_output = self._contrastive_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, + high_output = model.generate(input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=False, max_length=max_length, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - low_memory=False + attention_mask=attention_mask ) - - self.assertListEqual(low_output.tolist(), high_output.tolist()) + # self.assertListEqual(low_output.tolist(), high_output.tolist()) return From 0334d12e0392cd1d77ee8ae73856665a63a3271c Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 11:20:30 -0400 Subject: [PATCH 70/83] replaced unsqueeze with view --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3be5b2abee8..1f02661ee18 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2123,7 +2123,7 @@ def contrastive_search( # select the past_key_value if low_memory: next_model_input = self.prepare_inputs_for_generation( - top_k_ids[:, selected_idx], **model_kwargs + top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs ) print (next_model_input['input_ids'].shape) selected_outputs = self( From 06dacc03b4aa153ad2e62596fe6477c774f3ec5d Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 11:39:04 -0400 Subject: [PATCH 71/83] removed traces --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1f02661ee18..8285c239747 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2125,7 +2125,7 @@ def contrastive_search( next_model_input = self.prepare_inputs_for_generation( top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs ) - print (next_model_input['input_ids'].shape) + selected_outputs = self( **next_model_input, return_dict=True, output_hidden_states=False, output_attentions=False, ) From 94d6dd9dfe07b40ddf05f25a07c166f16241a3a9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 11:49:42 -0400 Subject: [PATCH 72/83] cleaned --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8285c239747..75a8a3f7c50 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2038,7 +2038,7 @@ def contrastive_search( for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states next_model_inputs = self.prepare_inputs_for_generation( - top_k_ids[:, i].unsqueeze(0), **model_kwargs + top_k_ids[:, i].view(-1, 1), **model_kwargs ) outputs = self( From a2293dd90ba7ac6057516bef769b82b486348e77 Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 12:42:08 -0400 Subject: [PATCH 73/83] debugged model kwargs --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1bacc5d084e..ddad910b9f6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1483,7 +1483,7 @@ def test_contrastive_generate_low_memory(self): penalty_alpha=0.6, low_memory=True, max_length=max_length, - attention_maks=attention_mask + attention_mask=attention_mask ) high_output = model.generate(input_ids, @@ -1493,7 +1493,7 @@ def test_contrastive_generate_low_memory(self): max_length=max_length, attention_mask=attention_mask ) - # self.assertListEqual(low_output.tolist(), high_output.tolist()) + self.assertListEqual(low_output.tolist(), high_output.tolist()) return From 0deba213b8b22610e022f16692327533eb14b992 Mon Sep 17 00:00:00 2001 From: blbadger Date: Tue, 4 Jul 2023 15:04:57 -0400 Subject: [PATCH 74/83] removed special models from test --- tests/generation/test_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ddad910b9f6..2fb51ca06db 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1460,13 +1460,12 @@ def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]): return config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) - print (input_ids.shape) - + # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): return From 1aa7279305aec6408174837d35e480ec2d47f511 Mon Sep 17 00:00:00 2001 From: blbadger Date: Fri, 7 Jul 2023 10:42:21 -0400 Subject: [PATCH 75/83] ran make quality --- src/transformers/generation/utils.py | 38 ++++++++++++++++++++-------- tests/generation/test_utils.py | 29 +++++++++++---------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 37f350789d5..a736f086ac7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -596,7 +596,10 @@ def _maybe_initialize_input_ids_for_generation( return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( - self, inputs: torch.Tensor, pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], + self, + inputs: torch.Tensor, + pad_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) @@ -680,7 +683,8 @@ def _prepare_decoder_input_ids_for_generation( if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), dim=-1, + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask @@ -787,7 +791,10 @@ def _reorder_cache(self, past_key_values, beam_idx): f" enable beam search for {self.__class__}" ) - def _get_logits_warper(self, generation_config: GenerationConfig,) -> LogitsProcessorList: + def _get_logits_warper( + self, + generation_config: GenerationConfig, + ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. @@ -2052,9 +2059,7 @@ def contrastive_search( all_last_hstates, all_hstates, all_logits = [], [], [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation( - top_k_ids[:, i].view(-1, 1), **model_kwargs - ) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) outputs = self( **next_model_inputs, @@ -2144,7 +2149,10 @@ def contrastive_search( ) selected_outputs = self( - **next_model_input, return_dict=True, output_hidden_states=False, output_attentions=False, + **next_model_input, + return_dict=True, + output_hidden_states=False, + output_attentions=False, ) next_past_key_values = selected_outputs["past_key_values"] @@ -4525,7 +4533,11 @@ def assisted_decoding( ) else: decoder_attentions = _split_model_outputs( - decoder_attentions, outputs.attentions, cur_len, added_len, is_decoder_attention=True, + decoder_attentions, + outputs.attentions, + cur_len, + added_len, + is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: @@ -4606,7 +4618,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ): for idx in range(len(past_key_values)): new_past.append( - (past_key_values[idx][0][:, :, :maximum_length], past_key_values[idx][1][:, :maximum_length, :],) + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) ) past_key_values = tuple(new_past) # gptbigcode is too @@ -4622,7 +4637,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length): else: for idx in range(len(past_key_values)): new_past.append( - (past_key_values[idx][0][:, :, :maximum_length, :], past_key_values[idx][1][:, :, :maximum_length, :],) + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) ) past_key_values = tuple(new_past) return past_key_values diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0150937495b..0f50632c63e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1461,42 +1461,45 @@ def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]): + if any( + model_name in model_class.__name__.lower() + for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] + ): return config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) - + # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): return config.use_cache = True config.is_decoder = True - prompt = "The rain in Spain" # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, - top_k=4, - penalty_alpha=0.6, - low_memory=True, + low_output = model.generate( + input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=True, max_length=max_length, - attention_mask=attention_mask + attention_mask=attention_mask, ) - high_output = model.generate(input_ids, - top_k=4, + high_output = model.generate( + input_ids, + top_k=4, penalty_alpha=0.6, - low_memory=False, + low_memory=False, max_length=max_length, - attention_mask=attention_mask + attention_mask=attention_mask, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) return - @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. From 871cf59495cda72b620f80d39a58ff9992e3fe5e Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Sat, 8 Jul 2023 10:13:03 -0400 Subject: [PATCH 76/83] Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante --- src/transformers/generation/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1474c07bf7a..ca52f6d2dcb 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -273,7 +273,7 @@ def __init__(self, **kwargs): self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.sequence_bias = kwargs.pop("sequence_bias", None) self.guidance_scale = kwargs.pop("guidance_scale", None) - self.low_memory = kwargs.pop("low_memory", False) + self.low_memory = kwargs.pop("low_memory", None) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) From ef6bfd670a1d7d2233ddb7477520564e80ea52a4 Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Sat, 8 Jul 2023 10:14:38 -0400 Subject: [PATCH 77/83] Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante --- src/transformers/generation/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index ca52f6d2dcb..a62dfe12b81 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -189,7 +189,7 @@ class GenerationConfig(PushToHubMixin): The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer quality. - low_memory (`bool`, *optional*, defaults to False): + low_memory (`bool`, *optional*): Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search. From bad2d1860dcba5dbffedeb02a818b833f3e5823b Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 8 Jul 2023 10:25:49 -0400 Subject: [PATCH 78/83] refactored --- src/transformers/generation/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a736f086ac7..07d44520674 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2113,12 +2113,6 @@ def contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - next_hidden = next_hidden - final = [] - for i in range(len(full_hidden_states)): - final.append(full_hidden_states[i]) - full_hidden_states = tuple(final) - logits = outputs.logits[:, -1, :] context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From f16f2e7b8dbaa5cc1842339fd4daf7c01b2d87be Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 8 Jul 2023 11:06:01 -0400 Subject: [PATCH 79/83] refactored --- src/transformers/generation/utils.py | 47 ++++++++++++---------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 07d44520674..5165694b237 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2069,19 +2069,30 @@ def contrastive_search( ) for key in all_outputs: all_outputs[key].append(outputs[key]) - - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - all_last_hstates.append(torch.squeeze(next_hidden, 0)) all_hstates.append(full_hidden_states) all_logits.append(outputs.logits[:, -1, :]) + else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + if low_memory: # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] @@ -2095,24 +2106,6 @@ def contrastive_search( logits = torch.cat(all_logits, dim=0) else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k (leading to OOM for large models) - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) From af70bef337d754d648bb6ee34fc7bc3a26b1b443 Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 8 Jul 2023 17:39:45 -0400 Subject: [PATCH 80/83] refactored --- src/transformers/generation/utils.py | 49 ++++++++++++++++------------ 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5165694b237..ef3f4922fb5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2069,30 +2069,19 @@ def contrastive_search( ) for key in all_outputs: all_outputs[key].append(outputs[key]) + + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + all_last_hstates.append(torch.squeeze(next_hidden, 0)) all_hstates.append(full_hidden_states) all_logits.append(outputs.logits[:, -1, :]) - else: - # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - outputs = self( - **next_model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, - ) - - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states - else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states - - if low_memory: # stack hidden states next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) final_full_hstates = [0 for i in range(len(full_hidden_states))] @@ -2106,6 +2095,24 @@ def contrastive_search( logits = torch.cat(all_logits, dim=0) else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + logits = outputs.logits[:, -1, :] context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) @@ -2129,7 +2136,7 @@ def contrastive_search( layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] next_decoder_hidden_states += (layer,) - # select the past_key_value + # generate past_key_values cache of only the selected token if low_memory: next_model_input = self.prepare_inputs_for_generation( top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs From 2d21e64c5dafb0ddf4af0d9ac17e1f8f21dfd747 Mon Sep 17 00:00:00 2001 From: blbadger Date: Sat, 8 Jul 2023 17:58:39 -0400 Subject: [PATCH 81/83] make fixup --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ef3f4922fb5..f097d1e4a5f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2096,7 +2096,7 @@ def contrastive_search( else: # compute the candidate tokens by the language model and collect their hidden_states - # assembles top_k_ids into batch of size k + # assembles top_k_ids into batch of size k next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) outputs = self( From bf3a07394f3e4497f88aed7580a6c85f5329c3be Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 20 Jul 2023 10:54:32 -0400 Subject: [PATCH 82/83] renamed flag sequential --- clean-snap.sh | 8 ++++++++ src/transformers/generation/utils.py | 15 ++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) create mode 100644 clean-snap.sh diff --git a/clean-snap.sh b/clean-snap.sh new file mode 100644 index 00000000000..a3862bb63a4 --- /dev/null +++ b/clean-snap.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Removes old revisions of snaps +# CLOSE ALL SNAPS BEFORE RUNNING THIS +set -eu +snap list --all | awk '/disabled/{print $1, $3}' | + while read snapname revision; do + snap remove "$snapname" --revision="$revision" + done diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f097d1e4a5f..97aa95bdfe9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1563,7 +1563,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, - low_memory=generation_config.low_memory, + sequential=generation_config.low_memory, **model_kwargs, ) @@ -1827,7 +1827,7 @@ def contrastive_search( return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, - low_memory: Optional[bool] = False, + sequential: Optional[bool] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1878,7 +1878,7 @@ def contrastive_search( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - low_memory (`bool`, *optional*): + sequential (`bool`, *optional*): Switches topk hidden state computation from parallel to sequential to reduce memory if True. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. @@ -1919,6 +1919,7 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + sequential = sequential if sequential is not None else self.generation_config.low_memory if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None @@ -1994,7 +1995,7 @@ def contrastive_search( is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) - if not low_memory: + if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs @@ -2047,14 +2048,14 @@ def contrastive_search( items = [] # item is either the key or the value matrix for item in layer: - if low_memory: + if sequential: items.append(item.repeat_interleave(1, dim=0)) else: items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - if low_memory: + if sequential: all_outputs = {key: [] for key in outputs} # defined in first loop iteration all_last_hstates, all_hstates, all_logits = [], [], [] for i in range(top_k): @@ -2137,7 +2138,7 @@ def contrastive_search( next_decoder_hidden_states += (layer,) # generate past_key_values cache of only the selected token - if low_memory: + if sequential: next_model_input = self.prepare_inputs_for_generation( top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs ) From b11c156dc6a9db4a17ad923b541accfb59ceeae9 Mon Sep 17 00:00:00 2001 From: blbadger Date: Thu, 20 Jul 2023 10:56:00 -0400 Subject: [PATCH 83/83] renamed flag sequential --- clean-snap.sh | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 clean-snap.sh diff --git a/clean-snap.sh b/clean-snap.sh deleted file mode 100644 index a3862bb63a4..00000000000 --- a/clean-snap.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -# Removes old revisions of snaps -# CLOSE ALL SNAPS BEFORE RUNNING THIS -set -eu -snap list --all | awk '/disabled/{print $1, $3}' | - while read snapname revision; do - snap remove "$snapname" --revision="$revision" - done