From 9cc12a44bc8bf47730af9a80e7f154fabc83653e Mon Sep 17 00:00:00 2001 From: naubull2 Date: Wed, 27 Sep 2023 23:40:27 +0900 Subject: [PATCH 01/16] [add] gpt-neox support --- fine-tune.py | 18 ++++- gptneox_attn_replace.py | 169 ++++++++++++++++++++++++++++++++++++++++ supervised-fine-tune.py | 16 +++- 3 files changed, 198 insertions(+), 5 deletions(-) create mode 100644 gptneox_attn_replace.py diff --git a/fine-tune.py b/fine-tune.py index 2bbc82ce..a9a3e72b 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -24,6 +24,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -39,7 +40,8 @@ @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neox-20b") + model_type: Optional[str] = field(default="gpt-neox") @dataclass class TrainingArguments(transformers.TrainingArguments): @@ -99,7 +101,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) model_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -157,10 +163,16 @@ def train(): data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py new file mode 100644 index 00000000..df9b9a4d --- /dev/null +++ b/gptneox_attn_replace.py @@ -0,0 +1,169 @@ +# Modified based on https://github.com/dvlab-research/LongLoRA + +from typing import Optional, Tuple +import warnings +import torch +import transformers +from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb + +from flash_attn import flash_attn_varlen_func + + +group_size_ratio = 1/4 + + +def _flash_attn(query, key, value, attention_mask=None, head_mask=None): + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # q, k, v: [bs, nh, seq_len, hd] + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + value_length = value.size(-2) + + # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd] + query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size) + key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size) + value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size) + + attn_dropout = 0.0 # TODO: attach to config + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * query_length, + step=query_length, + dtype=torch.int32, + device=query.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * key_length, + step=key_length, + dtype=torch.int32, + device=key.device, + ) + + attn_output, attn_weights, _ = flash_attn_varlen_func( + query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=attn_dropout, + softmax_scale=None, causal=True, return_attn_probs=True + ) + + attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2) + return attn_output, attn_weights + + +def get_forward_function(use_flash_attn=True, use_full=False): + + def forward_attention( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + # NOTE: compute SS group size + bsz, q_len, _ = hidden_states.size() + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] + # --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + # [bsz, nh, q_len, hd] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # NOTE: apply shift + if self.training and not use_full: + def shift(qkv, num_heads, head_dim): + # qkv = [bsz, nh, q_len, d] + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + num_group = q_len // group_size + qkv = qkv.transpose(1, 2) + # qkv = [bsz, q_len, nh, d] + qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) + # -> [bsz * n_group, group_s, nh, d) + # -> [bsz * n_group, nh, group_s, d) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + return qkv + + query = shift(query, self.num_attention_heads, self.head_size) + key = shift(key, self.num_attention_heads, self.head_size) + value = shift(value, self.num_attention_heads, self.head_size) + + # Compute attention + if use_flash_attn: + attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # NOTE: shift back + if self.training and not use_full: + attn_output = attn_output.transpose(1, 2) + # [bsz, q_len, nh, hd] + attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1) + attn_output = attn_output.transpose(1, 2) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward_attention + + +def replace_gpt_neox_attn(use_flash_attn=True, use_full=False): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if use_flash_attn and cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + "Resorting to plain attention..." + ) + use_flash_attn = False + + forward_fn = get_forward_function(use_flash_attn, use_full) + transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index bf22029a..cecfb9b9 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -27,6 +27,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -65,6 +66,7 @@ def jload(f, mode="r"): @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_type: Optional[str] = field(default="gpt-neox") @dataclass @@ -219,7 +221,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn, True) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -266,10 +272,16 @@ def train(): data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", From 0d476ea4d51cf54304290e57b27649cb39aa0e22 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Wed, 27 Sep 2023 23:40:42 +0900 Subject: [PATCH 02/16] [update] readme --- README.md | 81 ++++++++----------------------------------------------- 1 file changed, 11 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index ff2c503b..59512a95 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,11 @@ -[![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](https://2060079530708e861d.gradio.live) - -# LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models - -## News -- [x] [2023.9.22] We release our **13B and 70B 32k models with the supervised fine-tuning**, which is feasible for long context QA. Please check [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). To our best knowledge, **this is the first work that release 70B model with 32k context length**. -- [x] [2023.9.22] We release all our fine-tuned [models](https://huggingface.co/Yukang), including **70B-32k models**, [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k), [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft). Welcome to check them out! -- [x] [2023.9.22] We release [Paper](http://arxiv.org/abs/2309.12307) and this GitHub repo, including training and evaluation code. - -**LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models [[Paper](http://arxiv.org/abs/2309.12307)]**
-[Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en), -[Shengju Qian](https://scholar.google.com/citations?user=QNnWmasAAAAJ), -[Haotian Tang](https://scholar.google.com/citations?user=WxL13BAAAAAJ&hl), -[Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN), -[Zhijian Liu](https://scholar.google.com/citations?user=3coYSTUAAAAJ&hl=en), -[Song Han](https://scholar.google.com/citations?user=E0iCaa4AAAAJ&hl=zh-CN), -[Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)
- -
**Paper** | **Models** | [**Training**](#training) | [**Inference**](#inference) | **Online Demo**
- -

-

-

-

-

-

- -## Abstract -We present LongLoRA, an efficient fine-tuning approach that extends the context sizes of pre-trained large language models (LLMs), with limited computation cost. -Typically, training LLMs with long context sizes is computationally expensive, requiring extensive training hours and GPU resources. -In this paper, we speed up the context extension of LLMs in two aspects. On the one hand, although dense global attention is needed during inference, fine-tuning the model can be effectively and efficiently done by sparse local attention. The proposed shift short attention effectively enables context extension, leading to non-trivial computation saving with similar performance to fine-tuning with vanilla attention. On the other hand, we find that LoRA for context extension works well under the premise of trainable embedding and normalization. LongLoRA demonstrates strong empirical results on various tasks on LLaMA2 models from 7B/13B to 70B. LongLoRA adopts LLaMA2 7B from 4k context to 100k, or LLaMA2 70B to 32k on a single 8x A100 machine. LongLoRA extends models' context while retaining their original architectures, and is compatible with most existing techniques, like FlashAttention-2. In addition, to make LongLoRA practical, we collect a dataset, LongQA, for supervised fine-tuning. It contains more than 3k long context question-answer pairs. For more details, please refer to the [paper](http://arxiv.org/abs/2309.12307). +# LongLoRA (with GPTNeoX support): Efficient Fine-tuning of Long-Context Large Language Models +This repo provides on top of the original implementation, support for GPTNeoX with Flash-Attention and the LongLoRA's shifted short attention as needed. ## Highlights **LongLoRA** speed up the context extension of pre-trained large language models in both attention-level and weight-level. 1. The proposed shifted short attention is easy to implement, compatible with Flash-Attention, and not required during inference. -2. We release all our models, including models from 7B to 70B, context length from 8k to 100k, including [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft), [LLaMA2-LongLoRA-13B-64k](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k), and [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k). -3. We build up a long-context QA dataset, LongQA, for supervised fine-tuning (SFT). We release 13B and 70B 32k models with SFT, [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). We will further release the dataset in the next month. + ## Installation ``` @@ -43,46 +13,16 @@ pip install -r requirements.txt pip install flash-attn --no-build-isolation ``` -## Released models - -### Models with supervised fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------------|------|---------|---------|-------------------------------------------------------------------------| -| Llama-2-13b-chat-longlora-32k-sft | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) | -| Llama-2-70b-chat-longlora-32k-sft | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft) | - -### Models with context extension via fully fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------|------|---------|-------|-------------------------------------------------------------------| -| Llama-2-7b-longlora-8k-ft | 7B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k-ft) | -| Llama-2-7b-longlora-16k-ft | 7B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k-ft) | -| Llama-2-7b-longlora-32k-ft | 7B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k-ft) | -| Llama-2-7b-longlora-100k-ft | 7B | 100000 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft) | -| Llama-2-13b-longlora-8k-ft | 13B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k-ft) | -| Llama-2-13b-longlora-16k-ft | 13B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k-ft) | -| Llama-2-13b-longlora-32k-ft | 13B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k-ft) | - -### Models with context extension via improved LoRA fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------|------|---------|-------|-------------------------------------------------------------------| -| Llama-2-7b-longlora-8k | 7B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k) | -| Llama-2-7b-longlora-16k | 7B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k) | -| Llama-2-7b-longlora-32k | 7B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k) | -| Llama-2-13b-longlora-8k | 13B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k) | -| Llama-2-13b-longlora-16k | 13B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k) | -| Llama-2-13b-longlora-32k | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k) | -| Llama-2-13b-longlora-64k | 13B | 65536 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k) | -| Llama-2-70b-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k) | -| Llama-2-70b-chat-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k) | - ## Training ### Pre-trained weights -We use LLaMA2 models as the pre-trained weights and fine-tune them to long context window sizes. Please download [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), [Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf), and [Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf), based on your choices. +I used GPTNeoX model as the base model architecture, which was ported from the authors' original repo where Llama2 was used. +Some candidate pre-trained weights may include [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Polyglot-ko-12.8B](https://huggingface.co/EleutherAI/polyglot-ko-12.8b) and other variants. + ### Fine-tuning ``` torchrun --nproc_per_node=8 fine-tune.py \ - --model_name_or_path path_to/Llama-2-7b-hf \ + --model_name_or_path path_to/gpt_neox_model_hf \ --bf16 True \ --output_dir path_to_saving_checkpoints \ --cache_dir path_to_cache \ @@ -107,7 +47,7 @@ torchrun --nproc_per_node=8 fine-tune.py \ --max_steps 1000 ``` -- Please remember to change `path_to/Llama-2-7b-hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. +- Please remember to change `path_to/gpt_neox_model_hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. - Note that you can change `model_max_length` to other values. - You could change `ds_configs/stage2.json` to `ds_configs/stage3.json` if you want. - Please set `use_flash_attn` as `False` if you use V100 machines or do not install flash attention. @@ -143,7 +83,7 @@ torchrun --nproc_per_node=8 supervised-fine-tune.py \ --deepspeed "ds_configs/stage2.json" \ --tf32 True ``` -- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models`, like `Llama-2-13b-longlora-32k` or `Llama-2-13b-longlora-32k-ft`. +- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models` - During our dataset collection, it is hard for us to collect many high-quality QA that are larger than 32768. Thus, if you use our `LongQA.json`, please also set `model_max_length` as 32768. @@ -282,7 +222,8 @@ If you find this project useful in your research, please consider citing: ``` ## Acknowledgement -- This work is built upon the [LLaMA2](https://ai.meta.com/llama) as the pre-trained models. +- This work is an GPTNeoX port of the work from the original authors' code. [LongLoRA](https://github.com/dvlab-research/LongLoRA) +- This work is built upon the [GPTNeoX-HF](https://huggingface.co/docs/transformers/model_doc/gpt_neox) which is based upon [EleutherAI/GPTNeoX](https://github.com/EleutherAI/gpt-neox) as the pre-trained model architecture. - This work is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed), [peft](https://github.com/huggingface/peft), and [Flash-Attention2](https://github.com/Dao-AILab/flash-attention) for acceleration. - Some evaluation code is modified upon [Landmark Attention](https://github.com/epfml/landmark-attention). - We use [LongChat](https://github.com/DachengLi1/LongChat) for the retrieval evaluation. From 11baea3b7231ec627a6a9884253abd04e8e024a0 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Thu, 28 Sep 2023 23:17:40 +0900 Subject: [PATCH 03/16] [fix] some of the bugs preventing fine-tune run + There's still bugs in the attention dimensions mismatch --- fine-tune.py | 5 +++-- gptneox_attn_replace.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index a9a3e72b..185344fd 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -40,7 +40,7 @@ @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neox-20b") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") model_type: Optional[str] = field(default="gpt-neox") @dataclass @@ -123,6 +123,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -130,7 +131,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index df9b9a4d..4670e5dd 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -110,13 +110,13 @@ def forward_attention( present = (key, value) if use_cache else None # NOTE: apply shift + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + num_group = q_len // group_size if self.training and not use_full: def shift(qkv, num_heads, head_dim): # qkv = [bsz, nh, q_len, d] - group_size = int(q_len * group_size_ratio) - if q_len % group_size > 0: - raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) - num_group = q_len // group_size qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) @@ -139,7 +139,7 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] - attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1) + attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2) # Reshape outputs From fcd5b3c6b549bf1ebcebdd56e232a6a5def96a30 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Fri, 29 Sep 2023 00:00:12 +0900 Subject: [PATCH 04/16] [fix] dimesion discrepancy between attention mask and the query length + group batch attention is skipped to avoid this problem for now --- gptneox_attn_replace.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 4670e5dd..e2a48184 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -120,9 +120,12 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) + qkv = qkv.transpose(1, 2) + + # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) - qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + #qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv query = shift(query, self.num_attention_heads, self.head_size) From ec3a31ff7b6bdf22f4ba531fff675008f0d63b3e Mon Sep 17 00:00:00 2001 From: naubull2 Date: Fri, 29 Sep 2023 00:03:57 +0900 Subject: [PATCH 05/16] [fix] SFT to match the same mods in finetune.py --- supervised-fine-tune.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index cecfb9b9..fefbc2b2 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -65,7 +65,7 @@ def jload(f, mode="r"): @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") model_type: Optional[str] = field(default="gpt-neox") @@ -243,6 +243,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -250,7 +251,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() @@ -290,6 +291,7 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] + model.config.use_cache = False # required for gradient checkpointing model.enable_input_require_grads() # required for gradient checkpointing model.gradient_checkpointing_enable() # enable gradient checkpointing From 8887749f598daef3f1df441458fa294f5f8950e6 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:05:28 +0900 Subject: [PATCH 06/16] [add] parallel group attention then reshape back to original form --- gptneox_attn_replace.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index e2a48184..3f398580 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -120,12 +120,12 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - qkv = qkv.transpose(1, 2) + #qkv = qkv.transpose(1, 2) # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) - #qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv query = shift(query, self.num_attention_heads, self.head_size) @@ -140,7 +140,9 @@ def shift(qkv, num_heads, head_dim): # NOTE: shift back if self.training and not use_full: - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, num_heads, head_dim) + #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2) From 7c617fe60219cf47758c754c2a6e097951c2d9b9 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:39:39 +0900 Subject: [PATCH 07/16] [fix] non-contiguous dimensions changing view issue --- gptneox_attn_replace.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 3f398580..13531d13 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -128,9 +128,10 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv - query = shift(query, self.num_attention_heads, self.head_size) - key = shift(key, self.num_attention_heads, self.head_size) - value = shift(value, self.num_attention_heads, self.head_size) + # contiguous is required as self._attn() will attempt to apply .view() on them + query = shift(query, self.num_attention_heads, self.head_size).contiguous() + key = shift(key, self.num_attention_heads, self.head_size).contiguous() + value = shift(value, self.num_attention_heads, self.head_size).contiguous() # Compute attention if use_flash_attn: @@ -141,7 +142,7 @@ def shift(qkv, num_heads, head_dim): # NOTE: shift back if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, num_heads, head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) From 765a7f2cc5fb800de39c1b22eef0a377cb1a233c Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:51:58 +0900 Subject: [PATCH 08/16] [add] attention mask to align with the grouped batching --- gptneox_attn_replace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 13531d13..37a2c405 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -132,6 +132,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention if use_flash_attn: From 2cd510587cbfb75727e71515d03cbf6f6fa6b607 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 19:47:11 +0900 Subject: [PATCH 09/16] [add] torch autocast for flash attention safety + flash attention only supports in fp16/bf16 --- fine-tune.py | 19 ++++++++++--------- supervised-fine-tune.py | 11 ++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index 185344fd..e05d81dc 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -182,15 +182,16 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing - - trainer = Trainer( - model=model, tokenizer=tokenizer, args=training_args, - train_dataset=dataset["train"], - eval_dataset=None, - data_collator=data_collator) - trainer.train() + with torch.cuda.amp.autocast(dtype=model.dtype): + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, + train_dataset=dataset["train"], + eval_dataset=None, + data_collator=data_collator) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index fefbc2b2..1f19db73 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -291,12 +291,13 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing + with torch.cuda.amp.autocast(dtype=model.dtype): + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) - trainer.train() + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) From b8b343feef72208336c4612ced80273254de260a Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 19:47:48 +0900 Subject: [PATCH 10/16] [fix] HF built-in rotary embedding is not compatible with flash-attention + cos/sin cache tensor is not trained parameter, so it's not autocast along with other model parameters through `torch_dtype`. --- gptneox_attn_replace.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 37a2c405..1ffea54e 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -4,13 +4,20 @@ import warnings import torch import transformers -from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb from flash_attn import flash_attn_varlen_func group_size_ratio = 1/4 +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed def _flash_attn(query, key, value, attention_mask=None, head_mask=None): # Flash attention codes from From fb6e5a250a5726fd5a83f4973e02d7aa178b91ef Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 20:06:01 +0900 Subject: [PATCH 11/16] [add] missing local reference for rotate_half --- gptneox_attn_replace.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 1ffea54e..73b44c42 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -10,6 +10,12 @@ group_size_ratio = 1/4 +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) From efd13eacf63d01c66905ee78cbfed5499690235f Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 21:21:36 +0900 Subject: [PATCH 12/16] [rollback] torch.cuda autocast causes half precision error + Works fine without the torch.cuda autocast context, so rollback. --- fine-tune.py | 19 +++++++++---------- supervised-fine-tune.py | 11 +++++------ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index e05d81dc..9d2e37fc 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -182,16 +182,15 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - with torch.cuda.amp.autocast(dtype=model.dtype): - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer( - model=model, tokenizer=tokenizer, args=training_args, - train_dataset=dataset["train"], - eval_dataset=None, - data_collator=data_collator) - trainer.train() + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, + train_dataset=dataset["train"], + eval_dataset=None, + data_collator=data_collator) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index 1f19db73..fefbc2b2 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -291,13 +291,12 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - with torch.cuda.amp.autocast(dtype=model.dtype): - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) - trainer.train() + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) From 4d330d04391daa96a9c79c9dd5aaa72a66d02073 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:35:10 +0900 Subject: [PATCH 13/16] [fix] flash attention causing in-place operation runtime errors --- gptneox_attn_replace.py | 57 +++++++++++++---------------------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 73b44c42..94dc5a9c 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -5,7 +5,9 @@ import torch import transformers -from flash_attn import flash_attn_varlen_func +from einops import rearrange +from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input group_size_ratio = 1/4 @@ -25,45 +27,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + def _flash_attn(query, key, value, attention_mask=None, head_mask=None): - # Flash attention codes from - # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py - - # q, k, v: [bs, nh, seq_len, hd] - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - value_length = value.size(-2) - - # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd] - query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size) - key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size) - value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size) - - attn_dropout = 0.0 # TODO: attach to config - - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * query_length, - step=query_length, - dtype=torch.int32, - device=query.device, - ) - - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * key_length, - step=key_length, - dtype=torch.int32, - device=key.device, - ) - - attn_output, attn_weights, _ = flash_attn_varlen_func( - query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=attn_dropout, - softmax_scale=None, causal=True, return_attn_probs=True - ) - - attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2) - return attn_output, attn_weights + # transform the data into the qkv packed form + qkv = torch.stack( + [query, key, value], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + bsz, q_len = qkv.shape[:2] + + qkv = rearrange(qkv, "b s ... -> (b s) ...") + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + + # disable attn weights by returning None when using flash attention + return output, None def get_forward_function(use_flash_attn=True, use_full=False): From 641b907b02439e54e4a8b0b91d7c8e1c2ca4b2cb Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:36:42 +0900 Subject: [PATCH 14/16] [fix] mixed use of tabs and spaces --- gptneox_attn_replace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 94dc5a9c..fda6a221 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -41,8 +41,8 @@ def _flash_attn(query, key, value, attention_mask=None, head_mask=None): output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - # disable attn weights by returning None when using flash attention - return output, None + # disable attn weights by returning None when using flash attention + return output, None def get_forward_function(use_flash_attn=True, use_full=False): From 72ce677297f8272b4210d3f6979e8f91b4fc546c Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:45:21 +0900 Subject: [PATCH 15/16] [change] readme back to where it came from the original repo --- README.md | 81 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 59512a95..ff2c503b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,41 @@ -# LongLoRA (with GPTNeoX support): Efficient Fine-tuning of Long-Context Large Language Models +[![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](https://2060079530708e861d.gradio.live) + +# LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models + +## News +- [x] [2023.9.22] We release our **13B and 70B 32k models with the supervised fine-tuning**, which is feasible for long context QA. Please check [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). To our best knowledge, **this is the first work that release 70B model with 32k context length**. +- [x] [2023.9.22] We release all our fine-tuned [models](https://huggingface.co/Yukang), including **70B-32k models**, [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k), [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft). Welcome to check them out! +- [x] [2023.9.22] We release [Paper](http://arxiv.org/abs/2309.12307) and this GitHub repo, including training and evaluation code. + +**LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models [[Paper](http://arxiv.org/abs/2309.12307)]**
+[Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en), +[Shengju Qian](https://scholar.google.com/citations?user=QNnWmasAAAAJ), +[Haotian Tang](https://scholar.google.com/citations?user=WxL13BAAAAAJ&hl), +[Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN), +[Zhijian Liu](https://scholar.google.com/citations?user=3coYSTUAAAAJ&hl=en), +[Song Han](https://scholar.google.com/citations?user=E0iCaa4AAAAJ&hl=zh-CN), +[Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)
+ +
**Paper** | **Models** | [**Training**](#training) | [**Inference**](#inference) | **Online Demo**
+ +

+

+

+

+

+

+ +## Abstract +We present LongLoRA, an efficient fine-tuning approach that extends the context sizes of pre-trained large language models (LLMs), with limited computation cost. +Typically, training LLMs with long context sizes is computationally expensive, requiring extensive training hours and GPU resources. +In this paper, we speed up the context extension of LLMs in two aspects. On the one hand, although dense global attention is needed during inference, fine-tuning the model can be effectively and efficiently done by sparse local attention. The proposed shift short attention effectively enables context extension, leading to non-trivial computation saving with similar performance to fine-tuning with vanilla attention. On the other hand, we find that LoRA for context extension works well under the premise of trainable embedding and normalization. LongLoRA demonstrates strong empirical results on various tasks on LLaMA2 models from 7B/13B to 70B. LongLoRA adopts LLaMA2 7B from 4k context to 100k, or LLaMA2 70B to 32k on a single 8x A100 machine. LongLoRA extends models' context while retaining their original architectures, and is compatible with most existing techniques, like FlashAttention-2. In addition, to make LongLoRA practical, we collect a dataset, LongQA, for supervised fine-tuning. It contains more than 3k long context question-answer pairs. For more details, please refer to the [paper](http://arxiv.org/abs/2309.12307). -This repo provides on top of the original implementation, support for GPTNeoX with Flash-Attention and the LongLoRA's shifted short attention as needed. ## Highlights **LongLoRA** speed up the context extension of pre-trained large language models in both attention-level and weight-level. 1. The proposed shifted short attention is easy to implement, compatible with Flash-Attention, and not required during inference. - +2. We release all our models, including models from 7B to 70B, context length from 8k to 100k, including [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft), [LLaMA2-LongLoRA-13B-64k](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k), and [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k). +3. We build up a long-context QA dataset, LongQA, for supervised fine-tuning (SFT). We release 13B and 70B 32k models with SFT, [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). We will further release the dataset in the next month. ## Installation ``` @@ -13,16 +43,46 @@ pip install -r requirements.txt pip install flash-attn --no-build-isolation ``` +## Released models + +### Models with supervised fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------------|------|---------|---------|-------------------------------------------------------------------------| +| Llama-2-13b-chat-longlora-32k-sft | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) | +| Llama-2-70b-chat-longlora-32k-sft | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft) | + +### Models with context extension via fully fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------|------|---------|-------|-------------------------------------------------------------------| +| Llama-2-7b-longlora-8k-ft | 7B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k-ft) | +| Llama-2-7b-longlora-16k-ft | 7B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k-ft) | +| Llama-2-7b-longlora-32k-ft | 7B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k-ft) | +| Llama-2-7b-longlora-100k-ft | 7B | 100000 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft) | +| Llama-2-13b-longlora-8k-ft | 13B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k-ft) | +| Llama-2-13b-longlora-16k-ft | 13B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k-ft) | +| Llama-2-13b-longlora-32k-ft | 13B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k-ft) | + +### Models with context extension via improved LoRA fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------|------|---------|-------|-------------------------------------------------------------------| +| Llama-2-7b-longlora-8k | 7B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k) | +| Llama-2-7b-longlora-16k | 7B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k) | +| Llama-2-7b-longlora-32k | 7B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k) | +| Llama-2-13b-longlora-8k | 13B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k) | +| Llama-2-13b-longlora-16k | 13B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k) | +| Llama-2-13b-longlora-32k | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k) | +| Llama-2-13b-longlora-64k | 13B | 65536 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k) | +| Llama-2-70b-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k) | +| Llama-2-70b-chat-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k) | + ## Training ### Pre-trained weights -I used GPTNeoX model as the base model architecture, which was ported from the authors' original repo where Llama2 was used. -Some candidate pre-trained weights may include [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Polyglot-ko-12.8B](https://huggingface.co/EleutherAI/polyglot-ko-12.8b) and other variants. - +We use LLaMA2 models as the pre-trained weights and fine-tune them to long context window sizes. Please download [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), [Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf), and [Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf), based on your choices. ### Fine-tuning ``` torchrun --nproc_per_node=8 fine-tune.py \ - --model_name_or_path path_to/gpt_neox_model_hf \ + --model_name_or_path path_to/Llama-2-7b-hf \ --bf16 True \ --output_dir path_to_saving_checkpoints \ --cache_dir path_to_cache \ @@ -47,7 +107,7 @@ torchrun --nproc_per_node=8 fine-tune.py \ --max_steps 1000 ``` -- Please remember to change `path_to/gpt_neox_model_hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. +- Please remember to change `path_to/Llama-2-7b-hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. - Note that you can change `model_max_length` to other values. - You could change `ds_configs/stage2.json` to `ds_configs/stage3.json` if you want. - Please set `use_flash_attn` as `False` if you use V100 machines or do not install flash attention. @@ -83,7 +143,7 @@ torchrun --nproc_per_node=8 supervised-fine-tune.py \ --deepspeed "ds_configs/stage2.json" \ --tf32 True ``` -- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models` +- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models`, like `Llama-2-13b-longlora-32k` or `Llama-2-13b-longlora-32k-ft`. - During our dataset collection, it is hard for us to collect many high-quality QA that are larger than 32768. Thus, if you use our `LongQA.json`, please also set `model_max_length` as 32768. @@ -222,8 +282,7 @@ If you find this project useful in your research, please consider citing: ``` ## Acknowledgement -- This work is an GPTNeoX port of the work from the original authors' code. [LongLoRA](https://github.com/dvlab-research/LongLoRA) -- This work is built upon the [GPTNeoX-HF](https://huggingface.co/docs/transformers/model_doc/gpt_neox) which is based upon [EleutherAI/GPTNeoX](https://github.com/EleutherAI/gpt-neox) as the pre-trained model architecture. +- This work is built upon the [LLaMA2](https://ai.meta.com/llama) as the pre-trained models. - This work is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed), [peft](https://github.com/huggingface/peft), and [Flash-Attention2](https://github.com/Dao-AILab/flash-attention) for acceleration. - Some evaluation code is modified upon [Landmark Attention](https://github.com/epfml/landmark-attention). - We use [LongChat](https://github.com/DachengLi1/LongChat) for the retrieval evaluation. From 28d452ec7355d7289a5176fac09f05038a16e5f9 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:47:40 +0900 Subject: [PATCH 16/16] [remove] unused comments --- gptneox_attn_replace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index fda6a221..ee16bb0b 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -112,9 +112,7 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - #qkv = qkv.transpose(1, 2) - # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) @@ -124,6 +122,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention @@ -136,7 +135,6 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) - #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2)