Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV] Add Gradient Checkpointing support for RWKV #24955

Merged
merged 1 commit into from
Jul 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
base_model_prefix = "rwkv"
_no_split_modules = ["RwkvBlock"]
_keep_in_fp32_modules = ["time_decay", "time_first"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -605,6 +606,8 @@ def __init__(self, config):

self.layers_are_rescaled = False

self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -659,14 +662,35 @@ def forward(
]
state[4] -= 1e30

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

hidden_states = inputs_embeds

all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.blocks):
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

hidden_states, state, attentions = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, state
)
else:
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)

if (
self.layers_are_rescaled
and self.config.rescale_every > 0
Expand Down