add alias for lora adaptors

This commit is contained in:
zhhan 2024-07-22 16:15:46 -07:00
parent 081fe431aa
commit 0e2a0d4d09
2 changed files with 35 additions and 2 deletions

View file

@ -443,6 +443,10 @@ extern "C" {
struct llama_model * model,
struct llama_context_params params);
LLAMA_API void llama_ctx_switch_adaptor(
struct llama_context* ctx,
const char* adaptor_alias);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);

View file

@ -2851,6 +2851,8 @@ struct llama_lora_weight {
struct llama_lora_adapter {
struct llama_model * base_model;
std::string alias;
bool enabled = true;
// map tensor name to lora_a_b
std::unordered_map<std::string, struct llama_lora_weight> ab_map;
std::vector<struct ggml_context *> ctxs;
@ -7892,6 +7894,9 @@ static struct ggml_tensor * llm_build_lora_mm(
struct ggml_tensor * cur) {
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
for (auto & it : lctx.lora_adapters) {
if(!it.first->enabled) {
continue;
}
struct llama_lora_weight * lora = it.first->get_weight(w);
if (lora == nullptr) {
continue;
@ -7918,6 +7923,9 @@ static struct ggml_tensor * llm_build_lora_mm_id(
struct ggml_tensor * ids) {
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
for (auto & it : lctx.lora_adapters) {
if (!it.first->enabled) {
continue;
}
struct llama_lora_weight * lora = it.first->get_weight(w);
if (lora == nullptr) {
continue;
@ -18604,7 +18612,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
}
static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' as '%s' ...\n", __func__, path_lora, adapter.alias.c_str());
ggml_context * ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
@ -19404,6 +19412,18 @@ struct llama_context * llama_new_context_with_model(
return ctx;
}
void llama_ctx_switch_adaptor(struct llama_context* ctx, const char* adaptor_alias) {
llama_synchronize(ctx);
for (auto& adaptor : ctx->lora_adapters) {
if (adaptor.first->alias == adaptor_alias) {
adaptor.first->enabled = true;
}
else {
adaptor.first->enabled = false;
}
}
}
void llama_free(struct llama_context * ctx) {
delete ctx;
}
@ -19615,8 +19635,17 @@ uint32_t llama_model_quantize(
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
try {
std::string alias;
std::string path = path_lora;
struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
llama_lora_adapter_init_internal(model, path_lora, *adapter);
size_t pos = path.find("=");
if (pos != std::string::npos) {
alias = path.substr(0, pos);
path = path.substr(pos + 1);
adapter->alias = alias;
}
llama_lora_adapter_init_internal(model, path.c_str(), *adapter);
return adapter;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());