From fef64c587d385d65f7c14ccfccb9f00b2afd02db Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Mar 2024 15:36:14 +0200 Subject: [PATCH] server : code style --- examples/server/oai.hpp | 112 +-- examples/server/server.cpp | 1500 ++++++++++++++++-------------------- examples/server/utils.hpp | 65 +- 3 files changed, 737 insertions(+), 940 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index dddc26b69..e57eb01ba 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -12,9 +12,8 @@ using json = nlohmann::json; inline static json oaicompat_completion_params_parse( const struct llama_model * model, - const json &body, /* openai api json semantics */ - const std::string &chat_template) -{ + const json & body, /* openai api json semantics */ + const std::string & chat_template) { json llama_params; llama_params["__oaicompat"] = true; @@ -27,26 +26,26 @@ inline static json oaicompat_completion_params_parse( // // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); - llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); - llama_params["temperature"] = json_value(body, "temperature", 0.0); - llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); - llama_params["top_p"] = json_value(body, "top_p", 1.0); - llama_params["n_predict"] = json_value(body, "max_tokens", -1); - llama_params["logit_bias"] = json_value(body, "logit_bias",json::object()); - llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); - llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); - llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); - llama_params["stream"] = json_value(body, "stream", false); - llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); - llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); - llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); - llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); - llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); + llama_params["model"] = json_value(body, "model", std::string("unknown")); + llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); + llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); + llama_params["temperature"] = json_value(body, "temperature", 0.0); + llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); + llama_params["top_p"] = json_value(body, "top_p", 1.0); + llama_params["n_predict"] = json_value(body, "max_tokens", -1); + llama_params["logit_bias"] = json_value(body, "logit_bias", json::object()); + llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); + llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); + llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); + llama_params["stream"] = json_value(body, "stream", false); + llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); + llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); + llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); + llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); + llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); if (body.count("grammar") != 0) { llama_params["grammar"] = json_value(body, "grammar", json::object()); @@ -65,8 +64,7 @@ inline static json oaicompat_completion_params_parse( return llama_params; } -inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) -{ +inline static json format_final_response_oaicompat(const json & request, const task_result & response, bool streaming = false) { json result = response.result_json; bool stopped_word = result.count("stopped_word") != 0; @@ -91,17 +89,19 @@ inline static json format_final_response_oaicompat(const json &request, const ta std::time_t t = std::time(0); - json res = - json{{"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", - json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; + json res = json { + {"choices", choices}, + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + }}, + {"id", gen_chatcmplid()} + }; if (server_verbose) { res["__verbose"] = result; @@ -125,10 +125,10 @@ inline static std::vector format_partial_response_oaicompat(const task_res bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); + std::string content = json_value(result, "content", std::string("")); std::string finish_reason; if (stopped_word || stopped_eos) { @@ -196,26 +196,28 @@ inline static std::vector format_partial_response_oaicompat(const task_res } } - json ret = json{{"choices", choices}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; return std::vector({ret}); } -inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) -{ - json res = - json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", - json{{"prompt_tokens", 0}, - {"total_tokens", 0}}}, - {"data", embeddings} - }; +inline static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"data", embeddings} + }; + return res; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 29fe96f83..d918cdbe4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1068,8 +1068,8 @@ struct llama_server_context { task.id_multi = id_multi; task.id_target = 0; task.data = std::move(data); - task.infill_mode = infill; - task.embedding_mode = embedding; + task.infill = infill; + task.embedding = embedding; task.type = TASK_TYPE_COMPLETION; // when a completion task's prompt array is not a singleton, we split it into multiple requests @@ -1132,71 +1132,74 @@ struct llama_server_context { subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode); + request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding); } } void process_single_task(task_server & task) { - switch (task.type) - { - case TASK_TYPE_COMPLETION: { - server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); - queue_tasks.defer(task); - break; - } - - if (task.data.contains("system_prompt")) { - system_prompt_process(task.data["system_prompt"]); - - // reset cache_tokens for all slots - for (server_slot & slot : slots) - { - slot.cache_tokens.clear(); - slot.n_past = 0; - slot.n_past_se = 0; - } - } - - slot->reset(); - - slot->id_task = task.id; - slot->id_multi = task.id_multi; - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - - if (!launch_slot_with_data(*slot, task.data)) + switch (task.type) { + case TASK_TYPE_COMPLETION: { - // send error result - send_error(task, "internal_error"); - break; - } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) { - if (slot.id_task == task.id_target) { - slot.release(); + server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); + queue_tasks.defer(task); break; } - } - } break; - case TASK_TYPE_NEXT_RESPONSE: { - // do nothing - } break; - case TASK_TYPE_METRICS: { - json slots_data = json::array(); - int n_idle_slots = 0; - int n_processing_slots = 0; - for (server_slot & slot: slots) { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { + if (task.data.contains("system_prompt")) { + system_prompt_process(task.data["system_prompt"]); + + // reset cache_tokens for all slots + for (server_slot & slot : slots) + { + slot.cache_tokens.clear(); + slot.n_past = 0; + slot.n_past_se = 0; + } + } + + slot->reset(); + + slot->id_task = task.id; + slot->id_multi = task.id_multi; + slot->infill = task.infill; + slot->embedding = task.embedding; + + if (!launch_slot_with_data(*slot, task.data)) { + // send error result + send_error(task, "internal_error"); + break; + } + } break; + case TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case TASK_TYPE_METRICS: + { + json slots_data = json::array(); + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot: slots) { + json slot_data = get_formated_generation(slot); + slot_data["id"] = slot.id; + slot_data["id_task"] = slot.id_task; + slot_data["state"] = slot.state; + slot_data["prompt"] = slot.prompt; + slot_data["next_token"] = { {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining}, {"num_tokens_predicted", slot.n_decoded}, @@ -1204,33 +1207,33 @@ struct llama_server_context { {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, - }; - if (slot_data["state"] == IDLE) { - n_idle_slots++; - } else { - n_processing_slots++; + }; + if (slot_data["state"] == IDLE) { + n_idle_slots++; + } else { + n_processing_slots++; + } + slots_data.push_back(slot_data); } - slots_data.push_back(slot_data); - } - LOG_INFO("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots} - }); + LOG_INFO("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots} + }); - LOG_VERBOSE("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data} - }); + LOG_VERBOSE("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots}, + {"slots", slots_data} + }); - task_result res; - res.id = task.id; - res.id_multi = task.id_multi; - res.stop = true; - res.error = false; - res.result_json = { + task_result res; + res.id = task.id; + res.id_multi = task.id_multi; + res.stop = true; + res.error = false; + res.result_json = { { "idle", n_idle_slots }, { "processing", n_processing_slots }, { "deferred", queue_tasks.queue_tasks_deferred.size() }, @@ -1247,42 +1250,37 @@ struct llama_server_context { { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, { "slots", slots_data }, - }; + }; - metrics.reset_bucket(); - queue_results.send(res); - } break; + metrics.reset_bucket(); + queue_results.send(res); + } break; } } - void on_finish_multitask(task_multi& multitask) - { + void on_finish_multitask(const task_multi & multitask) { // all subtasks done == multitask is done task_result result; - result.id = multitask.id; - result.stop = true; + result.id = multitask.id; + result.stop = true; result.error = false; // collect json results into one json result std::vector result_jsons; - for (auto& subres : multitask.results) - { + for (const auto & subres : multitask.results) { result_jsons.push_back(subres.result_json); result.error = result.error && subres.error; } result.result_json = json{ { "results", result_jsons } }; + queue_results.send(result); } bool update_slots() { - if (system_need_update) - { - LOG_INFO("updating system prompt", {}); + if (system_need_update) { system_prompt_update(); } - llama_batch_clear(batch); - // release slots for (auto & slot : slots) { if (slot.command == RELEASE) { @@ -1371,6 +1369,8 @@ struct llama_server_context { } } + llama_batch_clear(batch); + // decode any currently ongoing sequences for (auto & slot : slots) { if (slot.state == IDLE) { @@ -1619,17 +1619,13 @@ struct llama_server_context { return true; } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) - { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - for (auto & slot : slots) - { - if (slot.ga_n != 1) - { + for (auto & slot : slots) { + if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; @@ -1649,12 +1645,12 @@ struct llama_server_context { LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); } + slot.n_past_se += n_tokens; } } - llama_batch batch_view = - { + llama_batch batch_view = { n_tokens, batch.token + i, nullptr, @@ -1667,10 +1663,8 @@ struct llama_server_context { const int ret = llama_decode(ctx, batch_view); - if (ret != 0) - { - if (n_batch == 1 || ret < 0) - { + if (ret != 0) { + if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); return false; @@ -1681,19 +1675,17 @@ struct llama_server_context { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; + continue; } - for (auto & slot : slots) - { - if (slot.state != PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) - { + for (auto & slot : slots) { + if (slot.state != PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; } // prompt evaluated for embedding - if (slot.embedding) - { + if (slot.embedding) { send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -1706,8 +1698,7 @@ struct llama_server_context { llama_sampling_accept(slot.ctx_sampling, ctx, id, true); slot.n_decoded += 1; - if (slot.n_decoded == 1) - { + if (slot.n_decoded == 1) { slot.t_start_genereration = ggml_time_us(); slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); @@ -1717,19 +1708,16 @@ struct llama_server_context { result.tok = id; const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) - { + if (slot.sparams.temp <= 0 && n_probs > 0) { // for llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) - { + for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - if (!process_token(result, slot)) - { + if (!process_token(result, slot)) { slot.release(); slot.print_timings(); send_final_response(slot); @@ -1741,10 +1729,11 @@ struct llama_server_context { } LOG_VERBOSE("slots updated", {}); + return true; } - json model_meta() { + json model_meta() const { return json { {"vocab_type", llama_vocab_type(model)}, {"n_vocab", llama_n_vocab(model)}, @@ -1849,49 +1838,34 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, std::string arg; bool invalid_param = false; - for (int i = 1; i < argc; i++) - { + for (int i = 1; i < argc; i++) { arg = argv[i]; - if (arg == "--port") - { - if (++i >= argc) - { + if (arg == "--port") { + if (++i >= argc) { invalid_param = true; break; } sparams.port = std::stoi(argv[i]); - } - else if (arg == "--host") - { - if (++i >= argc) - { + } else if (arg == "--host") { + if (++i >= argc) { invalid_param = true; break; } sparams.hostname = argv[i]; - } - else if (arg == "--path") - { - if (++i >= argc) - { + } else if (arg == "--path") { + if (++i >= argc) { invalid_param = true; break; } sparams.public_path = argv[i]; - } - else if (arg == "--api-key") - { - if (++i >= argc) - { + } else if (arg == "--api-key") { + if (++i >= argc) { invalid_param = true; break; } sparams.api_keys.emplace_back(argv[i]); - } - else if (arg == "--api-key-file") - { - if (++i >= argc) - { + } else if (arg == "--api-key-file") { + if (++i >= argc) { invalid_param = true; break; } @@ -1908,53 +1882,36 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } key_file.close(); - } - else if (arg == "--timeout" || arg == "-to") - { - if (++i >= argc) - { + } else if (arg == "--timeout" || arg == "-to") { + if (++i >= argc) { invalid_param = true; break; } sparams.read_timeout = std::stoi(argv[i]); sparams.write_timeout = std::stoi(argv[i]); - } - else if (arg == "-m" || arg == "--model") - { - if (++i >= argc) - { + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { invalid_param = true; break; } params.model = argv[i]; - } - else if (arg == "-a" || arg == "--alias") - { - if (++i >= argc) - { + } else if (arg == "-a" || arg == "--alias") { + if (++i >= argc) { invalid_param = true; break; } params.model_alias = argv[i]; - } - else if (arg == "-h" || arg == "--help") - { + } else if (arg == "-h" || arg == "--help") { server_print_usage(argv[0], default_params, default_sparams); exit(0); - } - else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size") - { - if (++i >= argc) - { + } else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size") { + if (++i >= argc) { invalid_param = true; break; } params.n_ctx = std::stoi(argv[i]); - } - else if (arg == "--rope-scaling") - { - if (++i >= argc) - { + } else if (arg == "--rope-scaling") { + if (++i >= argc) { invalid_param = true; break; } @@ -1963,59 +1920,44 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else { invalid_param = true; break; } - } - else if (arg == "--rope-freq-base") - { - if (++i >= argc) - { + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { invalid_param = true; break; } params.rope_freq_base = std::stof(argv[i]); - } - else if (arg == "--rope-freq-scale") - { - if (++i >= argc) - { + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { invalid_param = true; break; } params.rope_freq_scale = std::stof(argv[i]); - } - else if (arg == "--yarn-ext-factor") - { + } else if (arg == "--yarn-ext-factor") { if (++i >= argc) { invalid_param = true; break; } params.yarn_ext_factor = std::stof(argv[i]); } - else if (arg == "--yarn-attn-factor") - { + else if (arg == "--yarn-attn-factor") { if (++i >= argc) { invalid_param = true; break; } params.yarn_attn_factor = std::stof(argv[i]); - } - else if (arg == "--yarn-beta-fast") - { + } else if (arg == "--yarn-beta-fast") { if (++i >= argc) { invalid_param = true; break; } params.yarn_beta_fast = std::stof(argv[i]); - } - else if (arg == "--yarn-beta-slow") - { + } else if (arg == "--yarn-beta-slow") { if (++i >= argc) { invalid_param = true; break; } params.yarn_beta_slow = std::stof(argv[i]); - } - else if (arg == "--pooling") - { + } else if (arg == "--pooling") { if (++i >= argc) { invalid_param = true; break; @@ -2025,108 +1967,79 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else { invalid_param = true; break; } - } - else if (arg == "--threads" || arg == "-t") - { + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) { invalid_param = true; break; } params.n_threads = std::stoi(argv[i]); - } - else if (arg == "--grp-attn-n" || arg == "-gan") - { + } else if (arg == "--grp-attn-n" || arg == "-gan") { if (++i >= argc) { invalid_param = true; break; } params.grp_attn_n = std::stoi(argv[i]); - } - else if (arg == "--grp-attn-w" || arg == "-gaw") - { - if (++i >= argc) - { + } else if (arg == "--grp-attn-w" || arg == "-gaw") { + if (++i >= argc) { invalid_param = true; break; } params.grp_attn_w = std::stoi(argv[i]); - } - else if (arg == "--threads-batch" || arg == "-tb") - { - if (++i >= argc) - { + } else if (arg == "--threads-batch" || arg == "-tb") { + if (++i >= argc) { invalid_param = true; break; } params.n_threads_batch = std::stoi(argv[i]); - } - else if (arg == "--threads-http") - { - if (++i >= argc) - { + } else if (arg == "--threads-http") { + if (++i >= argc) { invalid_param = true; break; } sparams.n_threads_http = std::stoi(argv[i]); - } - else if (arg == "-b" || arg == "--batch-size") - { - if (++i >= argc) - { + } else if (arg == "-b" || arg == "--batch-size") { + if (++i >= argc) { invalid_param = true; break; } params.n_batch = std::stoi(argv[i]); - } - else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") - { - if (++i >= argc) - { + } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { invalid_param = true; break; } if (llama_supports_gpu_offload()) { params.n_gpu_layers = std::stoi(argv[i]); } else { - LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - "See main README.md for information on enabling GPU BLAS support", - {{"n_gpu_layers", params.n_gpu_layers}}); + LOG_WARNING( + "Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " + "See main README.md for information on enabling GPU BLAS support", + {{"n_gpu_layers", params.n_gpu_layers}}); } - } - else if (arg == "--split-mode" || arg == "-sm") - { + } else if (arg == "--split-mode" || arg == "-sm") { if (++i >= argc) { invalid_param = true; break; } std::string arg_next = argv[i]; - if (arg_next == "none") - { + if (arg_next == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; - } - else if (arg_next == "layer") - { + } else if (arg_next == "layer") { params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } - else if (arg_next == "row") - { + } else if (arg_next == "row") { params.split_mode = LLAMA_SPLIT_MODE_ROW; - } - else { + } else { invalid_param = true; break; } #ifndef GGML_USE_CUBLAS fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n"); #endif // GGML_USE_CUBLAS - } - else if (arg == "--tensor-split" || arg == "-ts") - { - if (++i >= argc) - { + } else if (arg == "--tensor-split" || arg == "-ts") { + if (++i >= argc) { invalid_param = true; break; } @@ -2139,25 +2052,18 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, std::vector split_arg{it, {}}; GGML_ASSERT(split_arg.size() <= llama_max_devices()); - for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) - { - if (i_device < split_arg.size()) - { + for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { + if (i_device < split_arg.size()) { params.tensor_split[i_device] = std::stof(split_arg[i_device]); - } - else - { + } else { params.tensor_split[i_device] = 0.0f; } } #else LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); #endif // GGML_USE_CUBLAS - } - else if (arg == "--main-gpu" || arg == "-mg") - { - if (++i >= argc) - { + } else if (arg == "--main-gpu" || arg == "-mg") { + if (++i >= argc) { invalid_param = true; break; } @@ -2166,59 +2072,42 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, #else LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {}); #endif - } - else if (arg == "--lora") - { - if (++i >= argc) - { + } else if (arg == "--lora") { + if (++i >= argc) { invalid_param = true; break; } params.lora_adapter.emplace_back(argv[i], 1.0f); params.use_mmap = false; - } - else if (arg == "--lora-scaled") - { - if (++i >= argc) - { + } else if (arg == "--lora-scaled") { + if (++i >= argc) { invalid_param = true; break; } const char * lora_adapter = argv[i]; - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.use_mmap = false; - } - else if (arg == "--lora-base") - { - if (++i >= argc) - { + } else if (arg == "--lora-base") { + if (++i >= argc) { invalid_param = true; break; } params.lora_base = argv[i]; - } - else if (arg == "-v" || arg == "--verbose") - { + } else if (arg == "-v" || arg == "--verbose") { #if SERVER_VERBOSE != 1 LOG_WARNING("server.cpp is not built with verbose logging.", {}); #else server_verbose = true; #endif - } - else if (arg == "--mlock") - { + } else if (arg == "--mlock") { params.use_mlock = true; - } - else if (arg == "--no-mmap") - { + } else if (arg == "--no-mmap") { params.use_mmap = false; - } - else if (arg == "--numa") { + } else if (arg == "--numa") { if (++i >= argc) { invalid_param = true; break; @@ -2229,35 +2118,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else { invalid_param = true; break; } } - } - else if (arg == "--embedding") - { + } else if (arg == "--embedding") { params.embedding = true; - } - else if (arg == "-cb" || arg == "--cont-batching") - { + } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; - } - else if (arg == "-np" || arg == "--parallel") - { - if (++i >= argc) - { + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { invalid_param = true; break; } params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-n" || arg == "--n-predict") - { - if (++i >= argc) - { + } else if (arg == "-n" || arg == "--n-predict") { + if (++i >= argc) { invalid_param = true; break; } params.n_predict = std::stoi(argv[i]); - } else if (arg == "-spf" || arg == "--system-prompt-file") - { - if (++i >= argc) - { + } else if (arg == "-spf" || arg == "--system-prompt-file") { + if (++i >= argc) { invalid_param = true; break; } @@ -2274,51 +2152,32 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, std::back_inserter(systm_content) ); llama.system_prompt_process(json::parse(systm_content)); - } - else if (arg == "-ctk" || arg == "--cache-type-k") { + } else if (arg == "-ctk" || arg == "--cache-type-k") { params.cache_type_k = argv[++i]; - } - else if (arg == "-ctv" || arg == "--cache-type-v") { + } else if (arg == "-ctv" || arg == "--cache-type-v") { params.cache_type_v = argv[++i]; - } - else if (arg == "--log-format") - { - if (++i >= argc) - { + } else if (arg == "--log-format") { + if (++i >= argc) { invalid_param = true; break; } - if (std::strcmp(argv[i], "json") == 0) - { + if (std::strcmp(argv[i], "json") == 0) { server_log_json = true; - } - else if (std::strcmp(argv[i], "text") == 0) - { + } else if (std::strcmp(argv[i], "text") == 0) { server_log_json = false; - } - else - { + } else { invalid_param = true; break; } - } - else if (arg == "--log-disable") - { + } else if (arg == "--log-disable") { log_set_target(stdout); LOG_INFO("logging to file is disabled.", {}); - } - else if (arg == "--slots-endpoint-disable") - { + } else if (arg == "--slots-endpoint-disable") { sparams.slots_endpoint = false; - } - else if (arg == "--metrics") - { + } else if (arg == "--metrics") { sparams.metrics_endpoint = true; - } - else if (arg == "--chat-template") - { - if (++i >= argc) - { + } else if (arg == "--chat-template") { + if (++i >= argc) { invalid_param = true; break; } @@ -2329,9 +2188,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } sparams.chat_template = argv[i]; - } - else if (arg == "--override-kv") - { + } else if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; break; @@ -2342,6 +2199,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } + struct llama_model_kv_override kvo; std::strncpy(kvo.key, argv[i], sep - argv[i]); kvo.key[sep - argv[i]] = 0; @@ -2372,66 +2230,40 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } params.kv_overrides.push_back(kvo); - } - else - { + } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); exit(1); } } + if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; } - if (invalid_param) - { + if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); exit(1); } } -/* llama.cpp completion api semantics */ -static json format_partial_response( - llama_server_context & llama, server_slot * slot, const std::string & content, const std::vector & probs -) { - json res = json { - {"content", content}, - {"stop", false}, - {"id_slot", slot->id}, - {"multimodal", false}, - }; - - if (slot->sparams.n_probs > 0) - { - res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); - } - - return res; -} - -static json format_tokenizer_response(const std::vector &tokens) -{ +static json format_tokenizer_response(const std::vector & tokens) { return json { {"tokens", tokens} }; } -static json format_detokenized_response(std::string content) -{ +static json format_detokenized_response(const std::string & content) { return json { {"content", content} }; } - -static void log_server_request(const httplib::Request &req, const httplib::Response &res) -{ +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port - if (req.path == "/v1/health" || req.path == "/v1/completions") - { + if (req.path == "/v1/health" || req.path == "/v1/completions") { return; } @@ -2450,22 +2282,6 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo }); } -static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, server_slot *slot) -{ - auto & gtps = slot->generated_token_probs; - auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; - const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); - if (slot->generated_text.capacity() < slot->generated_text.size() + len) - { - slot->generated_text.reserve(slot->generated_text.size() + len); - } - for (const completion_token_output & cto : gtps) - { - slot->generated_text += translator(cto); - } -} - std::function shutdown_handler; std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; @@ -2476,6 +2292,7 @@ inline void signal_handler(int signal) { fprintf(stderr, "Received second interrupt, terminating immediately.\n"); exit(1); } + shutdown_handler(signal); } @@ -2500,15 +2317,17 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, - {"commit", LLAMA_COMMIT}}); + LOG_INFO("build info", { + {"build", LLAMA_BUILD_NUMBER}, + {"commit", LLAMA_COMMIT} + }); LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + {"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); httplib::Server svr; @@ -2517,69 +2336,76 @@ int main(int argc, char ** argv) { svr.set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); }); - svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { + svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); - switch(current_state) { - case SERVER_STATE_READY: { - // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; - task.id_target = -1; + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + task_server task; + task.id = llama.queue_tasks.get_new_id(); + task.type = TASK_TYPE_METRICS; + task.id_target = -1; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + llama.queue_results.add_waiting_task_id(task.id); + llama.queue_tasks.post(task); - // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + // get the result + task_result result = llama.queue_results.recv(task.id); + llama.queue_results.remove_waiting_task_id(task.id); - int n_idle_slots = result.result_json["idle"]; - int n_processing_slots = result.result_json["processing"]; + const int n_idle_slots = result.result_json["idle"]; + const int n_processing_slots = result.result_json["processing"]; - json health = { + json health = { {"status", "ok"}, {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots}}; - res.status = 200; // HTTP OK - if (sparams.slots_endpoint && req.has_param("include_slots")) { - health["slots"] = result.result_json["slots"]; - } + {"slots_processing", n_processing_slots} + }; - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable + res.status = 200; // HTTP OK + if (sparams.slots_endpoint && req.has_param("include_slots")) { + health["slots"] = result.result_json["slots"]; } + + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; } - res.set_content(health.dump(), "application/json"); - break; - } case SERVER_STATE_LOADING_MODEL: - res.set_content(R"({"status": "loading model"})", "application/json"); - res.status = 503; // HTTP Service Unavailable - break; + { + res.set_content(R"({"status": "loading model"})", "application/json"); + res.status = 503; // HTTP Service Unavailable + } break; case SERVER_STATE_ERROR: - res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); - res.status = 500; // HTTP Internal Server Error - break; + { + res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); + res.status = 500; // HTTP Internal Server Error + } break; } }); if (sparams.slots_endpoint) { - svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue task_server task; task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; + task.id_multi = -1; task.id_target = -1; + task.type = TASK_TYPE_METRICS; llama.queue_results.add_waiting_task_id(task.id); llama.queue_tasks.post(task); @@ -2594,12 +2420,13 @@ int main(int argc, char ** argv) { } if (sparams.metrics_endpoint) { - svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue task_server task; task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; + task.id_multi = -1; task.id_target = -1; + task.type = TASK_TYPE_METRICS; llama.queue_results.add_waiting_task_id(task.id); llama.queue_tasks.post(task); @@ -2610,59 +2437,62 @@ int main(int argc, char ** argv) { json data = result.result_json; - uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; - uint64_t t_prompt_processing = data["t_prompt_processing"]; + const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; + const uint64_t t_prompt_processing = data["t_prompt_processing"]; - uint64_t n_tokens_predicted = data["n_tokens_predicted"]; - uint64_t t_tokens_generation = data["t_tokens_generation"]; + const uint64_t n_tokens_predicted = data["n_tokens_predicted"]; + const uint64_t t_tokens_generation = data["t_tokens_generation"]; - int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; + const int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", data["n_prompt_tokens_processed_total"]} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", data["n_tokens_predicted_total"]} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", data["kv_cache_tokens_count"]} - },{ - {"name", "requests_processing"}, - {"help", "Number of request processing."}, - {"value", data["processing"]} - },{ - {"name", "requests_deferred"}, - {"help", "Number of request deferred."}, - {"value", data["deferred"]} - }}} + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", data["n_prompt_tokens_processed_total"]} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", data["n_tokens_predicted_total"]} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0} + },{ + {"name", "kv_cache_usage_ratio"}, + {"help", "KV-cache usage. 1 means 100 percent usage."}, + {"value", 1. * kv_cache_used_cells / params.n_ctx} + },{ + {"name", "kv_cache_tokens"}, + {"help", "KV-cache tokens."}, + {"value", data["kv_cache_tokens_count"]} + },{ + {"name", "requests_processing"}, + {"help", "Number of request processing."}, + {"value", data["processing"]} + },{ + {"name", "requests_deferred"}, + {"help", "Number of request deferred."}, + {"value", data["deferred"]} + }}} }; std::stringstream prometheus; + for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); + const auto & type = el.key(); const auto & metrics_def = el.value(); + for (const auto & metric_def : metrics_def) { - std::string name = metric_def["name"]; - std::string help = metric_def["help"]; + const std::string name = metric_def["name"]; + const std::string help = metric_def["help"]; + auto value = json_value(metric_def, "value", 0); prometheus << "# HELP llamacpp:" << name << " " << help << "\n" << "# TYPE llamacpp:" << name << " " << type << "\n" @@ -2677,49 +2507,39 @@ int main(int argc, char ** argv) { svr.set_logger(log_server_request); - svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { - const char fmt[] = "500 Internal Server Error\n%s"; - char buf[BUFSIZ]; - try - { - std::rethrow_exception(std::move(ep)); - } - catch (std::exception &e) - { - snprintf(buf, sizeof(buf), fmt, e.what()); - } - catch (...) - { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); - } - res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; - }); + svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + const char fmt[] = "500 Internal Server Error\n%s"; - svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { - if (res.status == 401) - { - res.set_content("Unauthorized", "text/plain; charset=utf-8"); - } - if (res.status == 400) - { - res.set_content("Invalid request", "text/plain; charset=utf-8"); - } - else if (res.status == 404) - { - res.set_content("File Not Found", "text/plain; charset=utf-8"); - res.status = 404; - } - }); + char buf[BUFSIZ]; + try { + std::rethrow_exception(std::move(ep)); + } catch (std::exception &e) { + snprintf(buf, sizeof(buf), fmt, e.what()); + } catch (...) { + snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + } + + res.set_content(buf, "text/plain; charset=utf-8"); + res.status = 500; + }); + + svr.set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 401) { + res.set_content("Unauthorized", "text/plain; charset=utf-8"); + } + if (res.status == 400) { + res.set_content("Invalid request", "text/plain; charset=utf-8"); + } + if (res.status == 404) { + res.set_content("File Not Found", "text/plain; charset=utf-8"); + } + }); // set timeouts and change hostname and port svr.set_read_timeout (sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); - if (!svr.bind_to_port(sparams.hostname, sparams.port)) - { + if (!svr.bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } @@ -2728,8 +2548,9 @@ int main(int argc, char ** argv) { svr.set_base_dir(sparams.public_path); std::unordered_map log_data; + log_data["hostname"] = sparams.hostname; - log_data["port"] = std::to_string(sparams.port); + log_data["port"] = std::to_string(sparams.port); if (sparams.api_keys.size() == 1) { log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); @@ -2738,19 +2559,19 @@ int main(int argc, char ** argv) { } // load the model - if (!llama.load_model(params)) - { + if (!llama.load_model(params)) { state.store(SERVER_STATE_ERROR); return 1; } else { llama.initialize(); state.store(SERVER_STATE_READY); - LOG_INFO("model loaded", {}); } + + LOG_INFO("model loaded", {}); + const auto model_meta = llama.model_meta(); if (sparams.chat_template.empty()) { // custom chat template is not supplied - // check if the template comes with the model is supported by us llama.validate_model_chat_template(sparams); } @@ -2763,6 +2584,7 @@ int main(int argc, char ** argv) { // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); + std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); @@ -2781,149 +2603,145 @@ int main(int argc, char ** argv) { }; // this is only called if no index.html is found in the public --path - svr.Get("/", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); - return false; - }); + svr.Get("/", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); + return false; + }); // this is only called if no index.js is found in the public --path - svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); - return false; - }); + svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); + return false; + }); // this is only called if no index.html is found in the public --path - svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); - return false; - }); + svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); + return false; + }); // this is only called if no index.html is found in the public --path - svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); - return false; - }); + svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); + return false; + }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = { - { "user_name", llama.name_user.c_str() }, - { "assistant_name", llama.name_assistant.c_str() }, - { "default_generation_settings", llama.default_generation_settings_for_props }, - { "total_slots", llama.params.n_parallel } - }; - res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json data = { + { "user_name", llama.name_user.c_str() }, + { "assistant_name", llama.name_assistant.c_str() }, + { "default_generation_settings", llama.default_generation_settings_for_props }, + { "total_slots", llama.params.n_parallel } + }; - svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = json::parse(req.body); - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, data, false, false); - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.queue_results.recv(id_task); - if (!result.error && result.stop) { - res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } - else - { - res.status = 404; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - } - llama.queue_results.remove_waiting_task_id(id_task); - } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) - { - while (true) - { - task_result result = llama.queue_results.recv(id_task); - if (!result.error) { - const std::string str = - "data: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(id_task); - return false; - } - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(id_task); - return false; - } - break; - } - } + res.set_content(data.dump(), "application/json; charset=utf-8"); + }); - llama.queue_results.remove_waiting_task_id(id_task); - sink.done(); - return true; - }; - - auto on_complete = [id_task, &llama] (bool) - { - // cancel - llama.request_cancel(id_task); - llama.queue_results.remove_waiting_task_id(id_task); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); - - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::time_t t = std::time(0); - - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", t}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }); - - const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) - { + svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } + + json data = json::parse(req.body); + + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, -1, data, false, false); + + if (!json_value(data, "stream", false)) { + task_result result = llama.queue_results.recv(id_task); + if (!result.error && result.stop) { + res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } else { + res.status = 404; + res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + } + + llama.queue_results.remove_waiting_task_id(id_task); + } else { + const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { + while (true) { + task_result result = llama.queue_results.recv(id_task); + if (!result.error) { + const std::string str = + "data: " + + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(id_task); + return false; + } + + if (result.stop) { + break; + } + } else { + const std::string str = + "error: " + + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(id_task); + return false; + } + + break; + } + } + + llama.queue_results.remove_waiting_task_id(id_task); + sink.done(); + + return true; + }; + + auto on_complete = [id_task, &llama] (bool) { + // cancel + llama.request_cancel(id_task); + llama.queue_results.remove_waiting_task_id(id_task); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }); + + svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; + + res.set_content(models.dump(), "application/json; charset=utf-8"); + }); + + const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!validate_api_key(req, res)) { + return; + } + json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); const int id_task = llama.queue_tasks.get_new_id(); @@ -2931,22 +2749,19 @@ int main(int argc, char ** argv) { llama.request_completion(id_task, -1, data, false, false); if (!json_value(data, "stream", false)) { - std::string completion_text; task_result result = llama.queue_results.recv(id_task); if (!result.error && result.stop) { json oaicompat_result = format_final_response_oaicompat(data, result); - res.set_content(oaicompat_result.dump(-1, ' ', false, - json::error_handler_t::replace), - "application/json; charset=utf-8"); + res.set_content(oaicompat_result.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { res.status = 500; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); } llama.queue_results.remove_waiting_task_id(id_task); } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink &sink) { + const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { while (true) { task_result llama_result = llama.queue_results.recv(id_task); if (!llama_result.error) { @@ -2998,193 +2813,180 @@ int main(int argc, char ** argv) { } }; - svr.Post("/chat/completions", chat_completions); + svr.Post("/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = json::parse(req.body); - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, data, true, false); - if (!json_value(data, "stream", false)) { - std::string completion_text; + svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!validate_api_key(req, res)) { + return; + } + + json data = json::parse(req.body); + + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, -1, data, true, false); + if (!json_value(data, "stream", false)) { + task_result result = llama.queue_results.recv(id_task); + if (!result.error && result.stop) { + res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } else { + res.status = 404; + res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + } + + llama.queue_results.remove_waiting_task_id(id_task); + } else { + const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { + while (true) { task_result result = llama.queue_results.recv(id_task); - if (!result.error && result.stop) - { - res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } - else - { - res.status = 404; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - } - llama.queue_results.remove_waiting_task_id(id_task); - } else { - const auto chunked_content_provider = [id_task, &llama](size_t, httplib::DataSink & sink) { - while (true) - { - task_result result = llama.queue_results.recv(id_task); - if (!result.error) { - const std::string str = - "data: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(id_task); - return false; - } - if (result.stop) - { - break; - } - } - else - { - break; - } - } + if (!result.error) { + const std::string str = + "data: " + + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; - llama.queue_results.remove_waiting_task_id(id_task); - sink.done(); - return true; - }; + LOG_VERBOSE("data stream", { + { "to_send", str } + }); - auto on_complete = [id_task, &llama] (bool) - { - // cancel - llama.request_cancel(id_task); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); - - svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res) - { return res.set_content("", "application/json; charset=utf-8"); }); - - svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::vector tokens; - if (body.count("content") != 0) - { - tokens = llama.tokenize(body["content"], false); - } - const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::string content; - if (body.count("tokens") != 0) - { - const std::vector tokens = body["tokens"]; - content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); - } - - const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - json prompt; - if (body.count("content") != 0) - { - prompt = body["content"]; - } - else - { - prompt = ""; - } - - // create and queue the task - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); - - // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); - - // send the result - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - - json prompt; - if (body.count("input") != 0) - { - prompt = body["input"]; - // batch - if(prompt.is_array()) { - json data = json::array(); - int i = 0; - for (const json &elem : prompt) { - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); - - // get the result - task_result result = llama.queue_results.recv(id_task); + if (!sink.write(str.c_str(), str.size())) { llama.queue_results.remove_waiting_task_id(id_task); - - json embedding = json{ - {"embedding", json_value(result.result_json, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - data.push_back(embedding); + return false; } - json result = format_embeddings_response_oaicompat(body, data); - return res.set_content(result.dump(), "application/json; charset=utf-8"); + + if (result.stop) { + break; + } + } else { + break; } } - else - { - prompt = ""; + + llama.queue_results.remove_waiting_task_id(id_task); + sink.done(); + + return true; + }; + + auto on_complete = [id_task, &llama] (bool) { + llama.request_cancel(id_task); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }); + + svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { + return res.set_content("", "application/json; charset=utf-8"); + }); + + svr.Post("/tokenize", [&llama](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + std::vector tokens; + if (body.count("content") != 0) { + tokens = llama.tokenize(body["content"], false); + } + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/detokenize", [&llama](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + std::string content; + if (body.count("tokens") != 0) { + const std::vector tokens = body["tokens"]; + content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/embedding", [&llama](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + + json prompt; + if (body.count("content") != 0) { + prompt = body["content"]; + } else { + prompt = ""; + } + + // create and queue the task + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); + + // get the result + task_result result = llama.queue_results.recv(id_task); + llama.queue_results.remove_waiting_task_id(id_task); + + // send the result + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/v1/embeddings", [&llama](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + + json prompt; + if (body.count("input") != 0) { + prompt = body["input"]; + // batch + if (prompt.is_array()) { + json data = json::array(); + int i = 0; + for (const json & elem : prompt) { + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); + + // get the result + task_result result = llama.queue_results.recv(id_task); + llama.queue_results.remove_waiting_task_id(id_task); + + json embedding = json{ + {"embedding", json_value(result.result_json, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + + data.push_back(embedding); } - // create and queue the task - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + json result = format_embeddings_response_oaicompat(body, data); - // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); + return res.set_content(result.dump(), "application/json; charset=utf-8"); + } + } else { + prompt = ""; + } - json data = json::array({json{ - {"embedding", json_value(result.result_json, "embedding", json::array())}, - {"index", 0}, - {"object", "embedding"} - }} - ); + // create and queue the task + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); - json root = format_embeddings_response_oaicompat(body, data); + // get the result + task_result result = llama.queue_results.recv(id_task); + llama.queue_results.remove_waiting_task_id(id_task); - // send the result - return res.set_content(root.dump(), "application/json; charset=utf-8"); - }); + json data = json::array({json{ + {"embedding", json_value(result.result_json, "embedding", json::array())}, + {"index", 0}, + {"object", "embedding"} + }} + ); + + json root = format_embeddings_response_oaicompat(body, data); + + return res.set_content(root.dump(), "application/json; charset=utf-8"); + }); if (sparams.n_threads_http < 1) { // +2 threads for monitoring endpoints @@ -3194,17 +2996,16 @@ int main(int argc, char ** argv) { svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); - // run the HTTP server in a thread - see comment below - std::thread t([&]() - { - if (!svr.listen_after_bind()) - { - state.store(SERVER_STATE_ERROR); - return 1; - } - return 0; - }); + // run the HTTP server in a thread - see comment below + std::thread t([&]() { + if (!svr.listen_after_bind()) { + state.store(SERVER_STATE_ERROR); + return 1; + } + + return 0; + }); llama.queue_tasks.on_new_task(std::bind( &llama_server_context::process_single_task, &llama, std::placeholders::_1)); @@ -3241,5 +3042,6 @@ int main(int argc, char ** argv) { t.join(); llama_backend_free(); + return 0; } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 89cd5bcb7..db5654069 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -58,8 +58,8 @@ struct task_server { task_type type; json data; - bool infill_mode = false; - bool embedding_mode = false; + bool infill = false; + bool embedding = false; }; struct task_result { @@ -187,7 +187,8 @@ inline std::string format_chat(const struct llama_model * model, const std::stri res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); } - std::string formatted_chat(buf.data(), res); + const std::string formatted_chat(buf.data(), res); + LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); return formatted_chat; @@ -201,17 +202,18 @@ struct llama_server_queue { int id = 0; bool running; - std::mutex mutex_tasks; - // queues std::vector queue_tasks; std::vector queue_tasks_deferred; - std::vector queue_multitasks; + std::vector queue_multitasks; + + std::mutex mutex_tasks; std::condition_variable condition_tasks; + // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_run_slots; + std::function callback_new_task; + std::function callback_finish_multitask; + std::function callback_run_slots; // Add a new task to the end of the queue int post(task_server task) { @@ -265,10 +267,9 @@ struct llama_server_queue { } // end the start_loop routine - void terminate() { { - std::unique_lock lock(mutex_tasks); - running = false; - } + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; condition_tasks.notify_all(); } @@ -350,14 +351,11 @@ struct llama_server_queue { } // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int subtask_id, task_result& result) - { + void update_multitask(int id_multi, int id_sub, task_result& result) { std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) - { - if (multitask.id == id_multi) - { - multitask.subtasks_remaining.erase(subtask_id); + for (auto & multitask : queue_multitasks) { + if (multitask.id == id_multi) { + multitask.subtasks_remaining.erase(id_sub); multitask.results.push_back(result); } } @@ -468,13 +466,10 @@ static inline std::vector base64_decode(const std::string & encoded_str std::vector ret; - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) - { - for (i = 0; i <4; i++) - { + if (i == 4) { + for (i = 0; i < 4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } @@ -482,23 +477,20 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) - { + for (i = 0; (i < 3); i++) { ret.push_back(char_array_3[i]); } + i = 0; } } - if (i) - { - for (j = i; j <4; j++) - { + if (i) { + for (j = i; j < 4; j++) { char_array_4[j] = 0; } - for (j = 0; j <4; j++) - { + for (j = 0; j < 4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } @@ -506,8 +498,7 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; (j < i - 1); j++) - { + for (j = 0; j < i - 1; j++) { ret.push_back(char_array_3[j]); } } @@ -586,6 +577,7 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) if (out.size() == 1 && (out[0] & 0x80) == 0x80) { @@ -601,6 +593,7 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, // convert a vector of completion_token_output to json static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { json out = json::array(); + for (const auto & prob : probs) { json probs_for_token = json::array();