clean up struct def
This commit is contained in:
parent
21cb13384c
commit
c58a332fcd
3 changed files with 22 additions and 14 deletions
|
@ -687,7 +687,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
params.lora_adapters.push_back({
|
||||
std::string(argv[i]),
|
||||
1.0,
|
||||
nullptr,
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
@ -698,7 +697,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
params.lora_adapters.push_back({
|
||||
lora_adapter,
|
||||
std::stof(argv[i]),
|
||||
nullptr,
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
@ -2106,16 +2104,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
la.adapter = llama_lora_adapter_init(model, la.path.c_str());
|
||||
if (la.adapter == nullptr) {
|
||||
llama_lora_adapter_container loaded_la;
|
||||
loaded_la.path = la.path;
|
||||
loaded_la.scale = la.scale;
|
||||
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
|
||||
if (loaded_la.adapter == nullptr) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
return iparams;
|
||||
}
|
||||
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
|
||||
}
|
||||
if (!params.lora_init_without_apply) {
|
||||
llama_lora_adapters_apply(lctx, params.lora_adapters);
|
||||
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
|
|
|
@ -33,9 +33,12 @@
|
|||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
|
||||
struct llama_lora_adapter_container {
|
||||
struct llama_lora_adapter_info {
|
||||
std::string path;
|
||||
float scale;
|
||||
};
|
||||
|
||||
struct llama_lora_adapter_container : llama_lora_adapter_info {
|
||||
struct llama_lora_adapter * adapter;
|
||||
};
|
||||
|
||||
|
@ -133,7 +136,7 @@ struct gpt_params {
|
|||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
|
||||
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
|
||||
std::vector<llama_lora_adapter_container> lora_adapters; // lora adapter path with user defined scale
|
||||
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
|
||||
|
||||
std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||
|
||||
|
@ -315,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
|
|||
//
|
||||
|
||||
struct llama_init_result {
|
||||
struct llama_model * model = nullptr;
|
||||
struct llama_model * model = nullptr;
|
||||
struct llama_context * context = nullptr;
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
};
|
||||
|
||||
struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
|
||||
|
|
|
@ -623,6 +623,7 @@ struct server_response {
|
|||
struct server_context {
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
|
@ -682,6 +683,7 @@ struct server_context {
|
|||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
lora_adapters = llama_init.lora_adapters;
|
||||
params.n_parallel -= 1; // but be sneaky about it
|
||||
if (model == nullptr) {
|
||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
||||
|
@ -1853,7 +1855,7 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_TYPE_SET_LORA:
|
||||
{
|
||||
llama_lora_adapters_apply(ctx, params.lora_adapters);
|
||||
llama_lora_adapters_apply(ctx, lora_adapters);
|
||||
server_task_result result;
|
||||
result.id = task.id;
|
||||
result.data = json{{ "success", true }};
|
||||
|
@ -3340,8 +3342,8 @@ int main(int argc, char ** argv) {
|
|||
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.params.lora_adapters.size(); ++i) {
|
||||
auto & la = ctx_server.params.lora_adapters[i];
|
||||
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
||||
auto & la = ctx_server.lora_adapters[i];
|
||||
result.push_back({
|
||||
{"id", i},
|
||||
{"path", la.path},
|
||||
|
@ -3356,10 +3358,10 @@ int main(int argc, char ** argv) {
|
|||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.params.lora_adapters.size();
|
||||
int max_idx = ctx_server.lora_adapters.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & la : ctx_server.params.lora_adapters) {
|
||||
for (auto & la : ctx_server.lora_adapters) {
|
||||
la.scale = 0.0f;
|
||||
}
|
||||
|
||||
|
@ -3368,7 +3370,7 @@ int main(int argc, char ** argv) {
|
|||
int id = entry.at("id");
|
||||
float scale = entry.at("scale");
|
||||
if (0 <= id && id < max_idx) {
|
||||
ctx_server.params.lora_adapters[id].scale = scale;
|
||||
ctx_server.lora_adapters[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue