Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for mamba1 #1213

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Goekdeniz-Guelmez
Copy link
Contributor

@Goekdeniz-Guelmez Goekdeniz-Guelmez commented Jan 20, 2025

This PR optimizes the MambaBlock implementation to improve performance and cache handling while maintaining compatibility with the existing MambaCache interface.

  1. Batched Input Processing

    • Processes all tokens simultaneously through in_proj instead of one at a time
    • Reduces number of operations and improves parallelization
    • Uses reshape to maintain batch dimensions efficiently
  2. Refactored Core Logic

    • Introduced _process_sequence method to separate main processing logic from cache handling
    • Improves code maintainability and makes future optimizations easier
    • Clearer separation between state management and computation
  3. Enhanced Cache Handling

    • Added robust type checking for different cache formats
    • Maintains backwards compatibility with list-based caches
    • Explicit handling of MambaCache objects
  4. Memory and Computation Optimizations

    • Pre-computes A matrix outside the token loop
    • Better state tracking through sequence processing
    • More explicit memory management

Hardware M4 MacMini

Before:

state-spaces/mamba-130m-hf

Prompt: 5 tokens, 41.293 tokens-per-sec
Generation: 100 tokens, 113.233 tokens-per-sec
Peak memory: 0.529 GB

mlx-community/Falcon3-Mamba-7B-Instruct-4bits

Prompt: 22 tokens, 14.359 tokens-per-sec
Generation: 100 tokens, 15.100 tokens-per-sec
Peak memory: 4.218 GB

After

state-spaces/mamba-130m-hf

Prompt: 5 tokens, 129.130 tokens-per-sec
Generation: 100 tokens, 106.952 tokens-per-sec
Peak memory: 0.530 GB

mlx-community/Falcon3-Mamba-7B-Instruct-4bits

Prompt: 22 tokens, 28.364 tokens-per-sec
Generation: 100 tokens, 14.512 tokens-per-sec
Peak memory: 4.164 GB

… Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant