diff --git a/include/llama.h b/include/llama.h index bf2761467..7c375522c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -443,6 +443,10 @@ extern "C" { struct llama_model * model, struct llama_context_params params); + LLAMA_API void llama_ctx_switch_adaptor( + struct llama_context* ctx, + const char* adaptor_alias); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp index 99a6d8b66..cbc9269d5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2851,6 +2851,8 @@ struct llama_lora_weight { struct llama_lora_adapter { struct llama_model * base_model; + std::string alias; + bool enabled = true; // map tensor name to lora_a_b std::unordered_map ab_map; std::vector ctxs; @@ -7892,6 +7894,9 @@ static struct ggml_tensor * llm_build_lora_mm( struct ggml_tensor * cur) { struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (auto & it : lctx.lora_adapters) { + if(!it.first->enabled) { + continue; + } struct llama_lora_weight * lora = it.first->get_weight(w); if (lora == nullptr) { continue; @@ -7918,6 +7923,9 @@ static struct ggml_tensor * llm_build_lora_mm_id( struct ggml_tensor * ids) { struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); for (auto & it : lctx.lora_adapters) { + if (!it.first->enabled) { + continue; + } struct llama_lora_weight * lora = it.first->get_weight(w); if (lora == nullptr) { continue; @@ -18604,7 +18612,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { - LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' as '%s' ...\n", __func__, path_lora, adapter.alias.c_str()); ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -19404,6 +19412,18 @@ struct llama_context * llama_new_context_with_model( return ctx; } +void llama_ctx_switch_adaptor(struct llama_context* ctx, const char* adaptor_alias) { + llama_synchronize(ctx); + for (auto& adaptor : ctx->lora_adapters) { + if (adaptor.first->alias == adaptor_alias) { + adaptor.first->enabled = true; + } + else { + adaptor.first->enabled = false; + } + } +} + void llama_free(struct llama_context * ctx) { delete ctx; } @@ -19615,8 +19635,17 @@ uint32_t llama_model_quantize( struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { try { + std::string alias; + std::string path = path_lora; struct llama_lora_adapter * adapter = new llama_lora_adapter(model); - llama_lora_adapter_init_internal(model, path_lora, *adapter); + size_t pos = path.find("="); + if (pos != std::string::npos) { + alias = path.substr(0, pos); + path = path.substr(pos + 1); + adapter->alias = alias; + } + + llama_lora_adapter_init_internal(model, path.c_str(), *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());