Skip to content

Commit

Permalink
Add new model (huggingface#32615)
Browse files Browse the repository at this point in the history
* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* fix

* fix tests

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
2 people authored and stevhliu committed Oct 22, 2024
1 parent 251b564 commit b31d7a7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 49 deletions.
55 changes: 10 additions & 45 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

_CHECKPOINT_FOR_DOC = "tiiuae/falcon-mamba-7b"
_CHECKPOINT_FOR_DOC = "tiiuae/falcon_mamba-7b"
_CONFIG_FOR_DOC = "FalconMambaConfig"


Expand Down Expand Up @@ -167,7 +167,6 @@ def cuda_kernels_forward(
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
Expand Down Expand Up @@ -196,9 +195,6 @@ def cuda_kernels_forward(
else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
Expand All @@ -220,9 +216,6 @@ def cuda_kernels_forward(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -282,17 +275,13 @@ def slow_forward(
input_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
Expand Down Expand Up @@ -321,9 +310,6 @@ def slow_forward(
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
Expand Down Expand Up @@ -385,11 +371,10 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
return self.slow_forward(hidden_states, cache_params, cache_position)


# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
Expand Down Expand Up @@ -427,16 +412,13 @@ def forward(
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
hidden_states = residual + hidden_states
return hidden_states

Expand Down Expand Up @@ -635,13 +617,14 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -680,15 +663,10 @@ def forward(
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
mixer_block.__call__, hidden_states, cache_params, cache_position
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -748,13 +726,6 @@ def _update_model_kwargs_for_generation(
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

return model_kwargs

def prepare_inputs_for_generation(
Expand All @@ -764,7 +735,6 @@ def prepare_inputs_for_generation(
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
# Overwitten -- uses `cache_params` as opposed to `past_key_values`
Expand All @@ -779,10 +749,6 @@ def prepare_inputs_for_generation(
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)

if attention_mask is not None:
attention_mask = None

else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
Expand All @@ -800,7 +766,6 @@ def prepare_inputs_for_generation(
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs
Expand All @@ -811,10 +776,11 @@ def prepare_inputs_for_generation(
output_type=FalconMambaCausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored copy
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -840,7 +806,6 @@ def forward(
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = falcon_mamba_outputs[0]

Expand Down
24 changes: 20 additions & 4 deletions tests/models/falcon_mamba/test_modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = ids_tensor([self.batch_size, self.seq_length], 1)

sequence_labels = None
token_labels = None
Expand All @@ -120,7 +119,7 @@ def prepare_config_and_inputs(
return (
config,
input_ids,
attention_mask,
None,
sequence_labels,
token_labels,
choice_labels,
Expand Down Expand Up @@ -150,6 +149,23 @@ def get_pipeline_config(self):
config.vocab_size = 300
return config

def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()

return (
config,
input_ids,
sequence_labels,
token_labels,
choice_labels,
)

def create_and_check_falcon_mamba_model(self, config, input_ids, *args):
config.output_hidden_states = True
model = FalconMambaModel(config=config)
Expand Down Expand Up @@ -237,12 +253,12 @@ def prepare_config_and_inputs_for_common(self):
(
config,
input_ids,
attention_mask,
_,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict


Expand Down

0 comments on commit b31d7a7

Please sign in to comment.