common : update lora

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-02 17:26:18 +02:00
parent 8d117a518d
commit 272cd0eaea
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 40 additions and 40 deletions

View file

@ -1512,7 +1512,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME", {"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)", "path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0 }); params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
} }
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@ -1520,7 +1520,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora-scaled"}, "FNAME", "SCALE", {"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) { [](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale) }); params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
} }
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));

View file

@ -922,20 +922,21 @@ struct common_init_result common_init_from_params(common_params & params) {
// load and optionally apply lora adapters // load and optionally apply lora adapters
for (auto & la : params.lora_adapters) { for (auto & la : params.lora_adapters) {
common_lora_adapter_container loaded_la; llama_lora_adapter_ptr lora;
loaded_la.path = la.path; lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
loaded_la.scale = la.scale; if (lora == nullptr) {
loaded_la.adapter.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (loaded_la.adapter == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); llama_free(lctx);
llama_free_model(model); llama_free_model(model);
return iparams; return iparams;
} }
iparams.lora_adapters.emplace_back(std::move(loaded_la)); // copy to list of loaded adapters
la.ptr = lora.get();
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
} }
if (!params.lora_init_without_apply) { if (!params.lora_init_without_apply) {
common_lora_adapters_apply(lctx, iparams.lora_adapters); common_lora_adapters_apply(lctx, params.lora_adapters);
} }
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
@ -1002,11 +1003,11 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) { void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_lora_adapter_clear(ctx); llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) { for (auto & la : lora) {
if (la.scale != 0.0f) { if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter.get(), la.scale); llama_lora_adapter_set(ctx, la.ptr, la.scale);
} }
} }
} }

View file

@ -24,13 +24,12 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
// TODO: "lora_adapter" is tautology
struct common_lora_adapter_info { struct common_lora_adapter_info {
std::string path; std::string path;
float scale; float scale;
};
struct common_lora_adapter_container : common_lora_adapter_info { struct llama_lora_adapter * ptr;
llama_lora_adapter_ptr adapter;
}; };
using llama_tokens = std::vector<llama_token>; using llama_tokens = std::vector<llama_token>;
@ -478,11 +477,12 @@ std::string fs_get_cache_file(const std::string & filename);
// Model utils // Model utils
// //
// note: defines object's lifetime
struct common_init_result { struct common_init_result {
llama_model_ptr model; llama_model_ptr model;
llama_context_ptr context; llama_context_ptr context;
std::vector<common_lora_adapter_container> lora_adapters; std::vector<llama_lora_adapter_ptr> lora;
}; };
struct common_init_result common_init_from_params(common_params & params); struct common_init_result common_init_from_params(common_params & params);
@ -504,7 +504,7 @@ struct llama_model * common_load_model_from_hf(
const struct llama_model_params & params); const struct llama_model_params & params);
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters); void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);
// //
// Batch utils // Batch utils

View file

