Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba #30950

Merged
merged 115 commits into from
Oct 4, 2024
Merged

Add Zamba #30950

merged 115 commits into from
Oct 4, 2024

Conversation

pglorio
Copy link
Contributor

@pglorio pglorio commented May 22, 2024

What does this PR do?

Please include support for Zamba architecture created by Zyphra Technologies.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @younesbelkada

@pglorio
Copy link
Contributor Author

pglorio commented May 23, 2024

Just for future reference, we measured latency of a single mamba layer in Zamba and compared it to that of a single layer in Mamba, which have very similar implementations (we have some reshapings in Zamba, but they should be a non-op, and a concatenation), and found that that the mamba layer in Zamba to have the same speed in a single forward pass, but to be slower on generation.

More specifically, we instantiated these two (random) models:

config = ZambaConfig(num_hidden_layers=81, hidden_size=3712, n_mamba_heads=1, use_cache=True)
model_1 = ZambaForCausalLM(config).cuda()
config = MambaConfig(num_hidden_layers=81, hidden_size=3712, use_cache=True)
model_2 = MambaForCausalLM(config).cuda()

(here n_mamba_heads=1 corresponds to the original Mamba architecture), and use this code for generation:

model.eval()
input_ids = torch.randint(1000, (1, 2048)).to(device=model.device)
with torch.no_grad():
    output = model.generate(input_ids, max_new_tokens=300, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)

We found that the total time spent computing this line


is 8.1s, and for this line
hidden_states = self.mixer(hidden_states, cache_params=cache_params)

is 6.3s.

@ArthurZucker
Copy link
Collaborator

cc @younesbelkada !

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR ! I left some minor suggestions in the modeling code for general improvements
Can you also make sure to rebase with main and make sure make fixup pass locally ? Let me know if you need any assistance!

docs/source/en/model_doc/zamba.md Show resolved Hide resolved
docs/source/en/model_doc/zamba.md Outdated Show resolved Hide resolved
src/transformers/models/zamba/configuration_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
@amazingvince
Copy link

I tried running basic training script with gradient accumulation and without on this fork and am getting this error:
File "/home/user/transformers_zamba/src/transformers/models/zamba/modeling_zamba.py", line 1051, in forward
hidden_states = hidden_states + from_tf if from_tf is not None else hidden_states
RuntimeError: The size of tensor a (7424) must match the size of tensor b (3712) at non-singleton dimension 2

The from_tf is not well described in the doc strings. Not sure what is not working here.

@pglorio
Copy link
Contributor Author

pglorio commented Jun 4, 2024

I tried running basic training script with gradient accumulation and without on this fork and am getting this error:
File "/home/user/transformers_zamba/src/transformers/models/zamba/modeling_zamba.py", line 1051, in forward
hidden_states = hidden_states + from_tf if from_tf is not None else hidden_states
RuntimeError: The size of tensor a (7424) must match the size of tensor b (3712) at non-singleton dimension 2

The from_tf is not well described in the doc strings. Not sure what is not working here.

Thanks for spotting this. We fixed the issue in the most recent push. Please try again and let us know if you still encounter issues.

We are adding more docstrings to explain various parts of the architecture. We will add the description below for from_tf around this line:

