Skip to content
/ DiG Public

DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention

License

Notifications You must be signed in to change notification settings

hustvl/DiG

Repository files navigation

Diffusion GLA (DiG)

Scalable and Efficient Diffusion Models with Gated Linear Attention

Lianghui Zhu1,2,Zilong Huang2 📧,Bencheng Liao1,Jun Hao Liew2, Hanshu Yan2, Jiashi Feng2, Xinggang Wang1 📧

1 School of EIC, Huazhong University of Science and Technology, 2 ByteDance

(📧) corresponding author.

ArXiv Preprint (arXiv 2405.18428)

News

  • May. 28th, 2024: We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️

Abstract

Diffusion models with large-scale pre-training have achieved significant success in the field of visual content generation, particularly exemplified by Diffusion Transformers (DiT). However, DiT models have faced challenges with quadratic complexity efficiency, especially when handling long sequences. In this paper, we aim to incorporate the sub-quadratic modeling capability of Gated Linear Attention (GLA) into the 2D diffusion backbone. Specifically, we introduce Diffusion Gated Linear Attention Transformers (DiG), a simple, adoptable solution with minimal parameter overhead. We offer two variants, i,e, a plain and U-shape architecture, showing superior efficiency and competitive effectiveness. In addition to superior performance to DiT and other sub-quadratic-time diffusion models at $256 \times 256$ resolution, DiG demonstrates greater efficiency than these methods starting from a $512$ resolution. Specifically, DiG-S/2 is $2.5\times$ faster and saves $75.7%$ GPU memory compared to DiT-S/2 at a $1792$ resolution. Additionally, DiG-XL/2 is $4.2\times$ faster than the Mamba-based model at a $1024$ resolution and $1.8\times$ faster than DiT with FlashAttention-2 at a $2048$ resolution. We will release the code soon.

Overview

Envs. for Training

  • Python 3.9.2

    • conda create -n your_env_name python=3.9.2
  • torch 2.1.1 + cu118

    • pip3 install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
  • Requirements:

    # triton
    pip3 install triton
    
    # GLA
    git clone https://github.com/sustcsonglin/flash-linear-attention
    git checkout 36743f3f14e47f23c1ad45cf5de727dbacb5600e
    cd flash-linear-attention
    pip3 install -e .
    
    # others
    pip3 install diffusers
    pip3 install tensorboard
    pip3 install timm
    pip3 install transformers
    pip3 install accelerate
    pip3 install fvcore
    pip3 install opt_einsum
    pip3 install torchdiffeq
    pip3 install ftfy
    pip3 install PyAV

Train Your DiG

  • Set your VAE path in train-multi-nodes.py.
  • Set your DATA_PATH in scripts/dig_s_d2_in1k_256_bs256_1node.sh.
  • Run bash DiG/scripts/dig_s_d2_in1k_256_bs256_1node.sh no_env_install.

Acknowledgement ❤️

This project is based on GLA (paper, code), flash-linear-attention (code), DiT (paper, code), DiS (paper, code), OpenDiT (code). Thanks for their wonderful works.

Citation

If you find DiG is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.

@article{dig,
      title={DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention}, 
      author={Lianghui Zhu and Zilong Huang and Bencheng Liao and Jun Hao Liew and Hanshu Yan and Jiashi Feng and Xinggang Wang},
      year={2024},
      eprint={2405.18428},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

About

DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published