Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add patches for memory_efficient_attention and NTK scaling #743

Merged
merged 9 commits into from
Jul 18, 2023
Merged

Conversation

airaria
Copy link
Contributor

@airaria airaria commented Jul 13, 2023

Description

What does this PR do?

  • Adding NTK scaling patch function apply_ntk_scaling_patch()

    1. Refactoring the NTK scaling code (Extend context size without fine-tuning #705) by moving the relevant code into patches.py to make the inference code clean and simplify the usage.

    2. Adding the alpha parameter for NTK scaling to the patch function and inference scripts. See the usage below for explanations.

  • Adding attention patch function apply_attention_patch()

    1. Adding support for inference with memory_efficient_attention. On a single 24G GPU, the maximum context size can be scaled up to about 5K without exceed the GPU memory (model loaded in fp16).

    2. Add an option for storing KV_cache before applying RoPE.

  • Updating inference_hf.py, gradio_demo.py and openai_api_server.py to showcase NTK scaling and memory_efficient_attention.

Usage

alpha=2.0 # alpha can be a float, a string representing a float,  or 'auto'
use_memory_efficient_attention=True # True or False
store_kv_before_rope=False # True or False

# The following code should be placed before model initialization
from patches import apply_attention_patch, apply_ntk_scaling_patch
apply_attention_patch(
    use_memory_efficient_attention=use_memory_efficient_attention,
    store_kv_before_rope=store_kv_before_rope
)
apply_ntk_scaling_patch(alpha=alpha)

Parameters

  • alpha: If 'auto', alpha is calculated with the empirical formula alpha = (seq_len / 1024 - 1) * 1.1 during generation, otherwise alpha is set to the fixed float value given.
  • use_memory_efficient_attention: If use memory_efficient_attention from xformers or not. Default is False.
  • store_kv_before_rope: If store KV_cache before applying RoPE or not . Default is False.

Advices

  • Set use_memory_efficient_attention=True to save GPU memory when processing long texts.
  • Set alpha to a float value (>1) to apply NTK scaling to support long context. Emperically, we find alpha = (seq_len / 1024 - 1) may be a good choice, where seq_len is the estimated context size (sum of the lengths of the input and the output).
  • Set alpha to 'auto' to let the model determine the value of alpha dynamically and adatively.
  • Set store_kv_before_rope=True if alpha='auto' and if you encounter performance degradation. See the discussion here.

@airaria airaria marked this pull request as ready for review July 15, 2023 06:56
@airaria airaria requested a review from ymcui July 17, 2023 00:55
@ymcui ymcui changed the title add pathces for memory_efficient_attention and NTK scaling Add patches for memory_efficient_attention and NTK scaling Jul 18, 2023
Copy link
Owner

@ymcui ymcui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this first and we need some wiki documentation updates after the release of v5.0.

@ymcui ymcui merged commit f27945f into main Jul 18, 2023
@ymcui ymcui deleted the patches branch July 19, 2023 08:48
@IT-five
Copy link

IT-five commented Dec 10, 2023

我想请问一下,源码中如下,对长度超过max_length的进行了截断,但在NTK实现里又要求"if seq_len > self.max_seq_len_cached:",那是不是意味着永远不会超过self.max_seq_len_cached,那怎么支持NTK外推上下文呢?

if len(tokenized_prompt) > max_length:
            half = int(max_length/2)
            prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants