Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive Search peak memory reduction #24120

Merged
merged 117 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
e9c05dc
added hidden subset
blbadger Jun 6, 2023
9c4b96a
Merge pull request #1 from blbadger/master
blbadger Jun 6, 2023
fc52166
debugged hidden subset contrastive search
blbadger Jun 6, 2023
f16ac72
Merge pull request #2 from blbadger/master
blbadger Jun 6, 2023
3906d60
added contrastive search compression
blbadger Jun 6, 2023
40ebe76
Merge pull request #3 from blbadger/master
blbadger Jun 6, 2023
2881aef
debugged compressed contrastive search
blbadger Jun 6, 2023
4ddf45b
Merge pull request #4 from blbadger/master
blbadger Jun 6, 2023
7d29c55
memory reduction for contrastive search
blbadger Jun 7, 2023
b0b98cb
Merge pull request #5 from blbadger/master
blbadger Jun 7, 2023
57dfaac
debugged mem red
blbadger Jun 7, 2023
a419245
Merge pull request #6 from blbadger/master
blbadger Jun 7, 2023
fd0e19f
added low memory option feature
blbadger Jun 7, 2023
fc03ab2
Merge pull request #7 from blbadger/master
blbadger Jun 7, 2023
802cfd4
debugged mem optmimization output stack
blbadger Jun 7, 2023
0632f06
debugged mem optmimization output stack
blbadger Jun 7, 2023
8318968
Merge pull request #8 from blbadger/master
blbadger Jun 7, 2023
9bad256
debugged low mem
blbadger Jun 7, 2023
8fa1731
Merge pull request #9 from blbadger/master
blbadger Jun 7, 2023
a89bb8e
added low mem cache
blbadger Jun 7, 2023
cdbd070
Merge pull request #10 from blbadger/master
blbadger Jun 7, 2023
f90f948
fixed 2047 tensor view
blbadger Jun 7, 2023
65feec9
Merge pull request #11 from blbadger/master
blbadger Jun 7, 2023
e1718c3
debugged 2042 past key val inputs
blbadger Jun 7, 2023
089a299
Merge pull request #12 from blbadger/master
blbadger Jun 7, 2023
3fd54e6
reformatted tensors
blbadger Jun 7, 2023
6d6ac75
Merge pull request #13 from blbadger/master
blbadger Jun 7, 2023
12d5aea
changed low mem output
blbadger Jun 7, 2023
89f9b13
Merge pull request #14 from blbadger/master
blbadger Jun 7, 2023
44a9ec4
final clean
blbadger Jun 7, 2023
37bb62d
removed subset hidden csearch
blbadger Jun 8, 2023
68c1cd8
fixed hidden device
blbadger Jun 8, 2023
e199ddc
fixed hidden device
blbadger Jun 8, 2023
8ace5a3
changed compressor dtype
blbadger Jun 8, 2023
1ac80a0
removed hstate compression
blbadger Jun 8, 2023
1c3aae7
integrated csearch in generate
blbadger Jun 8, 2023
f18bccd
test csearch integration into generation
blbadger Jun 8, 2023
abf0a72
fixed csearch kwarg integration with generation
blbadger Jun 8, 2023
e517d5f
final wrap and added doc
blbadger Jun 8, 2023
cc1ea6d
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
bd2e36b
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
b59ec6d
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
a7fb76e
added debug print
blbadger Jun 16, 2023
961a1ba
direct hstate cat
blbadger Jun 16, 2023
882b6d2
direct hstate cat
blbadger Jun 16, 2023
c3f3db3
direct hstate cat debug
blbadger Jun 16, 2023
692b5e1
direct hstate cat debug
blbadger Jun 16, 2023
349bbf9
expanded full hidden state stack
blbadger Jun 16, 2023
cd4bed0
expanded full hidden state stack
blbadger Jun 16, 2023
ae41c50
matched dims for hstates
blbadger Jun 16, 2023
30baaa6
matched dims for hstates
blbadger Jun 16, 2023
ebc19ff
logits fix
blbadger Jun 16, 2023
752a488
equality test
blbadger Jun 16, 2023
4f973ba
equality hidden debug
blbadger Jun 16, 2023
b809415
debug
blbadger Jun 16, 2023
9230061
added prints for debug
blbadger Jun 16, 2023
2863471
added prints for debug
blbadger Jun 16, 2023
e653353
equality check
blbadger Jun 16, 2023
d790ea5
switched squeeze dim
blbadger Jun 16, 2023
f194221
input format debug
blbadger Jun 16, 2023
665c323
tracing top_k_ids
blbadger Jun 16, 2023
6259b56
removed trace
blbadger Jun 17, 2023
55561bb
Merge pull request #16 from blbadger/equal-csearch
blbadger Jun 17, 2023
7f52d87
Merge branch 'huggingface:main' into main
blbadger Jun 17, 2023
6d2734c
added test context
blbadger Jun 17, 2023
a873dfd
Merge pull request #17 from blbadger/equal-csearch
blbadger Jun 17, 2023
4033b19
added jitter
blbadger Jun 18, 2023
e2051a7
added jitter
blbadger Jun 18, 2023
e8f4cd1
added jitter
blbadger Jun 18, 2023
6bed197
returned state
blbadger Jun 18, 2023
67946f2
rebuilt past key value reconstruction
blbadger Jun 21, 2023
3dbd776
debugged
blbadger Jun 21, 2023
547df69
cleaned traces
blbadger Jun 21, 2023
f4b1f28
added selection for pkv
blbadger Jun 21, 2023
d1af0f0
changed output to dict
blbadger Jun 21, 2023
fbb11b5
Merge pull request #18 from blbadger/selected-pkv
blbadger Jun 21, 2023
ee94a31
cleaned
blbadger Jun 21, 2023
5cfd454
cleaned
blbadger Jun 21, 2023
b63ec63
Merge pull request #19 from blbadger/selected-pkv
blbadger Jun 21, 2023
2fbca35
cleaned up contrastive search test
blbadger Jun 22, 2023
29b16f7
Merge pull request #20 from blbadger/selected-pkv
blbadger Jun 22, 2023
efcba6f
moved low_memory kwarg
blbadger Jul 3, 2023
5a3b26c
debugged
blbadger Jul 3, 2023
fb337c3
Merge pull request #21 from blbadger/selected-pkv
blbadger Jul 3, 2023
cf12230
changed low mem test batch size to 1
blbadger Jul 3, 2023
60fd185
removed output
blbadger Jul 3, 2023
0e4fd99
Merge pull request #22 from blbadger/selected-pkv
blbadger Jul 3, 2023
a3355c1
debugged test input shape
blbadger Jul 3, 2023
704e9b1
Merge pull request #23 from blbadger/selected-pkv
blbadger Jul 3, 2023
87be0de
reformatted csearch test
blbadger Jul 3, 2023
8564437
Merge pull request #24 from blbadger/selected-pkv
blbadger Jul 3, 2023
ab307f9
added trace
blbadger Jul 3, 2023
dfff73d
removed unsqueeze on final forward pass
blbadger Jul 4, 2023
0334d12
replaced unsqueeze with view
blbadger Jul 4, 2023
06dacc0
removed traces
blbadger Jul 4, 2023
94d6dd9
cleaned
blbadger Jul 4, 2023
fe78f81
Merge pull request #25 from blbadger/selected-pkv
blbadger Jul 4, 2023
a2293dd
debugged model kwargs
blbadger Jul 4, 2023
150d1a1
Merge pull request #26 from blbadger/selected-pkv
blbadger Jul 4, 2023
0deba21
removed special models from test
blbadger Jul 4, 2023
5237cf0
Merge pull request #27 from blbadger/selected-pkv
blbadger Jul 4, 2023
05c408e
Merge branch 'main' into main
blbadger Jul 4, 2023
f9bd670
Merge branch 'huggingface:main' into main
blbadger Jul 5, 2023
1aa7279
ran make quality
blbadger Jul 7, 2023
8129e2a
Merge branch 'huggingface:main' into main
blbadger Jul 7, 2023
871cf59
Update src/transformers/generation/configuration_utils.py
blbadger Jul 8, 2023
ef6bfd6
Update src/transformers/generation/configuration_utils.py
blbadger Jul 8, 2023
bad2d18
refactored
blbadger Jul 8, 2023
f16f2e7
refactored
blbadger Jul 8, 2023
af70bef
refactored
blbadger Jul 8, 2023
d82e792
Merge pull request #28 from blbadger/selected-pkv
blbadger Jul 8, 2023
2d21e64
make fixup
blbadger Jul 8, 2023
f310f83
Merge pull request #29 from blbadger/selected-pkv
blbadger Jul 8, 2023
bf3a073
renamed flag sequential
blbadger Jul 20, 2023
b11c156
renamed flag sequential
blbadger Jul 20, 2023
c619204
Merge pull request #30 from blbadger/selected-pkv
blbadger Jul 20, 2023
1ae9d4a
Merge branch 'huggingface:main' into main
blbadger Jul 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.


> Parameters that define the output variables of `generate`

Expand Down Expand Up @@ -270,6 +273,7 @@ def __init__(self, **kwargs):
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand Down
131 changes: 100 additions & 31 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,7 @@ def generate(
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
sequential=generation_config.low_memory,
**model_kwargs,
)

Expand Down Expand Up @@ -1832,6 +1833,7 @@ def contrastive_search(
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
sequential: Optional[bool] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1882,6 +1884,8 @@ def contrastive_search(
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -1921,6 +1925,7 @@ def contrastive_search(
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
sequential = sequential if sequential is not None else self.generation_config.low_memory
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
Expand Down Expand Up @@ -1986,6 +1991,7 @@ def contrastive_search(
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = outputs.hidden_states[-1]

# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]

Expand All @@ -1995,11 +2001,11 @@ def contrastive_search(
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)

# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)

past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
Expand All @@ -2019,7 +2025,6 @@ def contrastive_search(
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty

logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
Expand Down Expand Up @@ -2049,25 +2054,74 @@ def contrastive_search(
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values

# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
all_last_hstates, all_hstates, all_logits = [], [], []
for i in range(top_k):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)

outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
for key in all_outputs:
all_outputs[key].append(outputs[key])

if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states

else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
Comment on lines +2080 to +2086
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines exit in both the if (low memory) and else (not low memory) code blocks, but note that they are not easily refactored because next_hidden and full_hidden_states must be returned iteratively for each top_k token when low memory is activated, but otherwise they are only returned once batch-wise.


all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])

# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
for layer in range(len(full_hidden_states)):
final_full_hstates[layer] = torch.stack(
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
)
full_hidden_states = tuple(final_full_hstates)

# stack logits
logits = torch.cat(all_logits, dim=0)

logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)

outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

logits = outputs.logits[:, -1, :]

context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)

# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
Expand All @@ -2089,17 +2143,32 @@ def contrastive_search(
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
next_decoder_hidden_states += (layer,)

# select the past_key_value
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values
# generate past_key_values cache of only the selected token
if sequential:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
)

selected_outputs = self(
**next_model_input,
return_dict=True,
output_hidden_states=False,
output_attentions=False,
)
next_past_key_values = selected_outputs["past_key_values"]

else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values

logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]

Expand Down
43 changes: 43 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,49 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)

def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return

config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return

config.use_cache = True
config.is_decoder = True

# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()

low_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=True,
max_length=max_length,
attention_mask=attention_mask,
)

high_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_length=max_length,
attention_mask=attention_mask,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

return

@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
Expand Down