Skip to content

Commit

Permalink
lora: lokr tweaks + reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Jan 29, 2025
1 parent 653d56a commit bf221a4
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,13 @@ struct LoraModel : public GGMLRunner {
float scale_value = 1.0f;
std::string fk = lora_pre[type] + key;
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
// loHa mode
// LoHa mode

// TODO: split qkv convention for LoHas (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoHa models.");
break;
}
std::string alpha_name = "";

ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition
Expand All @@ -286,11 +291,6 @@ struct LoraModel : public GGMLRunner {
std::string hada_2_down_name = "";
std::string hada_2_up_name = "";

// TODO: split qkv convention for LoHas (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoHa models.");
break;
}

hada_1_down_name = fk + ".hada_w1_b";
hada_1_up_name = fk + ".hada_w1_a";
Expand Down Expand Up @@ -340,12 +340,20 @@ struct LoraModel : public GGMLRunner {

// calc_scale
// TODO: .dora_scale?
int64_t dim = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
scale_value = alpha / rank;
}
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
// LoKr mode

// TODO: split qkv convention for LoKrs (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoKr models.");
break;
}

std::string alpha_name = fk + ".alpha";

ggml_tensor* lokr_w1 = NULL;
Expand All @@ -354,12 +362,6 @@ struct LoraModel : public GGMLRunner {
std::string lokr_w1_name = "";
std::string lokr_w2_name = "";

// TODO: split qkv convention for LoKrs (is it ever used?)
if (is_qkv_split || is_qkvm_split) {
LOG_ERROR("Split qkv isn't supported for LoKr models.");
break;
}

lokr_w1_name = fk + ".lokr_w1";
lokr_w2_name = fk + ".lokr_w2";

Expand All @@ -372,14 +374,14 @@ struct LoraModel : public GGMLRunner {
std::string down_name = lokr_w1_name + "_b";
std::string up_name = lokr_w1_name + "_a";
if (lora_tensors.find(down_name) != lora_tensors.end()) {
// w1 should not be low rank normally, sometimes w1 and w2 are swapped
down = to_f32(compute_ctx, lora_tensors[down_name]);
applied_lora_tensors.insert(down_name);

// scale != 1 only when using Low rank form (?)
int64_t dim = down->ne[ggml_n_dims(down) - 1];
int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
scale_value = alpha / rank;
}
}
if (lora_tensors.find(up_name) != lora_tensors.end()) {
Expand All @@ -399,18 +401,25 @@ struct LoraModel : public GGMLRunner {
if (lora_tensors.find(down_name) != lora_tensors.end()) {
down = to_f32(compute_ctx, lora_tensors[down_name]);
applied_lora_tensors.insert(down_name);

int64_t rank = down->ne[ggml_n_dims(down) - 1];
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / rank;
}
}
if (lora_tensors.find(up_name) != lora_tensors.end()) {
up = to_f32(compute_ctx, lora_tensors[up_name]);
applied_lora_tensors.insert(up_name);
}
lokr_w2 = ggml_merge_lora(compute_ctx, down, up);
}

// Technically it might be unused, but I believe it's the expected behavior
applied_lora_tensors.insert(alpha_name);

updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2);

// TODO: double check alpha implementation, it seems strange to not use them most of the time
applied_lora_tensors.insert(alpha_name);
} else {
// LoRA mode
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
Expand Down Expand Up @@ -770,12 +779,12 @@ struct LoraModel : public GGMLRunner {
}
// calc_scale
// TODO: .dora_scale?
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1];
if (lora_tensors.find(scale_name) != lora_tensors.end()) {
scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]);
} else if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
scale_value = alpha / dim;
scale_value = alpha / rank;
}

updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid);
Expand Down

0 comments on commit bf221a4

Please sign in to comment.