Skip to content

Commit

Permalink
fix: repair flash attention support (#386)
Browse files Browse the repository at this point in the history
* repair flash attention in _ext
this does not fix the currently broken fa behind the define, which is only used by VAE

Co-authored-by: FSSRepo <[email protected]>

* make flash attention in the diffusion model a runtime flag
no support for sd3 or video

* remove old flash attention option and switch vae over to attn_ext

* update docs

* format code

---------

Co-authored-by: FSSRepo <[email protected]>
Co-authored-by: leejet <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2024
1 parent ea9b647 commit 1c168d9
Show file tree
Hide file tree
Showing 17 changed files with 334 additions and 314 deletions.
6 changes: 0 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON)
Expand Down Expand Up @@ -61,11 +60,6 @@ if (SD_HIPBLAS)
endif()
endif ()

if(SD_FLASH_ATTN)
message("-- Use Flash Attention for memory optimization")
add_definitions(-DSD_USE_FLASH_ATTENTION)
endif()

set(SD_LIB stable-diffusion)

file(GLOB SD_LIB_SOURCES
Expand Down
21 changes: 17 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
- No need to convert to `.ggml` or `.gguf` anymore!
- Flash Attention for memory usage optimization (only cpu for now)
- Flash Attention for memory usage optimization
- Original `txt2img` and `img2img` mode
- Negative prompt
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
Expand Down Expand Up @@ -182,11 +182,21 @@ Example of text2img by using SYCL backend:
##### Using Flash Attention
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
eg.:
- flux 768x768 ~600mb
- SD2 768x768 ~1400mb
For most backends, it slows things down, but for cuda it generally speeds it up too.
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
Run by adding `--diffusion-fa` to the arguments and watch for:
```
cmake .. -DSD_FLASH_ATTN=ON
cmake --build . --config Release
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
```
and the compute buffer shrink in the debug log:
```
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
```
### Run
Expand Down Expand Up @@ -240,6 +250,9 @@ arguments:
--vae-tiling process vae in tiles to reduce memory usage
--vae-on-cpu keep vae in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model (for low vram)
Might lower quality, since it implies converting k and v to f16.
This might crash if it is not supported by the backend.
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color Colors the logging tags according to level
Expand Down
18 changes: 8 additions & 10 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ class CLIPTokenizer {
}
}

std::string clean_up_tokenization(std::string &text){

std::string clean_up_tokenization(std::string& text) {
std::regex pattern(R"( ,)");
// Replace " ," with ","
std::string result = std::regex_replace(text, pattern, ",");
Expand All @@ -359,10 +358,10 @@ class CLIPTokenizer {
std::u32string ts = decoder[t];
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
std::string s = utf32_to_utf8(ts);
if (s.length() >= 4 ){
if(ends_with(s, "</w>")) {
if (s.length() >= 4) {
if (ends_with(s, "</w>")) {
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
}else{
} else {
text += s;
}
} else {
Expand Down Expand Up @@ -768,8 +767,7 @@ class CLIPVisionModel : public GGMLBlock {
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values,
bool return_pooled = true) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
// pixel_values: [N, num_channels, image_size, image_size]
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
Expand All @@ -779,11 +777,11 @@ class CLIPVisionModel : public GGMLBlock {
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, -1, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]

GGML_ASSERT(x->ne[3] == 1);
GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) {
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
return pooled; // [N, hidden_size]
Expand Down
23 changes: 14 additions & 9 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,19 @@ class CrossAttention : public GGMLBlock {
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;

public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head)
int64_t d_head,
bool flash_attn = false)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim) {
context_dim(context_dim),
flash_attn(flash_attn) {
int64_t inner_dim = d_head * n_head;

blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
Expand Down Expand Up @@ -283,7 +286,7 @@ class CrossAttention : public GGMLBlock {
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]

x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]

x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
Expand All @@ -301,15 +304,16 @@ class BasicTransformerBlock : public GGMLBlock {
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false)
bool ff_in = false,
bool flash_attn = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
Expand Down Expand Up @@ -374,7 +378,8 @@ class SpatialTransformer : public GGMLBlock {
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim)
int64_t context_dim,
bool flash_attn = false)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
Expand All @@ -388,7 +393,7 @@ class SpatialTransformer : public GGMLBlock {

for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
}

blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
Expand Down Expand Up @@ -511,4 +516,4 @@ class VideoResBlock : public ResBlock {
}
};

#endif // __COMMON_HPP__
#endif // __COMMON_HPP__
19 changes: 9 additions & 10 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "clip.hpp"
#include "t5.hpp"


struct SDCondition {
struct ggml_tensor* c_crossattn = NULL; // aka context
struct ggml_tensor* c_vector = NULL; // aka y
Expand Down Expand Up @@ -44,7 +43,7 @@ struct Conditioner {
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1;
SDVersion version = VERSION_SD1;
PMVersion pm_version = VERSION_1;
CLIPTokenizer tokenizer;
ggml_type wtype;
Expand All @@ -61,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_type wtype,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = VERSION_1,
PMVersion pv = VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
if (clip_skip <= 0) {
Expand Down Expand Up @@ -162,7 +161,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
bool padding = false){
bool padding = false) {
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
text_model->model.n_token, padding);
}
Expand Down Expand Up @@ -271,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++)
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
Expand All @@ -287,11 +286,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// weights.insert(weights.begin(), 1.0);

tokenizer.pad_tokens(tokens, weights, max_length, padding);
int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs;
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (uint32_t i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
// hardcode for now
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
// hardcode for now
class_token_mask.push_back(true);
else
class_token_mask.push_back(false);
Expand Down Expand Up @@ -536,7 +535,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false){
bool force_zero_embeddings = false) {
auto image_tokens = convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]);
Expand Down Expand Up @@ -964,7 +963,7 @@ struct SD3CLIPEmbedder : public Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false){
bool force_zero_embeddings = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}

Expand Down
12 changes: 7 additions & 5 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ struct UNetModel : public DiffusionModel {

UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD1)
: unet(backend, wtype, version) {
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, wtype, version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -133,8 +134,9 @@ struct FluxModel : public DiffusionModel {

FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: flux(backend, wtype, version) {
SDVersion version = VERSION_FLUX_DEV,
bool flash_attn = false)
: flux(backend, wtype, version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -178,4 +180,4 @@ struct FluxModel : public DiffusionModel {
}
};

#endif
#endif
10 changes: 9 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct SDParams {
bool normalize_input = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
Expand Down Expand Up @@ -151,6 +152,7 @@ void print_params(SDParams params) {
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
Expand Down Expand Up @@ -227,6 +229,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
Expand Down Expand Up @@ -477,6 +482,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
} else if (arg == "--vae-on-cpu") {
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
} else if (arg == "--diffusion-fa") {
params.diffusion_flash_attn = true; // can reduce MEM significantly
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
Expand Down Expand Up @@ -868,7 +875,8 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.diffusion_flash_attn);

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
Expand Down
Loading

0 comments on commit 1c168d9

Please sign in to comment.