Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Multi Resolution Analysis (MRA) (New PR) #24513

Merged
merged 16 commits into from
Jul 10, 2023

Conversation

novice03
Copy link
Contributor

@novice03 novice03 commented Jun 27, 2023

Add Multi Resolution Analysis (MRA) for Approximate Self-Attention

This PR adds the MRA model to the repository.

Paper: https://arxiv.org/pdf/2207.10284.pdf
Code: https://github.com/mlpen/mra-attention

To-do:

  • Improve loading cuda kernels
  • Improve formatting and documentation
  • Upload checkpoints

@novice03
Copy link
Contributor Author

Copied all files over from #20573

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 27, 2023

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Jun 27, 2023

Could you fix the failing tests?

@novice03
Copy link
Contributor Author

Hello @sgugger, I've made sure all checks pass and fixed conflicts.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning everything! I just have one tiny nit. @amyeroberts could you have one final look and merge?

src/transformers/models/auto/tokenization_auto.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model!

Really nice PR. Mostly a few very, very small nits. Only main comment to be addressed before merging in is the implementation of test_attention_outputs .

src/transformers/models/mra/configuration_mra.py Outdated Show resolved Hide resolved
src/transformers/models/mra/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/mra/__init__.py Outdated Show resolved Hide resolved
tests/models/mra/test_modeling_mra.py Outdated Show resolved Hide resolved
src/transformers/models/mra/modeling_mra.py Outdated Show resolved Hide resolved
Comment on lines +616 to +619
query_layer.float(),
key_layer.float(),
value_layer.float(),
attention_mask.float(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super familiar with the assumptions we have about our models and the layer types. @sgugger - is it OK to call float() like this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a custom CUDA kernel which I'm guessing cannot handle other dtypes.

Comment on lines 45 to 59
batch_size=2,
seq_length=256,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=128,
num_hidden_layers=5,
num_attention_heads=2,
intermediate_size=36,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of argument values here defining the model architecture are all quite large, which will make running the test suite slow. Could you reduce seq_length and hidden_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduced seq_length to 8 and hidden_size to 16.

Comment on lines +345 to +346
def test_attention_outputs(self):
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests that are skipped should be skipped explicitly with a unittest.skip(reason) decorator. In this case, as the model outputs attentions, a custom implementation should added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, actually, MRA does not output attentions. All of the computation is done by the kernels, and the output of mra2_attention is the product of attention and value. For this reason, I've removed output_attentions from the modeling file and skipped this test.

@novice03
Copy link
Contributor Author

novice03 commented Jul 4, 2023

Hello @amyeroberts, I've addressed your comments and made some code changes. Please take a look at the updated files.

@novice03
Copy link
Contributor Author

novice03 commented Jul 6, 2023

Hi @amyeroberts, I've addressed the suggestions from the code review. Please take a look at the updated code.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model and iterating!

All LGTM - just two tiny, tiny nits. Otherwise, we're good to merge :)

docs/source/en/model_doc/mra.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/mra.md Outdated Show resolved Hide resolved
@novice03
Copy link
Contributor Author

Thanks for catching these errors @amyeroberts! I've applied both changes.

@amyeroberts amyeroberts merged commit 30ed3ad into huggingface:main Jul 10, 2023
@novice03 novice03 deleted the add-mra-2 branch July 10, 2023 17:18
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 26, 2023

@novice03

It seems the CI get

(line 403)  ValueError: sequence length must be divisible by the block_size.

when load_cuda_kernels loads successfully.

It's likely due to seq_length=8 from MraModelTester, but I am not able to set the correct combination of seq_length, block_size, num_blocks to make it works.

Note, our daily CI (with torch 2.0.1 + CUDA 11.8) fails to load custom CUDA kernels and the execution goes to

    if cuda_kernel is None:
        return torch.zeros_like(query).requires_grad_()

in mra2_attention and tests pass.

However, in our CI with torch 1.13 (and with CUDA 11.6.2), kernel is loaded, but the tests fail.

It would be great if you can help us to find the correct settings where the CI will pass when kernel is loaded.

Thanks in advance 🤗 .

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 26, 2023

You can run

python3 -m pytest -v tests/models/mra/test_modeling_mra.py::MraModelTest::test_for_masked_lm

The full error log is (if custom cuda kernal is loaded successfully)

self = <tests.models.mra.test_modeling_mra.MraModelTest testMethod=test_for_masked_lm>

    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
>       self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)

tests/models/mra/test_modeling_mra.py:322: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/models/mra/test_modeling_mra.py:210: in create_and_check_for_masked_lm
    result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:1093: in forward
    outputs = self.mra(
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:1028: in forward
    encoder_outputs = self.encoder(
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:782: in forward
    layer_outputs = layer_module(hidden_states, attention_mask)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:729: in forward
    self_attention_outputs = self.attention(hidden_states, attention_mask)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:681: in forward
    self_outputs = self.self(hidden_states, attention_mask)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/mra/modeling_mra.py:615: in forward
    context_layer = mra2_attention(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

query = tensor([[[[ 0.0500, -0.0523, -0.0260,  ...,  0.0000,  0.0000,  0.0000],
          [-0.1339,  0.0844,  0.0287,  ...,  0...       [ 0.0293,  0.1609,  0.0547,  ...,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0', grad_fn=<CatBackward0>)
key = tensor([[[[ 0.0185, -0.0316,  0.0150,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0575, -0.1123,  0.0832,  ...,  0...       [ 0.0608,  0.0932, -0.0973,  ...,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0', grad_fn=<CatBackward0>)
value = tensor([[[[ 0.0131,  0.1242,  0.0672,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0212,  0.0600,  0.0269,  ...,  0...       [-0.1005, -0.0048,  0.0561,  ...,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0', grad_fn=<CatBackward0>)
mask = tensor([[-2.1475e+09,  1.0000e+00,  1.0000e+00, -2.1475e+09,  1.0000e+00,
         -2.1475e+09, -2.1475e+09,  1.0000e+...  1.0000e+00,  1.0000e+00, -2.1475e+09,  1.0000e+00,
         -2.1475e+09, -2.1475e+09,  1.0000e+00]], device='cuda:0')
num_blocks = 64, approx_mode = 'full', block_size = 32, initial_prior_first_n_blocks = 0, initial_prior_diagonal_n_blocks = 0

    def mra2_attention(
        query,
        key,
        value,
        mask,
        num_blocks,
        approx_mode,
        block_size=32,
        initial_prior_first_n_blocks=0,
        initial_prior_diagonal_n_blocks=0,
    ):
        """
        Use Mra to approximate self-attention.
        """
        if cuda_kernel is None:
            return torch.zeros_like(query).requires_grad_()
    
        batch_size, num_head, seq_len, head_dim = query.size()
        meta_batch = batch_size * num_head
    
        if seq_len % block_size != 0:
>           raise ValueError("sequence length must be divisible by the block_size.")
E           ValueError: sequence length must be divisible by the block_size.

src/transformers/models/mra/modeling_mra.py:403: ValueError

@novice03
Copy link
Contributor Author

novice03 commented Aug 2, 2023

Hello @ydshieh, thanks for bringing this up. We will likely have to use larger values for seq_len and hidden_size. Can you please try with the values here?

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 4, 2023

Hi @novice03 Really appreciated you taking time on this. I tried it, and there are still 5 failures (it's already a great improvement!).

However, we (transformers) are in a series of reducing CI time and cost, and change to large values is really what we tried very hard to avoid, as you can see in #24824 , #25005 and #25266. Also, large values is very likely introducing OOM when running tests in multiprocesses settings (we use 8 processes to reduce the CI cost) and it's very hard to figure out when this happens.

I think it would be great if we can have an attribute block_size in the config classes with a default 32. And in the modeling file, everywhere calling methods like sparse_mask, mm_to_sparse etc. pass config.block_size to them.

This way, we will have a way to use small values in the tests. Furthermore, the users of this model will have more flexibility to run the model. And we can also have a better documentation about how to set the config values and the inputs to make it work.

Let me know WDYT 🙏 Thanks again!

@novice03
Copy link
Contributor Author

novice03 commented Aug 7, 2023

Hello @ydshieh, thanks for your reply. I understand that using large values increases the time and memory cost. However, since MRA was specifically designed for large sequences, it will be very tricky to run tests with small seq_len and hidden_size.

Unfortunately, I don't think that the tests can be fixed by lowering the block size. I've tried setting block size to 4 or 8, and got multiple other errors (index out of bounds errors, CUDA errors, etc.). Also, all of the released checkpoints are with block size = 32, so users cannot use the pretrained models with a different block size.

I hope I'm not asking too much, but is there an alternative/ exception that can be made? Either via allowing larger values or by running MRA tests without CUDA kernels. I've already verified that the HF model and the original code output similar logits and hidden states when CUDA kernels are loaded (with large sequence lengths).

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 7, 2023

Also, all of the released checkpoints are with block size = 32, so users cannot use the pretrained models with a different block size.

Fair point!

We will discuss internally what to deal with this model testing, but could you check the following 5 (remaining) failed tests that is from the new values you provided in an earlier comment, and see if you are able to fix them 🙏 ? Thanks!

(It's run on torch 1.13 + CUDA 11.6.2)

FAILED tests/models/mra/test_modeling_mra.py::MraModelTest::test_determinism - ValueError: zero-size array to reduction operation maximum which has no identity
FAILED tests/models/mra/test_modeling_mra.py::MraModelTest::test_feed_forward_chunking - AssertionError: False is not true
FAILED tests/models/mra/test_modeling_mra.py::MraModelTest::test_load_with_mismatched_shapes - ValueError: sequence length must be divisible by the block_size.
FAILED tests/models/mra/test_modeling_mra.py::MraModelTest::test_model_outputs_equivalence - TypeError: forward() got an unexpected keyword argument 'output_attentions'
FAILED tests/models/mra/test_modeling_mra.py::MraModelTest::test_retain_grad_hidden_states_attentions - TypeError: 'NoneType' object is not subscriptable

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* Add all files

* Update masked_language_modeling.md

* fix mlm models

* fix conflicts

* fix conflicts

* fix copies

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: amyeroberts <[email protected]>

* Reduce seq_len and hidden_size in ModelTester

* remove output_attentions

* fix conflicts

* remove copied from statements

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants