diff --git a/common/arg.cpp b/common/arg.cpp index deb113786..c81b15217 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1512,7 +1512,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)", [](common_params & params, const std::string & value) { - params.lora_adapters.push_back({ std::string(value), 1.0 }); + params.lora_adapters.push_back({ std::string(value), 1.0, nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); @@ -1520,7 +1520,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora-scaled"}, "FNAME", "SCALE", "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", [](common_params & params, const std::string & fname, const std::string & scale) { - params.lora_adapters.push_back({ fname, std::stof(scale) }); + params.lora_adapters.push_back({ fname, std::stof(scale), nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); diff --git a/common/common.cpp b/common/common.cpp index 6c0d24688..3e37039ca 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -922,20 +922,21 @@ struct common_init_result common_init_from_params(common_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - common_lora_adapter_container loaded_la; - loaded_la.path = la.path; - loaded_la.scale = la.scale; - loaded_la.adapter.reset(llama_lora_adapter_init(model, la.path.c_str())); - if (loaded_la.adapter == nullptr) { + llama_lora_adapter_ptr lora; + lora.reset(llama_lora_adapter_init(model, la.path.c_str())); + if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); llama_free_model(model); return iparams; } - iparams.lora_adapters.emplace_back(std::move(loaded_la)); // copy to list of loaded adapters + + la.ptr = lora.get(); + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } + if (!params.lora_init_without_apply) { - common_lora_adapters_apply(lctx, iparams.lora_adapters); + common_lora_adapters_apply(lctx, params.lora_adapters); } if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { @@ -1002,11 +1003,11 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) { llama_lora_adapter_clear(ctx); - for (auto & la : lora_adapters) { + for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.adapter.get(), la.scale); + llama_lora_adapter_set(ctx, la.ptr, la.scale); } } } diff --git a/common/common.h b/common/common.h index 2802675c2..e64292152 100644 --- a/common/common.h +++ b/common/common.h @@ -24,13 +24,12 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" +// TODO: "lora_adapter" is tautology struct common_lora_adapter_info { std::string path; float scale; -}; -struct common_lora_adapter_container : common_lora_adapter_info { - llama_lora_adapter_ptr adapter; + struct llama_lora_adapter * ptr; }; using llama_tokens = std::vector; @@ -478,11 +477,12 @@ std::string fs_get_cache_file(const std::string & filename); // Model utils // +// note: defines object's lifetime struct common_init_result { llama_model_ptr model; llama_context_ptr context; - std::vector lora_adapters; + std::vector lora; }; struct common_init_result common_init_from_params(common_params & params); @@ -504,7 +504,7 @@ struct llama_model * common_load_model_from_hf( const struct llama_model_params & params); // clear LoRA adapters from context, then apply new list of adapters -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora); // // Batch utils diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f50d75964..441c58e38 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -98,7 +98,7 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - std::vector lora; + std::vector lora; std::vector antiprompt; std::vector response_fields; @@ -198,7 +198,7 @@ struct server_task { bool metrics_reset_bucket = false; // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; + std::vector set_lora; server_task(server_task_type type) : type(type) {} @@ -206,7 +206,6 @@ struct server_task { const llama_model * model, const llama_context * ctx, const common_params & params_base, - const std::vector & lora_base, const json & data) { slot_params params; @@ -265,12 +264,12 @@ struct server_task { if (data.contains("lora")) { if (data.at("lora").is_array()) { - params.lora = parse_lora_request(lora_base, data.at("lora")); + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); } else { throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); } } else { - params.lora = lora_base; + params.lora = params_base.lora_adapters; } // TODO: add more sanity checks for the input parameters @@ -1132,7 +1131,7 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; + std::vector lora; // the index relative to completion multi-task request size_t index = 0; @@ -1633,8 +1632,6 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector lora; - llama_model * model_dft = nullptr; llama_context_params cparams_dft; @@ -1687,8 +1684,6 @@ struct server_context { model = llama_init.model.get(); ctx = llama_init.context.get(); - lora = std::move(llama_init.lora_adapters); - if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); return false; @@ -1883,7 +1878,7 @@ struct server_context { if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens slot.cache_tokens.clear(); - slot.lora = std::move(task.params.lora); + slot.lora = task.params.lora; } SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); @@ -2577,7 +2572,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SET_LORA: { - lora = std::move(task.set_lora); + params_base.lora_adapters = std::move(task.set_lora); auto res = std::make_unique(); res->id = task.id; queue_results.send(std::move(res)); @@ -3656,7 +3651,6 @@ int main(int argc, char ** argv) { ctx_server.model, ctx_server.ctx, ctx_server.params_base, - ctx_server.lora, data); task.id_selected_slot = json_value(data, "id_slot", -1); @@ -4083,8 +4077,9 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.lora.size(); ++i) { - auto & lora = ctx_server.lora[i]; + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; result.push_back({ {"id", i}, {"path", lora.path}, @@ -4103,7 +4098,7 @@ int main(int argc, char ** argv) { } server_task task(SERVER_TASK_TYPE_SET_LORA); task.id = ctx_server.queue_tasks.get_new_id(); - task.set_lora = parse_lora_request(ctx_server.lora, body); + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1cf08bb0a..dc6e6e67e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -799,25 +799,25 @@ static std::vector get_token_probabilities(llama_context * ctx } static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { + const std::vector & l1, + const std::vector & l2) { if (l1.size() != l2.size()) { return false; } for (size_t i = 0; i < l1.size(); ++i) { // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) { + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { return false; } } return true; } -// parse lora config from JSON request, returned a copy of base_lora with updated scale -static std::vector parse_lora_request( - const std::vector & base_lora, +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, const json & data) { - std::vector lora(base_lora); + std::vector lora(lora_base); int max_idx = lora.size(); // clear existing value diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index ca9f1ea8e..a05ba4f63 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 87ffdc599..7743b4652 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -2,7 +2,9 @@ #include "ggml.h" +#include #include +#include #include const char * llama_file_version_name(llama_fver version) { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index c9f274fc3..1ec478195 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -10,6 +10,7 @@ #include #include +#include #include using llama_buf_map = std::unordered_map;