Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SDXL support #117

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "ggml"]
path = ggml
url = https://github.com/FSSRepo/ggml.git
url = https://github.com/leejet/ggml.git
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in

- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
- SD1.x and SD2.x support
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) support
- SD1.x, SD2.x and SDXL support
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
Expand Down Expand Up @@ -302,3 +302,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
- [generative-models](https://github.com/Stability-AI/generative-models/)
2 changes: 1 addition & 1 deletion ggml
48 changes: 32 additions & 16 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ const char* unused_tensors[] = {
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
"conditioner.embedders.0.model.logit_scale",
"conditioner.embedders.0.model.text_projection",
"conditioner.embedders.1.model.logit_scale",
"model.diffusion_model.time_embedding.cond_proj.weight",
"unet.time_embedding.cond_proj.weight",
"model_ema.decay",
Expand All @@ -100,11 +101,11 @@ bool is_unused_tensor(std::string name) {
}

std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"cond_stage_model.model.ln_final.bias", "cond_stage_model.transformer.text_model.final_layer_norm.bias"},
{"cond_stage_model.model.ln_final.weight", "cond_stage_model.transformer.text_model.final_layer_norm.weight"},
{"cond_stage_model.model.positional_embedding", "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"},
{"cond_stage_model.model.token_embedding.weight", "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"},

{"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"},
{"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"},
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
{"model.text_projection", "transformer.text_model.text_projection"},
};

std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
Expand Down Expand Up @@ -133,11 +134,21 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {

std::string convert_open_clip_to_hf_clip(const std::string& name) {
std::string new_name = name;
std::string prefix;
if (starts_with(new_name, "conditioner.embedders.0.")) {
new_name = "cond_stage_model." + new_name.substr(strlen("conditioner.embedders.0."));
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "conditioner.embedders.1.")) {
prefix = "cond_stage_model.1.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "cond_stage_model.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("cond_stage_model."));
} else {
return new_name;
}
std::string open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers.";
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";

if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
Expand All @@ -156,7 +167,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
}
}

return new_name;
return prefix + new_name;
}

std::string convert_vae_decoder_name(const std::string& name) {
Expand Down Expand Up @@ -358,7 +369,7 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)

std::string convert_tensor_name(const std::string& name) {
std::string new_name;
if (starts_with(name, "cond_stage_model.model") || starts_with(name, "conditioner.embedders.0.model")) {
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) {
new_name = convert_open_clip_to_hf_clip(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
Expand Down Expand Up @@ -419,7 +430,7 @@ void preprocess_tensor(TensorStorage tensor_storage,

tensor_storage.name = new_name;

if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
Expand All @@ -431,7 +442,7 @@ void preprocess_tensor(TensorStorage tensor_storage,

processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());

} else if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
} else if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
Expand Down Expand Up @@ -1163,15 +1174,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}

SDVersion ModelLoader::get_sd_version() {
// return VERSION_1_x;
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_XL;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight") {
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
tensor_storage.name == "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight") {
token_embedding_weight = tensor_storage;
break;
// break;
}
}
if (token_embedding_weight.ne[0] == 768) {
Expand Down Expand Up @@ -1275,7 +1291,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}

for (auto& tensor_storage : processed_tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
// LOG_DEBUG("%s", tensor_storage.name.c_str());

ggml_tensor* dst_tensor = NULL;

Expand Down
Loading
Loading