@ -98,7 +98,7 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_lora_adapter_container> lora; std::vector<common_lora_adapter_info> lora;
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
std::vector<std::string> response_fields; std::vector<std::string> response_fields;
@ -198,7 +198,7 @@ struct server_task {
bool metrics_reset_bucket = false; bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA // used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora; std::vector<common_lora_adapter_info> set_lora;
server_task(server_task_type type) : type(type) {} server_task(server_task_type type) : type(type) {}
@ -206,7 +206,6 @@ struct server_task {
const llama_model * model, const llama_model * model,
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
const std::vector<common_lora_adapter_container> & lora_base,
const json & data) { const json & data) {
slot_params params; slot_params params;
@ -265,12 +264,12 @@ struct server_task {
if (data.contains("lora")) { if (data.contains("lora")) {
if (data.at("lora").is_array()) { if (data.at("lora").is_array()) {
params.lora = parse_lora_request(lora_base, data.at("lora")); params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
} else { } else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
} }
} else { } else {
params.lora = lora_base; params.lora = params_base.lora_adapters;
} }
// TODO: add more sanity checks for the input parameters // TODO: add more sanity checks for the input parameters
@ -1132,7 +1131,7 @@ struct server_slot {
common_speculative * spec = nullptr; common_speculative * spec = nullptr;
std::vector<common_lora_adapter_container> lora; std::vector<common_lora_adapter_info> lora;
// the index relative to completion multi-task request // the index relative to completion multi-task request
size_t index = 0; size_t index = 0;
@ -1633,8 +1632,6 @@ struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> lora;
llama_model * model_dft = nullptr; llama_model * model_dft = nullptr;
llama_context_params cparams_dft; llama_context_params cparams_dft;
@ -1687,8 +1684,6 @@ struct server_context {
model = llama_init.model.get(); model = llama_init.model.get();
ctx = llama_init.context.get(); ctx = llama_init.context.get();
lora = std::move(llama_init.lora_adapters);
if (model == nullptr) { if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
return false; return false;
@ -1883,7 +1878,7 @@ struct server_context {
if (!are_lora_equal(task.params.lora, slot.lora)) { if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens // if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear(); slot.cache_tokens.clear();
slot.lora = std::move(task.params.lora); slot.lora = task.params.lora;
} }
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
@ -2577,7 +2572,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SET_LORA: case SERVER_TASK_TYPE_SET_LORA:
{ {
lora = std::move(task.set_lora); params_base.lora_adapters = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>(); auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id; res->id = task.id;
queue_results.send(std::move(res)); queue_results.send(std::move(res));
@ -3656,7 +3651,6 @@ int main(int argc, char ** argv) {
ctx_server.model, ctx_server.model,
ctx_server.ctx, ctx_server.ctx,
ctx_server.params_base, ctx_server.params_base,
ctx_server.lora,
data); data);
task.id_selected_slot = json_value(data, "id_slot", -1); task.id_selected_slot = json_value(data, "id_slot", -1);
@ -4083,8 +4077,9 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array(); json result = json::array();
for (size_t i = 0; i < ctx_server.lora.size(); ++i) { const auto & loras = ctx_server.params_base.lora_adapters;
auto & lora = ctx_server.lora[i]; for (size_t i = 0; i < loras.size(); ++i) {
auto & lora = loras[i];
result.push_back({ result.push_back({
{"id", i}, {"id", i},
{"path", lora.path}, {"path", lora.path},
@ -4103,7 +4098,7 @@ int main(int argc, char ** argv) {
} }
server_task task(SERVER_TASK_TYPE_SET_LORA); server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body); task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task); ctx_server.queue_tasks.post(task);

View file

@ -799,25 +799,25 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
} }
static bool are_lora_equal( static bool are_lora_equal(
const std::vector<common_lora_adapter_container> & l1, const std::vector<common_lora_adapter_info> & l1,
const std::vector<common_lora_adapter_container> & l2) { const std::vector<common_lora_adapter_info> & l2) {
if (l1.size() != l2.size()) { if (l1.size() != l2.size()) {
return false; return false;
} }
for (size_t i = 0; i < l1.size(); ++i) { for (size_t i = 0; i < l1.size(); ++i) {
// we don't check lora.path to reduce the time complexity // we don't check lora.path to reduce the time complexity
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) { if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
return false; return false;
} }
} }
return true; return true;
} }
// parse lora config from JSON request, returned a copy of base_lora with updated scale // parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_lora_adapter_container> parse_lora_request( static std::vector<common_lora_adapter_info> parse_lora_request(
const std::vector<common_lora_adapter_container> & base_lora, const std::vector<common_lora_adapter_info> & lora_base,
const json & data) { const json & data) {
std::vector<common_lora_adapter_container> lora(base_lora); std::vector<common_lora_adapter_info> lora(lora_base);
int max_idx = lora.size(); int max_idx = lora.size();
// clear existing value // clear existing value

View file

@ -5,6 +5,7 @@
#include <cinttypes> #include <cinttypes>
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
#include <cstring>
#include <vector> #include <vector>
#include <sstream> #include <sstream>

View file

@ -2,7 +2,9 @@
#include "ggml.h" #include "ggml.h"
#include <array>
#include <cinttypes> #include <cinttypes>
#include <cstring>
#include <future> #include <future>
const char * llama_file_version_name(llama_fver version) { const char * llama_file_version_name(llama_fver version) {

View file

@ -10,6 +10,7 @@
#include <cstddef> #include <cstddef>
#include <map> #include <map>
#include <stdexcept>
#include <unordered_map> #include <unordered_map>
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>; using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;