support mergekit-extract-lora

This commit is contained in:
Xuan Son Nguyen 2025-01-07 22:03:06 +01:00
parent 93fbfd022c
commit e444b8e0c2
4 changed files with 39 additions and 72 deletions

View file

@ -382,13 +382,13 @@ if __name__ == '__main__':
if self.lazy: if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor) tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name) base_name = get_base_tensor_name(name)
# note: lora_embedding is transposed by mergekit-extract-lora, so it's reversed here # note: mergekit-extract-lora also adds token embeddings to the adapter
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_B" in name is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_A" in name is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
if not is_lora_a and not is_lora_b: if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name: if ".base_layer.weight" in name:
continue continue
# mergekit-extract-lora add these layernorm to the adapter # mergekit-extract-lora add these layernorm to the adapter, we need to keep them
if ".layernorm" or ".norm" in name: if ".layernorm" or ".norm" in name:
yield (base_name, tensor) yield (base_name, tensor)
continue continue
@ -398,10 +398,6 @@ if __name__ == '__main__':
logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948") logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
sys.exit(1) sys.exit(1)
# mergekit-extract-lora transposes this tensor, we need to transpose it back
if ".lora_embedding" in name:
tensor = tensor.T
if base_name in tensor_map: if base_name in tensor_map:
if is_lora_a: if is_lora_a:
tensor_map[base_name].A = tensor tensor_map[base_name].A = tensor
@ -437,6 +433,11 @@ if __name__ == '__main__':
assert isinstance(dest_data, LoraTorchTensor) assert isinstance(dest_data, LoraTorchTensor)
lora_a, lora_b = dest_data.get_lora_A_B() lora_a, lora_b = dest_data.get_lora_A_B()
# token_embd A and B are already transposed by mergekit-extract-lora
# we transpose A back again because it is used by llm_build_inp_embd()
if "token_embd.weight" in dest_name:
lora_a = lora_a.T
yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_a", lora_a)
yield (dest_name + ".lora_b", lora_b) yield (dest_name + ".lora_b", lora_b)

View file

@ -243,8 +243,9 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
ab_map[name].b = cur; ab_map[name].b = cur;
} }
} else if (str_endswith(name, "_norm.weight")) { } else if (str_endswith(name, "_norm.weight")) {
// norm only has 1 dim, so tensor b == nullptr // TODO: add support for norm vector
ab_map[name] = llama_lora_weight(cur); // for now, we don't really care because most adapters still work fine without it
continue;
} else { } else {
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
} }
@ -254,9 +255,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
llama_lora_weight & w = it.second; llama_lora_weight & w = it.second;
if (w.is_norm) { bool is_token_embd = str_endswith(name, "token_embd.weight");
continue;
}
if (!w.a || !w.b) { if (!w.a || !w.b) {
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
@ -270,11 +269,18 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape // validate tensor shape
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { if (is_token_embd) {
throw std::runtime_error("tensor '" + name + "' has incorrect shape"); // expect B to be transposed, see llm_build_inp_embd()
} if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) {
if (w.a->ne[1] != w.b->ne[0]) { throw std::runtime_error("tensor '" + name + "' has incorrect shape");
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); }
} else {
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
throw std::runtime_error("tensor '" + name + "' has incorrect shape");
}
if (w.a->ne[1] != w.b->ne[0]) {
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
}
} }
// save tensor to adapter // save tensor to adapter
@ -285,24 +291,6 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
} }
// add norm vectors
for (auto & it : ab_map) {
const std::string & name = it.first;
llama_lora_weight & w = it.second;
if (w.is_norm) {
GGML_ASSERT(w.a != nullptr);
// device buft and device ctx
auto * model_tensor = llama_model_get_tensor(model, name.c_str());
if (!model_tensor) {
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
}
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
struct ggml_tensor * tensor_norm = ggml_dup_tensor(dev_ctx, w.a);
ggml_set_name(tensor_norm, w.a->name);
adapter.ab_map[it.first] = llama_lora_weight(tensor_norm);
}
}
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
{ {
adapter.ctxs.reserve(ctx_map.size()); adapter.ctxs.reserve(ctx_map.size());
@ -335,9 +323,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
auto orig = ab_map[it.first]; auto orig = ab_map[it.first];
auto dev = it.second; auto dev = it.second;
set_tensor(orig.a, dev.a); set_tensor(orig.a, dev.a);
if (!dev.is_norm) { set_tensor(orig.b, dev.b);
set_tensor(orig.b, dev.b);
}
} }
} }

View file

@ -45,11 +45,14 @@ struct llama_lora_weight {
struct ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// note: norm only has 1 dim, so tensor b == nullptr // get actual scale based on rank and alpha
bool is_norm = false; // is this a norm vector? (e.g. _norm.weight) float get_scale(float alpha, float adapter_scale) {
const float rank = (float) b->ne[0];
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
return scale;
}
llama_lora_weight() = default; llama_lora_weight() = default;
llama_lora_weight(struct ggml_tensor * a) : a(a), is_norm(true) {}
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };

View file

@ -2545,8 +2545,6 @@ static struct ggml_tensor * llm_build_inp_embd(
ggml_set_input(lctx.inp_tokens); ggml_set_input(lctx.inp_tokens);
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
//printf("tok_embd shape: %d x %d\n", tok_embd->ne[0], tok_embd->ne[1]);
//printf("inpL shape: %d x %d\n", inpL->ne[0], inpL->ne[1]);
// apply lora for embedding tokens if needed // apply lora for embedding tokens if needed
for (auto & it : lctx.lora_adapters) { for (auto & it : lctx.lora_adapters) {
@ -2554,18 +2552,13 @@ static struct ggml_tensor * llm_build_inp_embd(
if (lora == nullptr) { if (lora == nullptr) {
continue; continue;
} }
const float alpha = it.first->alpha; const float adapter_scale = it.second;
const float rank = (float) lora->b->ne[0]; const float scale = lora->get_scale(it.first->alpha, adapter_scale);
const float scale = alpha ? it.second * alpha / rank : it.second;
auto ss = ggml_get_rows(ctx, lora->b, lctx.inp_tokens);
//printf("a shape: %d x %d\n", lora->a->ne[0], lora->a->ne[1]);
//printf("b shape: %d x %d\n", lora->b->ne[0], lora->b->ne[1]);
//printf("ss shape: %d x %d\n", ss->ne[0], ss->ne[1]);
struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat( struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
ctx, ss, ggml_transpose(ctx, lora->a) ctx, lora->b, // non-transposed lora_b
ggml_get_rows(ctx, lora->a, lctx.inp_tokens)
), scale); ), scale);
//printf("inpL_delta shape: %d x %d\n", inpL_delta->ne[0], inpL_delta->ne[1]); inpL = ggml_add(ctx, inpL, inpL_delta);
inpL = ggml_add(ctx, inpL, ggml_cont(ctx, ggml_transpose(ctx, inpL_delta)));
} }
} else { } else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
@ -3919,17 +3912,9 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * attn_norm = model.layers[il].attn_norm;
for (auto & it : lctx.lora_adapters) {
struct llama_lora_weight * lora = it.first->get_weight(model.layers[il].attn_norm);
if (lora && lora->is_norm) {
attn_norm = ggml_add(ctx0, attn_norm, ggml_scale(ctx0, lora->a, 0.5));
}
}
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
attn_norm, NULL, model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
@ -3998,16 +3983,8 @@ struct llm_build_context {
// feed-forward network // feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) { if (model.layers[il].ffn_gate_inp == nullptr) {
struct ggml_tensor * ffn_norm = model.layers[il].ffn_norm;
// for (auto & it : lctx.lora_adapters) {
// struct llama_lora_weight * lora = it.first->get_weight(ffn_norm);
// if (lora && lora->is_norm) {
// ffn_norm = ggml_add(ctx0, ffn_norm, lora->a);
// }
// }
cur = llm_build_norm(ctx0, ffn_inp, hparams, cur = llm_build_norm(ctx0, ffn_inp, hparams,
ffn_norm, NULL, model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);