Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[BLOOM] Support encoder chunk sizes > 1 (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddxxdd-code authored Oct 9, 2022
1 parent fe6bd5f commit fbcb2ab
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
15 changes: 14 additions & 1 deletion examples/llm_serving/model/bloom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def get_slopes_power_of_2(n):
# shape of attention_mask: [B, 1, 1, S_max]
batch_size = attention_mask.shape[0]
key_length = attention_mask.shape[-1]

# Handle a special kind of internal padding added by alpa.
# Where internal padding of 2 is used for encoder chunck size that can't divide input length.
attention_mask = (attention_mask == 1)

attention_mask = attention_mask.reshape((batch_size, key_length))
num_heads = n_head
query_length = 1
Expand Down Expand Up @@ -178,6 +183,14 @@ def __call__(
(0, 0, causal_attention_mask_shift, 0),
(1, 1, seq_length, max_decoder_length)
)
# Handle a special kind of internal padding added by alpa.
# Note that this kind of internal padding is different from
# the padding added by the tokenizer. This internal padding
# should not update cache and step_ct
# shape: [B, 1, 1, S_max]
is_internal_padding = (attention_mask == 2)
num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1)
attention_mask = (attention_mask == 1)

attention_mask = combine_masks(attention_mask, causal_attention_mask)

Expand All @@ -195,7 +208,7 @@ def __call__(
cache_value = value
num_updated_cache_vectors = query.shape[1]
# A line added from bloom_model
attention_cache = key, value, cache_index + num_updated_cache_vectors
attention_cache = key, value, cache_index + num_updated_cache_vectors - num_internal_pad
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
Expand Down
8 changes: 0 additions & 8 deletions examples/llm_serving/model/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,6 @@ def get_alpa_model(model_name: str,
m = opt_model
elif "bloom" in model_name:
m = bloom_model
if any(x > 1 for x in encoder_chunk_sizes):
# TODO: support chunk size > 1
warnings.warn("encoder_chunk_size > 1 is not supported. Ignored.")
encoder_chunk_sizes = [1]
config = m.get_config(name,
num_pp_stages=None,
mark_boundary=False,
Expand All @@ -436,10 +432,6 @@ def get_alpa_model(model_name: str,
m = opt_model
elif "bloom" in model_name:
m = bloom_model
if any(x > 1 for x in encoder_chunk_sizes):
# TODO: support chunk size > 1
warnings.warn("encoder_chunk_size > 1 is not supported. Ignored.")
encoder_chunk_sizes = [1]

alpa.init()

Expand Down

2 comments on commit fbcb2ab

@zhisbug
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this PR, when encoder_chunk_size > len(seq), will it work correctly?

@ddxxdd-code
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this PR, when encoder_chunk_size > len(seq), will it work correctly?

Yes, it works. Here is one example of running with encoder_chunk_size = 64:
Screenshot from 2022-10-10 18-22-25

Please sign in to comment.