clean up struct def

This commit is contained in:
Xuan Son Nguyen 2024-08-05 23:23:37 +02:00
parent 21cb13384c
commit c58a332fcd
3 changed files with 22 additions and 14 deletions

View file

@ -687,7 +687,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_adapters.push_back({ params.lora_adapters.push_back({
std::string(argv[i]), std::string(argv[i]),
1.0, 1.0,
nullptr,
}); });
return true; return true;
} }
@ -698,7 +697,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_adapters.push_back({ params.lora_adapters.push_back({
lora_adapter, lora_adapter,
std::stof(argv[i]), std::stof(argv[i]),
nullptr,
}); });
return true; return true;
} }
@ -2106,16 +2104,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
// load and optionally apply lora adapters // load and optionally apply lora adapters
for (auto & la : params.lora_adapters) { for (auto & la : params.lora_adapters) {
la.adapter = llama_lora_adapter_init(model, la.path.c_str()); llama_lora_adapter_container loaded_la;
if (la.adapter == nullptr) { loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); llama_free(lctx);
llama_free_model(model); llama_free_model(model);
return iparams; return iparams;
} }
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
} }
if (!params.lora_init_without_apply) { if (!params.lora_init_without_apply) {
llama_lora_adapters_apply(lctx, params.lora_adapters); llama_lora_adapters_apply(lctx, iparams.lora_adapters);
} }
if (params.ignore_eos) { if (params.ignore_eos) {

View file

@ -33,9 +33,12 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct llama_lora_adapter_container { struct llama_lora_adapter_info {
std::string path; std::string path;
float scale; float scale;
};
struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter; struct llama_lora_adapter * adapter;
}; };
@ -133,7 +136,7 @@ struct gpt_params {
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_container> lora_adapters; // lora adapter path with user defined scale std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
@ -315,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
// //
struct llama_init_result { struct llama_init_result {
struct llama_model * model = nullptr; struct llama_model * model = nullptr;
struct llama_context * context = nullptr; struct llama_context * context = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
}; };
struct llama_init_result llama_init_from_gpt_params(gpt_params & params); struct llama_init_result llama_init_from_gpt_params(gpt_params & params);

View file

@ -623,6 +623,7 @@ struct server_response {
struct server_context { struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
gpt_params params; gpt_params params;
@ -682,6 +683,7 @@ struct server_context {
model = llama_init.model; model = llama_init.model;
ctx = llama_init.context; ctx = llama_init.context;
lora_adapters = llama_init.lora_adapters;
params.n_parallel -= 1; // but be sneaky about it params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) { if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}}); LOG_ERROR("unable to load model", {{"model", params.model}});
@ -1853,7 +1855,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SET_LORA: case SERVER_TASK_TYPE_SET_LORA:
{ {
llama_lora_adapters_apply(ctx, params.lora_adapters); llama_lora_adapters_apply(ctx, lora_adapters);
server_task_result result; server_task_result result;
result.id = task.id; result.id = task.id;
result.data = json{{ "success", true }}; result.data = json{{ "success", true }};
@ -3340,8 +3342,8 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json result = json::array(); json result = json::array();
for (size_t i = 0; i < ctx_server.params.lora_adapters.size(); ++i) { for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.params.lora_adapters[i]; auto & la = ctx_server.lora_adapters[i];
result.push_back({ result.push_back({
{"id", i}, {"id", i},
{"path", la.path}, {"path", la.path},
@ -3356,10 +3358,10 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const std::vector<json> body = json::parse(req.body); const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.params.lora_adapters.size(); int max_idx = ctx_server.lora_adapters.size();
// clear existing value // clear existing value
for (auto & la : ctx_server.params.lora_adapters) { for (auto & la : ctx_server.lora_adapters) {
la.scale = 0.0f; la.scale = 0.0f;
} }
@ -3368,7 +3370,7 @@ int main(int argc, char ** argv) {
int id = entry.at("id"); int id = entry.at("id");
float scale = entry.at("scale"); float scale = entry.at("scale");
if (0 <= id && id < max_idx) { if (0 <= id && id < max_idx) {
ctx_server.params.lora_adapters[id].scale = scale; ctx_server.lora_adapters[id].scale = scale;
} else { } else {
throw std::runtime_error("invalid adapter id"); throw std::runtime_error("invalid adapter id");
} }