from_tf is the output of shared transformer + linear layer (these layers are shown in fig. 2 in https://arxiv.org/pdf/2405.16712). from_tf is then added to the input to the mamba layer (as described in eq. (6) of https://arxiv.org/pdf/2405.16712, where y_l in that equation is from_tf).

@pglorio
Copy link
Contributor Author

pglorio commented Jun 5, 2024

Thanks a lot for this PR ! I left some minor suggestions in the modeling code for general improvements
Can you also make sure to rebase with main and make sure make fixup pass locally ? Let me know if you need any assistance!

Thank you for the thorough review!

We ran make fixup and make fix-copies. Running again make fixup gives this output:

Checking/fixing src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py src/transformers/models/zamba/configuration_zamba.py src/transformers/models/zamba/modeling_zamba.py tests/models/roc_bert/test_tokenization_roc_bert.py
All checks passed!
4 files left unchanged
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
python utils/check_doc_toc.py --fix_and_overwrite
running deps_table_update
updating src/transformers/dependency_versions_table.py
python utils/check_copies.py
Traceback (most recent call last):
  File "/workspace/transformers_zamba/utils/check_copies.py", line 1106, in <module>
    check_copies(args.fix_and_overwrite, args.file)
  File "/workspace/transformers_zamba/utils/check_copies.py", line 856, in check_copies
    raise Exception(
Exception: Found the following copy inconsistencies:
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace at line 167
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.

It looks like make fix-copies is trying to correct parts of the code that are outside of our PR, and some of those fixes still fail. However, we now do not seem to get errors related to our PR.

We pushed the fixes done by make fix-copies.

@amazingvince
Copy link

amazingvince commented Jun 6, 2024

Tried training again and am now getting this:

/trainer.py", line 3250, in training_step
self.accelerator.backward(loss)
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/accelerate/accelerator.py", line 2127, in backward
loss.backward(**kwargs)
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
_engine_run_backward(
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
return user_fn(self, *args)
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward
_engine_run_backward(
File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to create tensor with negative dimension -274340773632: [-274340773632]

@Quentin-Anthony
Copy link
Contributor

Tried training again and am now getting this:

/trainer.py", line 3250, in training_step self.accelerator.backward(loss) File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/accelerate/accelerator.py", line 2127, in backward loss.backward(**kwargs) File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward torch.autograd.backward( File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward _engine_run_backward( File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply return user_fn(self, *args) File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward _engine_run_backward( File "/home/user/mambaforge/envs/zamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Trying to create tensor with negative dimension -274340773632: [-274340773632]

Hey there! We've been successfully using: huggingface/alignment-handbook@main...Zyphra:alignment-handbook:zamba-instruct here recently to do sft of Zamba. Does your setup meaningfully differ from this? I can't seem to reproduce, can you provide us a reproducer?

@ArthurZucker
Copy link
Collaborator

cc @younesbelkada should I review this or do you want to do another pass? 🤗

@amazingvince
Copy link

I am trying to extend the max context length.
{
"max_position_embeddings": 32768,
"rope_theta": 192144,
}

also tried at 16k.

I tried running in your fork of alignment handbook and saw the same results.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for your great work on this ! I left few minor improvements to address and some file changes to revert - can you make sure to make our CI happy (by making sure make fixup command passes + the tests pass pytest tests/models/zamba/ pass ) let me know if you need any help or have any question

docs/source/en/model_doc/zamba.md Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
docs/source/en/model_doc/zamba.md Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
src/transformers/models/zamba/modeling_zamba.py Outdated Show resolved Hide resolved
tests/models/roc_bert/test_tokenization_roc_bert.py Outdated Show resolved Hide resolved
@pglorio
Copy link
Contributor Author

pglorio commented Jun 18, 2024

Thanks very much for your great work on this ! I left few minor improvements to address and some file changes to revert - can you make sure to make our CI happy (by making sure make fixup command passes + the tests pass pytest tests/models/zamba/ pass ) let me know if you need any help or have any question

Thank you for your help, @younesbelkada!

We believe we have addressed most of the concerns you raised; we still have two pending questions:

  • pytest tests/models/zamba/: Pytest flags only test_initialization as failing. The specific issue arises with x_proj_weight and dt_proj_weight, where their mean is approximately 10^-2, contrary to the expected 10^-9. This discrepancy is expected, it is due to the initialization scheme using a variance of (d_input)^(-0.5), where d_input is approximately 100 in the test configuration. We implemented nn.Parameter(torch.rand(...)) for initialization of these parameters, which we verified is equivalent to the Kaiming initialization typically used for nn.Linear. It seems that transformers may apply additional steps for the initialization of various layer types, which might not extend to parameters such as x_proj_weight. We have adjusted the tolerance for these parameters to 10^-2 in the initialization test in this line of the test script. Please let us know if additional steps are required.

  • make fixup: After running it, we have this output:

- src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py: copy does not match models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention at line 720
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese at line 76
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower at line 85
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false at line 94
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true at line 103
- tests/models/roc_bert/test_tokenization_roc_bert.py: copy does not match models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default at line 112
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.
make: *** [Makefile:38: repo-consistency] Error 1

all the lines are related to files outside of our PR, so we did not change those files, although indeed I do see that the CircleCI tests performed in this PR still fail. Please, let us know if further action is needed here and what would be the steps we'd need to take.

Thank you so much for your time and help!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for iterating ! I left one minor suggestion, can you also merge your branch with upstream main branch? This should make the CI happy and tests should be green

=0.19, Outdated Show resolved Hide resolved
@pglorio pglorio force-pushed the main branch 2 times, most recently from 0ec2417 to 18e8372 Compare June 21, 2024 03:28
@pglorio
Copy link
Contributor Author

pglorio commented Jun 21, 2024

Thank you so much for your guidance. We tried to rebase our PR and ran into an error related to model generation. It looks like the rebased GenerationMixin.generate method instantiates Zamba's cache as a DynamicCache class https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L1775. This is different from HybridMambaAttentionDynamicCache which would be the expected class for Zamba's cache (defined here https://github.com/Zyphra/transformers_zamba/blob/main/src/transformers/models/zamba/modeling_zamba.py#L130). In the non-rebased fork, the cache is instantiated by GenerationMixin.generate in this line: https://github.com/Zyphra/transformers_zamba/blob/main/src/transformers/generation/utils.py#L2379, which correctly instantiates cache as HybridMambaAttentionDynamicCache.

For reference, these are the calls performed from model.generate to the instantiation of the cache object:
using the rebased fork:

-> output = model.generate(**tokenized_prompt, max_new_tokens=300, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
  /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py(115)decorate_context()
-> return func(*args, **kwargs)
  /workspace/transformers_zamba_rebased/src/transformers/generation/utils.py(1775)generate()
-> model_kwargs["past_key_values"] = DynamicCache()
> /workspace/transformers_zamba_rebased/src/transformers/cache_utils.py(305)__init__()

and using the fork before rebasing:

-> output = model.generate(**tokenized_prompt, max_new_tokens=300, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
  /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py(115)decorate_context()
-> return func(*args, **kwargs)
  /workspace/transformers_zamba/src/transformers/generation/utils.py(1743)generate()
-> result = self._sample(
  /workspace/transformers_zamba/src/transformers/generation/utils.py(2379)_sample()
-> model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  /workspace/transformers_zamba/src/transformers/models/zamba/modeling_zamba.py(1588)prepare_inputs_for_generation()
-> past_key_values = HybridMambaAttentionDynamicCache(
> /workspace/transformers_zamba/src/transformers/models/zamba/modeling_zamba.py(146)__init__()

Could you please let us know how we can force Zamba's cache to be HybridMambaAttentionDynamicCache in the rebased fork?

Thanks very much!

@ArthurZucker
Copy link
Collaborator

He! Super late in coming back to you, you should be able to force it by setting cache_class = "" in the ZambaPreTrainedModel class!

@pglorio
Copy link
Contributor Author

pglorio commented Aug 6, 2024

He! Super late in coming back to you, you should be able to force it by setting cache_class = "" in the ZambaPreTrainedModel class!

Hello @ArthurZucker, thank you for the suggestion! We tried adding cache_class = "" to this line but we still couldn't make generation work. As an alternative fix, we added "zamba" to this line, similarly to what was done with Jamba, in which case generation works fine.

We are happy to either keep this fix or to use the one you suggested, in which case we would appreciate if you could say a few more words on how to implement it.

Meanwhile, we rebased our fork. All the local tests with make fixup have passed, except for a few warnings shown below which are unrelated to the updates we implemented:

/workspace/transformers_zamba/src/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
/workspace/transformers_zamba/utils/check_repo.py:376: UserWarning: Full repo consistency checks require all backends to be installed (with `pip install -e '.[dev]'` in the Transformers repo, the following are missing: TensorFlow, Flax. While it's probably fine as long as you didn't make any change in one of those backends modeling files, you should probably execute the command above to be on the safe side.
/workspace/transformers_zamba/src/transformers/models/deit/image_processing_deit.py:87: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  resample: PILImageResampling = PIL.Image.BICUBIC,
/workspace/transformers_zamba/src/transformers/models/chameleon/image_processing_chameleon.py:116: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
  resample: PILImageResampling = PIL.Image.LANCZOS,
/workspace/transformers_zamba/src/transformers/models/efficientnet/image_processing_efficientnet.py:92: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  resample: PILImageResampling = PIL.Image.NEAREST,

I see most of the CircleCI tests for this PR are still failing, please let me know if more needs to be done to fix those.

Thank you so much!

@ArthurZucker
Copy link
Collaborator

I'll have a look!

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.

@hg0428
Copy link

hg0428 commented Oct 1, 2024

Does this include support for Zamba2?

@pglorio
Copy link
Contributor Author

pglorio commented Oct 1, 2024

@hg0428 thanks for asking! Support for Zamba2 will be added in a follow-up PR. Meanwhile, you can install Zyphra's local transformers as described in the Zamba2's model card.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! 🤗

@hg0428
Copy link

hg0428 commented Oct 1, 2024

@hg0428 thanks for asking! Support for Zamba2 will be added in a follow-up PR. Meanwhile, you can install Zyphra's local transformers as described in the Zamba2's model card.

Unfortunately, that does not work on my device. Zamba2 transformers runs on mamba_ssm, which requires an NVIDIA GPU. I have Apple Silicon. See my issue: Zyphra/transformers_zamba2#3

@ArthurZucker
Copy link
Collaborator

Just waiting for https://github.com/huggingface/transformers/actions/runs/11116137341/job/30897696145?pr=30950#step:12:64 to be fixed! (related to accelerate and auto device, good that we have this test!)

@pglorio
Copy link
Contributor Author

pglorio commented Oct 4, 2024

Hi @Arthur, thank you again for reviewing.

The test mentioned above test_multi_gpu_data_parallel_forward now passes. We had to change some of the shared layers logic for it to work. Previously, self.mamba_layers and self.linear_layers were both nn.ModuleList objects and self.layers was not, which prevented most of the layers from being scattered across devices. Now only self.layers is nn.ModuleList and everything seems to work.

Additionally, we updated all the model's checkpoints on the hub since this involved changing some of the weight keys related to the shared layers. Separately, given we updated the checkpoints, we also swapped up<->gate in the MLP weight keys as well as in the forward pass so this issue is now addressed.

All tests related to zamba appear to pass. Thank you!

@hg0428
Copy link

hg0428 commented Oct 4, 2024

Hi @Arthur, thank you again for reviewing.

The test mentioned above test_multi_gpu_data_parallel_forward now passes. We had to change some of the shared layers logic for it to work. Previously, self.mamba_layers and self.linear_layers were both nn.ModuleList objects and self.layers was not, which prevented most of the layers from being scattered across devices. Now only self.layers is nn.ModuleList and everything seems to work.

Additionally, we updated all the model's checkpoints on the hub since this involved changing some of the weight keys related to the shared layers. Separately, given we updated the checkpoints, we also swapped up<->gate in the MLP weight keys as well as in the forward pass so this issue is now addressed.

All tests related to zamba appear to pass. Thank you!

Does this Zamba support work on Apple Silicon?

@Quentin-Anthony
Copy link
Contributor

Hi @Arthur, thank you again for reviewing.
The test mentioned above test_multi_gpu_data_parallel_forward now passes. We had to change some of the shared layers logic for it to work. Previously, self.mamba_layers and self.linear_layers were both nn.ModuleList objects and self.layers was not, which prevented most of the layers from being scattered across devices. Now only self.layers is nn.ModuleList and everything seems to work.
Additionally, we updated all the model's checkpoints on the hub since this involved changing some of the weight keys related to the shared layers. Separately, given we updated the checkpoints, we also swapped up<->gate in the MLP weight keys as well as in the forward pass so this issue is now addressed.
All tests related to zamba appear to pass. Thank you!

Does this Zamba support work on Apple Silicon?

I don't believe so. We're working on MLX support in a separate (private for now) vein of work from this PR, which just seeks to get basic GPU integration into upstream HuggingFace Transformers.

@ArthurZucker ArthurZucker merged commit f319ba1 into huggingface:main Oct 4, 2024
17 of 21 checks passed
@Quentin-Anthony
Copy link
Contributor

🎉 🎉 🎉 🎉 🎉 🎉 🎉 🎉 🎉 🎉

@ArthurZucker
Copy link
Collaborator

🚀

@fakerybakery
Copy link

Hi,
Are there any plans to add Zamba2 to Transformers?
Thanks!

@ArthurZucker
Copy link
Collaborator

I think the Zyphra team is already working on it!

@hg0428
Copy link

hg0428 commented Oct 15, 2024

Hopefully we get Apple Silicon support for Zamba and Zamba2 soon.

NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* Update index.md

* Rebase

* Rebase

* Updates from make fixup

* Update zamba.md

* Batched inference

* Update

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Update docs/source/en/model_doc/zamba.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/zamba.md

Co-authored-by: Arthur <[email protected]>

* Update configuration_zamba.py

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update modeling_zamba.py

* Update modeling_zamba.py

* Update modeling_zamba.py

* Update configuration_zamba.py

* Update modeling_zamba.py

* Update modeling_zamba.py

* Merge branch 'main' of https://github.com/Zyphra/transformers_zamba

* Update ZambaForCausalLM

* Update ZambaForCausalLM

* Describe diffs with original mamba layer

* Moved mamba init into `_init_weights`

* Update index.md

* Rebase

* Rebase

* Updates from make fixup

* Update zamba.md

* Batched inference

* Update

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Update docs/source/en/model_doc/zamba.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/zamba.md

Co-authored-by: Arthur <[email protected]>

* Update configuration_zamba.py

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update modeling_zamba.py

* Update modeling_zamba.py

* Update modeling_zamba.py

* Update configuration_zamba.py

* Update modeling_zamba.py

* Update modeling_zamba.py

* Merge branch 'main' of https://github.com/Zyphra/transformers_zamba

* Update ZambaForCausalLM

* Moved mamba init into `_init_weights`

* Update ZambaForCausalLM

* Describe diffs with original mamba layer

* make fixup fixes

* quality test fixes

* Fix Zamba model path

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* circleci fixes

* Update

* circleci fixes

* fix zamba test from merge

* fix ValueError for disabling mamba kernels

* add HF copyright

Co-authored-by: Arthur <[email protected]>

* shared_transf --> shared_transformer

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Fixes

* Move attention head dim to config

* Fix circle/ci tests

* Update modeling_zamba.py

* apply GenerationMixin inheritance change from upstream

* apply import ordering

* update needed transformers version for zamba

Co-authored-by: Arthur <[email protected]>

* add contribution author

* add @slow to avoid CI

* Update src/transformers/models/zamba/modeling_zamba.py

Co-authored-by: Arthur <[email protected]>

* Define attention_hidden_size

* Added doc for attention_head_size

* trigger CI

* Fix doc of attention_hidden_size

* [run-slow] zamba

* Fixed shared layer logic, swapped up<->gate in mlp

* shared_transformer -> shared_transf

* reformat HybridLayer __init__

* fix docstrings in zamba config

* added definition of _get_input_ids_and_config

* fixed formatting of _get_input_ids_and_config

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants