Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive Search peak memory reduction #24120

Merged
merged 117 commits into from
Jul 20, 2023
Merged
Changes from 1 commit
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
e9c05dc
added hidden subset
blbadger Jun 6, 2023
9c4b96a
Merge pull request #1 from blbadger/master
blbadger Jun 6, 2023
fc52166
debugged hidden subset contrastive search
blbadger Jun 6, 2023
f16ac72
Merge pull request #2 from blbadger/master
blbadger Jun 6, 2023
3906d60
added contrastive search compression
blbadger Jun 6, 2023
40ebe76
Merge pull request #3 from blbadger/master
blbadger Jun 6, 2023
2881aef
debugged compressed contrastive search
blbadger Jun 6, 2023
4ddf45b
Merge pull request #4 from blbadger/master
blbadger Jun 6, 2023
7d29c55
memory reduction for contrastive search
blbadger Jun 7, 2023
b0b98cb
Merge pull request #5 from blbadger/master
blbadger Jun 7, 2023
57dfaac
debugged mem red
blbadger Jun 7, 2023
a419245
Merge pull request #6 from blbadger/master
blbadger Jun 7, 2023
fd0e19f
added low memory option feature
blbadger Jun 7, 2023
fc03ab2
Merge pull request #7 from blbadger/master
blbadger Jun 7, 2023
802cfd4
debugged mem optmimization output stack
blbadger Jun 7, 2023
0632f06
debugged mem optmimization output stack
blbadger Jun 7, 2023
8318968
Merge pull request #8 from blbadger/master
blbadger Jun 7, 2023
9bad256
debugged low mem
blbadger Jun 7, 2023
8fa1731
Merge pull request #9 from blbadger/master
blbadger Jun 7, 2023
a89bb8e
added low mem cache
blbadger Jun 7, 2023
cdbd070
Merge pull request #10 from blbadger/master
blbadger Jun 7, 2023
f90f948
fixed 2047 tensor view
blbadger Jun 7, 2023
65feec9
Merge pull request #11 from blbadger/master
blbadger Jun 7, 2023
e1718c3
debugged 2042 past key val inputs
blbadger Jun 7, 2023
089a299
Merge pull request #12 from blbadger/master
blbadger Jun 7, 2023
3fd54e6
reformatted tensors
blbadger Jun 7, 2023
6d6ac75
Merge pull request #13 from blbadger/master
blbadger Jun 7, 2023
12d5aea
changed low mem output
blbadger Jun 7, 2023
89f9b13
Merge pull request #14 from blbadger/master
blbadger Jun 7, 2023
44a9ec4
final clean
blbadger Jun 7, 2023
37bb62d
removed subset hidden csearch
blbadger Jun 8, 2023
68c1cd8
fixed hidden device
blbadger Jun 8, 2023
e199ddc
fixed hidden device
blbadger Jun 8, 2023
8ace5a3
changed compressor dtype
blbadger Jun 8, 2023
1ac80a0
removed hstate compression
blbadger Jun 8, 2023
1c3aae7
integrated csearch in generate
blbadger Jun 8, 2023
f18bccd
test csearch integration into generation
blbadger Jun 8, 2023
abf0a72
fixed csearch kwarg integration with generation
blbadger Jun 8, 2023
e517d5f
final wrap and added doc
blbadger Jun 8, 2023
cc1ea6d
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
bd2e36b
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
b59ec6d
Update src/transformers/generation/utils.py
blbadger Jun 15, 2023
a7fb76e
added debug print
blbadger Jun 16, 2023
961a1ba
direct hstate cat
blbadger Jun 16, 2023
882b6d2
direct hstate cat
blbadger Jun 16, 2023
c3f3db3
direct hstate cat debug
blbadger Jun 16, 2023
692b5e1
direct hstate cat debug
blbadger Jun 16, 2023
349bbf9
expanded full hidden state stack
blbadger Jun 16, 2023
cd4bed0
expanded full hidden state stack
blbadger Jun 16, 2023
ae41c50
matched dims for hstates
blbadger Jun 16, 2023
30baaa6
matched dims for hstates
blbadger Jun 16, 2023
ebc19ff
logits fix
blbadger Jun 16, 2023
752a488
equality test
blbadger Jun 16, 2023
4f973ba
equality hidden debug
blbadger Jun 16, 2023
b809415
debug
blbadger Jun 16, 2023
9230061
added prints for debug
blbadger Jun 16, 2023
2863471
added prints for debug
blbadger Jun 16, 2023
e653353
equality check
blbadger Jun 16, 2023
d790ea5
switched squeeze dim
blbadger Jun 16, 2023
f194221
input format debug
blbadger Jun 16, 2023
665c323
tracing top_k_ids
blbadger Jun 16, 2023
6259b56
removed trace
blbadger Jun 17, 2023
55561bb
Merge pull request #16 from blbadger/equal-csearch
blbadger Jun 17, 2023
7f52d87
Merge branch 'huggingface:main' into main
blbadger Jun 17, 2023
6d2734c
added test context
blbadger Jun 17, 2023
a873dfd
Merge pull request #17 from blbadger/equal-csearch
blbadger Jun 17, 2023
4033b19
added jitter
blbadger Jun 18, 2023
e2051a7
added jitter
blbadger Jun 18, 2023
e8f4cd1
added jitter
blbadger Jun 18, 2023
6bed197
returned state
blbadger Jun 18, 2023
67946f2
rebuilt past key value reconstruction
blbadger Jun 21, 2023
3dbd776
debugged
blbadger Jun 21, 2023
547df69
cleaned traces
blbadger Jun 21, 2023
f4b1f28
added selection for pkv
blbadger Jun 21, 2023
d1af0f0
changed output to dict
blbadger Jun 21, 2023
fbb11b5
Merge pull request #18 from blbadger/selected-pkv
blbadger Jun 21, 2023
ee94a31
cleaned
blbadger Jun 21, 2023
5cfd454
cleaned
blbadger Jun 21, 2023
b63ec63
Merge pull request #19 from blbadger/selected-pkv
blbadger Jun 21, 2023
2fbca35
cleaned up contrastive search test
blbadger Jun 22, 2023
29b16f7
Merge pull request #20 from blbadger/selected-pkv
blbadger Jun 22, 2023
efcba6f
moved low_memory kwarg
blbadger Jul 3, 2023
5a3b26c
debugged
blbadger Jul 3, 2023
fb337c3
Merge pull request #21 from blbadger/selected-pkv
blbadger Jul 3, 2023
cf12230
changed low mem test batch size to 1
blbadger Jul 3, 2023
60fd185
removed output
blbadger Jul 3, 2023
0e4fd99
Merge pull request #22 from blbadger/selected-pkv
blbadger Jul 3, 2023
a3355c1
debugged test input shape
blbadger Jul 3, 2023
704e9b1
Merge pull request #23 from blbadger/selected-pkv
blbadger Jul 3, 2023
87be0de
reformatted csearch test
blbadger Jul 3, 2023
8564437
Merge pull request #24 from blbadger/selected-pkv
blbadger Jul 3, 2023
ab307f9
added trace
blbadger Jul 3, 2023
dfff73d
removed unsqueeze on final forward pass
blbadger Jul 4, 2023
0334d12
replaced unsqueeze with view
blbadger Jul 4, 2023
06dacc0
removed traces
blbadger Jul 4, 2023
94d6dd9
cleaned
blbadger Jul 4, 2023
fe78f81
Merge pull request #25 from blbadger/selected-pkv
blbadger Jul 4, 2023
a2293dd
debugged model kwargs
blbadger Jul 4, 2023
150d1a1
Merge pull request #26 from blbadger/selected-pkv
blbadger Jul 4, 2023
0deba21
removed special models from test
blbadger Jul 4, 2023
5237cf0
Merge pull request #27 from blbadger/selected-pkv
blbadger Jul 4, 2023
05c408e
Merge branch 'main' into main
blbadger Jul 4, 2023
f9bd670
Merge branch 'huggingface:main' into main
blbadger Jul 5, 2023
1aa7279
ran make quality
blbadger Jul 7, 2023
8129e2a
Merge branch 'huggingface:main' into main
blbadger Jul 7, 2023
871cf59
Update src/transformers/generation/configuration_utils.py
blbadger Jul 8, 2023
ef6bfd6
Update src/transformers/generation/configuration_utils.py
blbadger Jul 8, 2023
bad2d18
refactored
blbadger Jul 8, 2023
f16f2e7
refactored
blbadger Jul 8, 2023
af70bef
refactored
blbadger Jul 8, 2023
d82e792
Merge pull request #28 from blbadger/selected-pkv
blbadger Jul 8, 2023
2d21e64
make fixup
blbadger Jul 8, 2023
f310f83
Merge pull request #29 from blbadger/selected-pkv
blbadger Jul 8, 2023
bf3a073
renamed flag sequential
blbadger Jul 20, 2023
b11c156
renamed flag sequential
blbadger Jul 20, 2023
c619204
Merge pull request #30 from blbadger/selected-pkv
blbadger Jul 20, 2023
1ae9d4a
Merge branch 'huggingface:main' into main
blbadger Jul 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,30 +2069,19 @@ def contrastive_search(
)
for key in all_outputs:
all_outputs[key].append(outputs[key])

if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states

else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
Comment on lines +2080 to +2086
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines exit in both the if (low memory) and else (not low memory) code blocks, but note that they are not easily refactored because next_hidden and full_hidden_states must be returned iteratively for each top_k token when low memory is activated, but otherwise they are only returned once batch-wise.


all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])
else:
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)

outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)

# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

if low_memory:
# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
Expand All @@ -2106,6 +2095,24 @@ def contrastive_search(
logits = torch.cat(all_logits, dim=0)

else:
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)

outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

logits = outputs.logits[:, -1, :]

context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
Expand All @@ -2129,7 +2136,7 @@ def contrastive_search(
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
next_decoder_hidden_states += (layer,)

# select the past_key_value
# generate past_key_values cache of only the selected token
if low_memory:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
Expand Down