diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 055e2c5b8..67d704f1b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -64,7 +64,6 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, }; // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 @@ -91,6 +90,8 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + std::vector lora; + std::vector antiprompt; std::vector response_fields; bool timings_per_token = false; @@ -114,6 +115,11 @@ struct slot_params { samplers.emplace_back(common_sampler_type_to_str(sampler)); } + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + return json { {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, @@ -154,6 +160,7 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, }; } }; @@ -189,6 +196,7 @@ struct server_task { const llama_model * model, const llama_context * ctx, const common_params & params_base, + const std::vector & base_lora, const json & data) { slot_params params; @@ -245,6 +253,16 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(base_lora, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = base_lora; + } + // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { @@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result { } }; -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - struct server_slot { int id; int id_task = -1; @@ -1009,6 +1021,8 @@ struct server_slot { common_speculative * spec = nullptr; + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -1091,7 +1105,8 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) { - return is_non_causal() == other_slot.is_non_causal(); + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { @@ -1503,7 +1518,7 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector loras; + std::vector lora; llama_model * model_dft = nullptr; llama_context_params cparams_dft; @@ -1570,7 +1585,7 @@ struct server_context { model = llama_init.model; ctx = llama_init.context; - loras = llama_init.lora_adapters; + lora = llama_init.lora_adapters; if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); @@ -1776,6 +1791,12 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = std::move(task.params.lora); + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { @@ -2465,13 +2486,6 @@ struct server_context { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; - case SERVER_TASK_TYPE_SET_LORA: - { - common_lora_adapters_apply(ctx, loras); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; } } @@ -2808,8 +2822,12 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal()); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_lora_adapters_apply(ctx, slot_batched->lora); + } // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { @@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) { task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); - task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data); + task.params = server_task::params_from_json_cmpl( + ctx_server.model, + ctx_server.ctx, + ctx_server.params_base, + ctx_server.lora, + data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat @@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.loras.size(); ++i) { - auto & lora = ctx_server.loras[i]; + for (size_t i = 0; i < ctx_server.lora.size(); ++i) { + auto & lora = ctx_server.lora[i]; result.push_back({ {"id", i}, {"path", lora.path}, @@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) { }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.loras.size(); - - // clear existing value - for (auto & lora : ctx_server.loras) { - lora.scale = 0.0f; - } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.loras[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - server_task_result_ptr result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - if (result->is_error()) { - res_error(res, result->to_json()); + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return; } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); + ctx_server.lora = parse_lora_request(ctx_server.lora, body); + res_ok(res, json{{"success", true}}); }; // diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 749615449..9167c2f8e 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -40,3 +40,37 @@ def test_lora(scale: float, re_content: str): assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + # FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + }) + ) for lora, re_test in lora_config] + results = parallel_function_calls(tasks) + + print(results) + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f192..573c379f1 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -771,3 +771,44 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of base_lora with updated scale +static std::vector parse_lora_request( + const std::vector & base_lora, + const json & data) { + std::vector lora(base_lora); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (auto entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +}