diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 679301005..8b96a7a62 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,28 +61,29 @@ static bool server_verbose = false; } while (0) #endif -#define LOG_ERROR(MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) json oaicompat_completion_params_parse(const json &body); std::string format_chatml(std::vector messages); + // // base64 utils (TODO: move to common in the future) // static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static std::vector base64_decode(const std::string &encoded_string) +static std::vector base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; @@ -97,18 +98,17 @@ static std::vector base64_decode(const std::string &encoded_string) while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; - in_++; + char_array_4[i++] = encoded_string[in_]; in_++; if (i == 4) { - for (i = 0; i < 4; i++) + for (i = 0; i <4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (i = 0; (i < 3); i++) { @@ -120,19 +120,19 @@ static std::vector base64_decode(const std::string &encoded_string) if (i) { - for (j = i; j < 4; j++) + for (j = i; j <4; j++) { char_array_4[j] = 0; } - for (j = 0; j < 4; j++) + for (j = 0; j <4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } - char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (j = 0; (j < i - 1); j++) { @@ -147,21 +147,18 @@ static std::vector base64_decode(const std::string &encoded_string) // parallel // -enum server_state -{ - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed }; -enum task_type -{ +enum task_type { TASK_TYPE_COMPLETION, TASK_TYPE_CANCEL, }; -struct task_server -{ +struct task_server { int id; int target_id; task_type type; @@ -171,8 +168,7 @@ struct task_server int multitask_id = -1; }; -struct task_result -{ +struct task_result { int id; int multitask_id = -1; bool stop; @@ -180,8 +176,7 @@ struct task_result json result_json; }; -struct task_multi -{ +struct task_multi { int id; std::set subtasks_remaining{}; std::vector results{}; @@ -203,12 +198,12 @@ enum slot_command struct slot_params { - bool stream = true; + bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_predict = -1; // new tokens to predict + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_predict = -1; // new tokens to predict std::vector antiprompt; @@ -221,10 +216,10 @@ struct slot_image int32_t id; bool request_encode_image = false; - float *image_embedding = nullptr; + float * image_embedding = nullptr; int32_t image_tokens = 0; - clip_image_u8 *img_data; + clip_image_u8 * img_data; std::string prefix_prompt; // before of this image }; @@ -300,12 +295,13 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) static void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { - nlohmann::ordered_json log{ + nlohmann::ordered_json log + { {"timestamp", time(nullptr)}, - {"level", level}, - {"function", function}, - {"line", line}, - {"message", message}, + {"level", level}, + {"function", function}, + {"line", line}, + {"message", message}, }; if (!extra.empty()) @@ -344,15 +340,16 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector 0; // no budget } - bool available() const - { + bool available() const { return state == IDLE && command == NONE; } - bool is_processing() const - { + bool is_processing() const { return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; } - void add_token_string(const completion_token_output &token) - { + void add_token_string(const completion_token_output &token) { if (command == RELEASE) { return; @@ -499,8 +490,7 @@ struct llama_client_slot generated_token_probs.push_back(token); } - void release() - { + void release() { if (state == IDLE || state == PROCESSING) { t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; @@ -508,28 +498,27 @@ struct llama_client_slot } } - json get_formated_timings() - { - return json{ - {"prompt_n", num_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, + json get_formated_timings() { + return json + { + {"prompt_n", num_prompt_tokens_processed}, + {"prompt_ms", t_prompt_processing}, + {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, + {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, + {"predicted_n", n_decoded}, + {"predicted_ms", t_token_generation}, {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, + {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, }; } - void print_timings() const - { + void print_timings() const { LOG_TEE("\n"); LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); + __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, t_token_generation, n_decoded, t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded); + __func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded); LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation); } }; @@ -545,21 +534,21 @@ struct llama_server_context llama_batch batch; - bool multimodal = false; - bool clean_kv_cache = true; + bool multimodal = false; + bool clean_kv_cache = true; bool all_slots_are_idle = false; - bool add_bos_token = true; + bool add_bos_token = true; int32_t id_gen; - int32_t n_ctx; // total context for all clients / slots + int32_t n_ctx; // total context for all clients / slots // system prompt bool system_need_update = false; - std::string system_prompt; + std::string system_prompt; std::vector system_tokens; - std::string name_user; // this should be the antiprompt + std::string name_user; // this should be the antiprompt std::string name_assistant; // slots / clients @@ -567,7 +556,7 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; - std::vector queue_multitasks; + std::vector queue_multitasks; std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks std::condition_variable condition_tasks; std::mutex mutex_results; @@ -590,19 +579,16 @@ struct llama_server_context bool load_model(const gpt_params ¶ms_) { params = params_; - if (!params.mmproj.empty()) - { + if (!params.mmproj.empty()) { multimodal = true; LOG_TEE("Multi Modal Mode Enabled"); - clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/1); - if (clp_ctx == nullptr) - { + clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); + if(clp_ctx == nullptr) { LOG_ERROR("unable to load clip model", {{"model", params.mmproj}}); return false; } - if (params.n_ctx < 2048) - { // request larger context for the image embedding + if (params.n_ctx < 2048) { // request larger context for the image embedding params.n_ctx = 2048; } } @@ -614,12 +600,10 @@ struct llama_server_context return false; } - if (multimodal) - { + if (multimodal) { const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); - const int n_embd_llm = llama_n_embd(model); - if (n_embd_clip != n_embd_llm) - { + const int n_embd_llm = llama_n_embd(model); + if (n_embd_clip != n_embd_llm) { LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm); llama_free(ctx); llama_free_model(model); @@ -634,8 +618,7 @@ struct llama_server_context return true; } - void initialize() - { + void initialize() { id_gen = 0; // create slots @@ -663,7 +646,7 @@ struct llama_server_context system_tokens.clear(); } - std::vector tokenize(const json &json_prompt, bool add_bos) const + std::vector tokenize(const json & json_prompt, bool add_bos) const { // TODO: currently, we tokenize using special tokens by default // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) @@ -677,7 +660,7 @@ struct llama_server_context if (json_prompt.is_array()) { bool first = true; - for (const auto &p : json_prompt) + for (const auto& p : json_prompt) { if (p.is_string()) { @@ -713,12 +696,11 @@ struct llama_server_context return prompt_tokens; } - llama_client_slot *get_slot(int id) - { + llama_client_slot* get_slot(int id) { int64_t t_last = ggml_time_us(); llama_client_slot *last_used = nullptr; - for (llama_client_slot &slot : slots) + for (llama_client_slot & slot : slots) { if (slot.id == id && slot.available()) { @@ -735,43 +717,39 @@ struct llama_server_context return last_used; } - bool launch_slot_with_data(llama_client_slot *&slot, json data) - { + bool launch_slot_with_data(llama_client_slot* &slot, json data) { slot_params default_params; llama_sampling_params default_sparams; - if (data.count("__oaicompat") != 0) - { + if (data.count("__oaicompat") != 0) { slot->oaicompat = true; slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - } - else - { + } else { slot->oaicompat = false; slot->oaicompat_model = ""; } - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(data, "seed", default_params.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->params.stream = json_value(data, "stream", false); + slot->params.cache_prompt = json_value(data, "cache_prompt", false); + slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); + slot->params.seed = json_value(data, "seed", default_params.seed); + slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); // infill if (data.count("input_prefix") != 0) @@ -910,8 +888,7 @@ struct llama_server_context std::string prompt = slot->prompt.get(); size_t pos = 0, begin_prefix = 0; std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) - { + while ((pos = prompt.find(pattern, pos)) != std::string::npos) { size_t end_prefix = pos; pos += pattern.length(); size_t end_pos = prompt.find("]", pos); @@ -924,23 +901,19 @@ struct llama_server_context bool found = false; for (slot_image &img : slot->images) { - if (img.id == img_id) - { + if (img.id == img_id) { found = true; img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); begin_prefix = end_pos + 1; break; } } - if (!found) - { + if (!found) { LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id); slot->images.clear(); return false; } - } - catch (const std::invalid_argument &e) - { + } catch (const std::invalid_argument& e) { LOG_TEE("Invalid image number id in prompt\n"); slot->images.clear(); return false; @@ -969,24 +942,22 @@ struct llama_server_context return true; } - void kv_cache_clear() - { + void kv_cache_clear() { // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } - void update_system_prompt() - { + void update_system_prompt() { system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); llama_batch_clear(batch); kv_cache_clear(); - for (int i = 0; i < (int)system_tokens.size(); ++i) + for (int i = 0; i < (int) system_tokens.size(); ++i) { - llama_batch_add(batch, system_tokens[i], i, {0}, false); + llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } if (llama_decode(ctx, batch) != 0) @@ -1005,8 +976,7 @@ struct llama_server_context system_need_update = false; } - void notify_system_prompt_changed() - { + void notify_system_prompt_changed() { // release all slots for (llama_client_slot &slot : slots) { @@ -1016,10 +986,9 @@ struct llama_server_context system_need_update = true; } - void process_system_prompt_data(const json &sys_props) - { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); + void process_system_prompt_data(const json &sys_props) { + system_prompt = sys_props.value("prompt", ""); + name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); if (slots.size() > 0) @@ -1062,8 +1031,7 @@ struct llama_server_context return stop_pos; } - bool process_token(completion_token_output &result, llama_client_slot &slot) - { + bool process_token(completion_token_output &result, llama_client_slot &slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; @@ -1184,8 +1152,8 @@ struct llama_server_context { continue; } - clip_image_f32 *img_res = clip_image_f32_init(); - if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/true)) + clip_image_f32 * img_res = clip_image_f32_init(); + if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true)) { LOG_TEE("Error processing the given image"); clip_free(clp_ctx); @@ -1212,7 +1180,7 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(task_server &task, const std::string &error) + void send_error(task_server& task, const std::string &error) { LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); std::unique_lock lock(mutex_results); @@ -1221,12 +1189,12 @@ struct llama_server_context res.multitask_id = task.multitask_id; res.stop = false; res.error = true; - res.result_json = {{"content", error}}; + res.result_json = { { "content", error } }; queue_results.push_back(res); condition_results.notify_all(); } - void add_multi_task(int id, std::vector &sub_ids) + void add_multi_task(int id, std::vector& sub_ids) { std::lock_guard lock(mutex_tasks); task_multi multi; @@ -1236,10 +1204,10 @@ struct llama_server_context condition_tasks.notify_one(); } - void update_multi_task(int multitask_id, int subtask_id, task_result &result) + void update_multi_task(int multitask_id, int subtask_id, task_result& result) { std::lock_guard lock(mutex_tasks); - for (auto &multitask : queue_multitasks) + for (auto& multitask : queue_multitasks) { if (multitask.id == multitask_id) { @@ -1260,34 +1228,34 @@ struct llama_server_context const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - return json{ - {"n_ctx", slot.n_ctx}, - {"model", params.model_alias}, - {"seed", slot.params.seed}, - {"temperature", slot.sparams.temp}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, + return json { + {"n_ctx", slot.n_ctx}, + {"model", params.model_alias}, + {"seed", slot.params.seed}, + {"temperature", slot.sparams.temp}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"min_p", slot.sparams.min_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.penalty_last_n}, + {"repeat_penalty", slot.sparams.penalty_repeat}, + {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, - {"n_keep", params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"grammar", slot.sparams.grammar}, + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"n_predict", slot.params.n_predict}, + {"n_keep", params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"grammar", slot.sparams.grammar}, }; } @@ -1300,17 +1268,19 @@ struct llama_server_context res.error = false; res.stop = false; - res.result_json = json{ - {"content", tkn.text_to_send}, - {"stop", false}, - {"slot_id", slot.id}, - {"multimodal", multimodal}}; + res.result_json = json + { + {"content", tkn.text_to_send}, + {"stop", false}, + {"slot_id", slot.id}, + {"multimodal", multimodal} + }; if (slot.sparams.n_probs > 0) { std::vector probs_output = {}; const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); + size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); if (probs_pos < probs_stop_pos) { @@ -1339,22 +1309,24 @@ struct llama_server_context res.error = false; res.stop = true; - res.result_json = json{ - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"slot_id", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.num_prompt_tokens}, + res.result_json = json + { + {"content", !slot.params.stream ? slot.generated_text : ""}, + {"slot_id", slot.id}, + {"stop", true}, + {"model", params.model_alias}, + {"tokens_predicted", slot.n_decoded}, + {"tokens_evaluated", slot.num_prompt_tokens}, {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}}; + {"prompt", slot.prompt}, + {"truncated", slot.truncated}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + {"tokens_cached", slot.n_past}, + {"timings", slot.get_formated_timings()} + }; if (slot.sparams.n_probs > 0) { @@ -1367,8 +1339,8 @@ struct llama_server_context else { probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); } res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); } @@ -1407,7 +1379,8 @@ struct llama_server_context LOG_WARNING("embedding disabled", { {"params.embedding", params.embedding}, }); - res.result_json = json{ + res.result_json = json + { {"embedding", std::vector(n_embd, 0.0f)}, }; } @@ -1415,8 +1388,9 @@ struct llama_server_context { const float *data = llama_get_embeddings(ctx); std::vector embedding(data, data + n_embd); - res.result_json = json{ - {"embedding", embedding}, + res.result_json = json + { + {"embedding", embedding }, }; } queue_results.push_back(res); @@ -1453,10 +1427,11 @@ struct llama_server_context while (true) { std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&] - { return !queue_results.empty(); }); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); - for (int i = 0; i < (int)queue_results.size(); i++) + for (int i = 0; i < (int) queue_results.size(); i++) { // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result if (queue_results[i].multitask_id == task_id) @@ -1477,7 +1452,7 @@ struct llama_server_context } // never reached - // return task_result{-1, false, false, {}}; + //return task_result{-1, false, false, {}}; } // for multiple images processing @@ -1485,22 +1460,22 @@ struct llama_server_context { int image_idx = 0; - while (image_idx < (int)slot.images.size()) + while (image_idx < (int) slot.images.size()) { slot_image &img = slot.images[image_idx]; // process prefix prompt - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, + batch.pos + i, batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; if (llama_decode(ctx, batch_view)) @@ -1520,18 +1495,7 @@ struct llama_server_context } const int n_embd = llama_n_embd(model); - llama_batch batch_img = { - n_eval, - nullptr, - (img.image_embedding + i * n_embd), - nullptr, - nullptr, - nullptr, - nullptr, - slot.n_past, - 1, - 0, - }; + llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; if (llama_decode(ctx, batch_img)) { LOG_TEE("%s : failed to eval image\n", __func__); @@ -1544,13 +1508,14 @@ struct llama_server_context llama_batch_clear(batch); // append prefix of next image - const auto json_prompt = (image_idx >= (int)slot.images.size()) ? slot.params.input_suffix : // no more images, then process suffix prompt - (json)(slot.images[image_idx].prefix_prompt); + const auto json_prompt = (image_idx >= (int) slot.images.size()) ? + slot.params.input_suffix : // no more images, then process suffix prompt + (json)(slot.images[image_idx].prefix_prompt); std::vector append_tokens = tokenize(json_prompt, false); // has next image - for (int i = 0; i < (int)append_tokens.size(); ++i) + for (int i = 0; i < (int) append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true); + llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); slot.n_past += 1; } } @@ -1569,7 +1534,7 @@ struct llama_server_context condition_tasks.notify_one(); } - int split_multiprompt_task(task_server &multiprompt_task) + int split_multiprompt_task(task_server& multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); @@ -1599,60 +1564,55 @@ struct llama_server_context queue_tasks.erase(queue_tasks.begin()); switch (task.type) { - case TASK_TYPE_COMPLETION: - { - llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) - { - LOG_TEE("slot unavailable\n"); - // send error result - send_error(task, "slot unavailable"); - break; - } - - if (task.data.contains("system_prompt")) - { - if (!all_slots_are_idle) + case TASK_TYPE_COMPLETION: { + llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) { - send_error(task, "system prompt can only be updated when all slots are idle"); + LOG_TEE("slot unavailable\n"); + // send error result + send_error(task, "slot unavailable"); break; } - process_system_prompt_data(task.data["system_prompt"]); - // reset cache_tokens for all slots - for (llama_client_slot &slot : slots) + if (task.data.contains("system_prompt")) { - slot.cache_tokens.clear(); + if (!all_slots_are_idle) { + send_error(task, "system prompt can only be updated when all slots are idle"); + break; + } + process_system_prompt_data(task.data["system_prompt"]); + + // reset cache_tokens for all slots + for (llama_client_slot &slot : slots) + { + slot.cache_tokens.clear(); + } } - } - slot->reset(); + slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; + slot->multitask_id = task.multitask_id; - if (!launch_slot_with_data(slot, task.data)) - { - // send error result - send_error(task, "internal_error"); - break; - } - } - break; - case TASK_TYPE_CANCEL: - { // release slot linked with the task id - for (auto &slot : slots) - { - if (slot.task_id == task.target_id) + if (!launch_slot_with_data(slot, task.data)) { - slot.release(); + // send error result + send_error(task, "internal_error"); break; } - } - } - break; + } break; + case TASK_TYPE_CANCEL: { // release slot linked with the task id + for (auto & slot : slots) + { + if (slot.task_id == task.target_id) + { + slot.release(); + break; + } + } + } break; } } @@ -1671,12 +1631,13 @@ struct llama_server_context // collect json results into one json result std::vector result_jsons; - for (auto &subres : queue_iterator->results) + for (auto& subres : queue_iterator->results) { result_jsons.push_back(subres.result_json); aggregate_result.error = aggregate_result.error && subres.error; } - aggregate_result.result_json = json{"results", result_jsons}; + aggregate_result.result_json = json{ "results", result_jsons }; + agg_results.push_back(aggregate_result); @@ -1698,8 +1659,7 @@ struct llama_server_context queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end()); } - bool update_slots() - { + bool update_slots() { // attend tasks process_tasks(); @@ -1719,20 +1679,21 @@ struct llama_server_context kv_cache_clear(); } std::unique_lock lock(mutex_tasks); - condition_tasks.wait(lock, [&] - { return !queue_tasks.empty(); }); + condition_tasks.wait(lock, [&]{ + return !queue_tasks.empty(); + }); } for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t)slot.n_ctx) + if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) @@ -1747,15 +1708,15 @@ struct llama_server_context slot.truncated = true; LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + }); } } // decode any currently ongoing sequences - for (auto &slot : slots) + for (auto & slot : slots) { // release the slot if (slot.command == RELEASE) @@ -1764,7 +1725,7 @@ struct llama_server_context slot.command = NONE; slot.t_last_used = ggml_time_us(); - LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int)slot.cache_tokens.size()); + LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size()); continue; } @@ -1776,7 +1737,7 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, {slot.id}, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; } @@ -1787,7 +1748,7 @@ struct llama_server_context // assign workload to the slots if (params.cont_batching || batch.n_tokens == 0) { - for (auto &slot : slots) + for (auto & slot : slots) { const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()) || !slot.images.empty(); @@ -1822,8 +1783,7 @@ struct llama_server_context auto suffix_tokens = tokenize(slot.params.input_suffix, false); const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) - { + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { suffix_tokens.erase(suffix_tokens.begin()); } @@ -1836,7 +1796,7 @@ struct llama_server_context } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt + prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt } slot.num_prompt_tokens = prompt_tokens.size(); @@ -1858,11 +1818,11 @@ struct llama_server_context new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); slot.truncated = true; prompt_tokens = new_tokens; @@ -1891,7 +1851,7 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } - LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int)system_tokens.size() + slot.n_past); + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); @@ -1905,18 +1865,18 @@ struct llama_server_context } LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, - }); + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, + }); const bool has_images = process_images(slot); // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) + for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id}, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); } if (has_images && !ingest_images(slot, n_batch)) @@ -1932,7 +1892,7 @@ struct llama_server_context } slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.n_tokens - 1; } } } @@ -1943,20 +1903,20 @@ struct llama_server_context return true; } - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = - { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; + { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; const int ret = llama_decode(ctx, batch_view); if (ret != 0) @@ -1976,9 +1936,9 @@ struct llama_server_context continue; } - for (auto &slot : slots) + for (auto & slot : slots) { - if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; } @@ -2004,7 +1964,7 @@ struct llama_server_context slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; } - llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false}; + llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; const int32_t n_probs = slot.sparams.n_probs; @@ -2108,7 +2068,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, } static void server_params_parse(int argc, char **argv, server_params &sparams, - gpt_params ¶ms, llama_server_context &llama) + gpt_params ¶ms, llama_server_context& llama) { gpt_params default_params; server_params default_sparams; @@ -2162,19 +2122,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::ifstream key_file(argv[i]); - if (!key_file) - { + if (!key_file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; break; } std::string key; - while (std::getline(key_file, key)) - { - if (key.size() > 0) - { - sparams.api_keys.push_back(key); - } + while (std::getline(key_file, key)) { + if (key.size() > 0) { + sparams.api_keys.push_back(key); + } } key_file.close(); } @@ -2228,23 +2185,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::string value(argv[i]); - /**/ if (value == "none") - { - params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; - } - else if (value == "linear") - { - params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; - } - else if (value == "yarn") - { - params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; - } - else - { - invalid_param = true; - break; - } + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + else { invalid_param = true; break; } } else if (arg == "--rope-freq-base") { @@ -2266,8 +2210,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-ext-factor") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } @@ -2275,8 +2218,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-attn-factor") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } @@ -2284,8 +2226,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-beta-fast") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } @@ -2293,8 +2234,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-beta-slow") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } @@ -2363,8 +2303,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } @@ -2381,8 +2320,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, { params.split_mode = LLAMA_SPLIT_ROW; } - else - { + else { invalid_param = true; break; } @@ -2459,7 +2397,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - const char *lora_adapter = argv[i]; + const char * lora_adapter = argv[i]; if (++i >= argc) { invalid_param = true; @@ -2513,8 +2451,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_parallel = std::stoi(argv[i]); - } - else if (arg == "-n" || arg == "--n-predict") + } else if (arg == "-n" || arg == "--n-predict") { if (++i >= argc) { @@ -2522,8 +2459,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_predict = std::stoi(argv[i]); - } - else if (arg == "-spf" || arg == "--system-prompt-file") + } else if (arg == "-spf" || arg == "--system-prompt-file") { if (++i >= argc) { @@ -2531,8 +2467,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::ifstream file(argv[i]); - if (!file) - { + if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; break; @@ -2541,10 +2476,11 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(systm_content)); + std::back_inserter(systm_content) + ); llama.process_system_prompt_data(json::parse(systm_content)); } - else if (arg == "--mmproj") + else if(arg == "--mmproj") { if (++i >= argc) { @@ -2560,14 +2496,12 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--override-kv") { - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } - char *sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) - { + char * sep = strchr(argv[i], '='); + if (sep == nullptr || sep - argv[i] >= 128) { fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); invalid_param = true; break; @@ -2576,39 +2510,27 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::strncpy(kvo.key, argv[i], sep - argv[i]); kvo.key[sep - argv[i]] = 0; sep++; - if (strncmp(sep, "int:", 4) == 0) - { + if (strncmp(sep, "int:", 4) == 0) { sep += 4; kvo.tag = LLAMA_KV_OVERRIDE_INT; kvo.int_value = std::atol(sep); - } - else if (strncmp(sep, "float:", 6) == 0) - { + } else if (strncmp(sep, "float:", 6) == 0) { sep += 6; kvo.tag = LLAMA_KV_OVERRIDE_FLOAT; kvo.float_value = std::atof(sep); - } - else if (strncmp(sep, "bool:", 5) == 0) - { + } else if (strncmp(sep, "bool:", 5) == 0) { sep += 5; kvo.tag = LLAMA_KV_OVERRIDE_BOOL; - if (std::strcmp(sep, "true") == 0) - { + if (std::strcmp(sep, "true") == 0) { kvo.bool_value = true; - } - else if (std::strcmp(sep, "false") == 0) - { + } else if (std::strcmp(sep, "false") == 0) { kvo.bool_value = false; - } - else - { + } else { fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); invalid_param = true; break; } - } - else - { + } else { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; break; @@ -2622,8 +2544,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, exit(1); } } - if (!params.kv_overrides.empty()) - { + if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(llama_model_kv_override()); params.kv_overrides.back().key[0] = 0; } @@ -2645,8 +2566,7 @@ static std::string random_string() std::string result(32, ' '); - for (int i = 0; i < 32; ++i) - { + for (int i = 0; i < 32; ++i) { result[i] = str[generator() % str.size()]; } @@ -2664,10 +2584,9 @@ std::string format_chatml(std::vector messages) { std::ostringstream chatml_msgs; - for (auto it = messages.begin(); it != messages.end(); ++it) - { + for (auto it = messages.begin(); it != messages.end(); ++it) { chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; + << json_value(*it, "role", std::string("user")) << '\n'; chatml_msgs << json_value(*it, "content", std::string("")) << "<|im_end|>\n"; } @@ -2693,39 +2612,35 @@ json oaicompat_completion_params_parse( // // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' - llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); - llama_params["temperature"] = json_value(body, "temperature", 0.0); - llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); - llama_params["top_p"] = json_value(body, "top_p", 1.0); - llama_params["n_predict"] = json_value(body, "max_tokens", -1); - llama_params["logit_bias"] = json_value(body, "logit_bias", json::object()); + llama_params["model"] = json_value(body, "model", std::string("unknown")); + llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); + llama_params["temperature"] = json_value(body, "temperature", 0.0); + llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); + llama_params["top_p"] = json_value(body, "top_p", 1.0); + llama_params["n_predict"] = json_value(body, "max_tokens", -1); + llama_params["logit_bias"] = json_value(body, "logit_bias",json::object()); llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); - llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); - llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); - llama_params["stream"] = json_value(body, "stream", false); - llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); - llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); - llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); - llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); - llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); + llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); + llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); + llama_params["stream"] = json_value(body, "stream", false); + llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); + llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); + llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); + llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); + llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); - if (body.count("grammar") != 0) - { + if (body.count("grammar") != 0) { llama_params["grammar"] = json_value(body, "grammar", json::object()); } // Handle 'stop' field - if (body.contains("stop") && body["stop"].is_string()) - { + if (body.contains("stop") && body["stop"].is_string()) { llama_params["stop"] = json::array({body["stop"].get()}); - } - else - { + } else { llama_params["stop"] = json_value(body, "stop", json::array()); } @@ -2739,48 +2654,45 @@ static json format_final_response_oaicompat(const json &request, const task_resu { json result = response.result_json; - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); std::string finish_reason = "length"; - if (stopped_word || stopped_eos) - { + if (stopped_word || stopped_eos) { finish_reason = "stop"; } json choices = streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) + {"index", 0}, + {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}}}); std::time_t t = std::time(0); json res = json{{"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", - json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", + json{{"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, + {"id", gen_chatcmplid()}}; - if (server_verbose) - { + if (server_verbose) { res["__verbose"] = result; } - if (result.contains("completion_probabilities")) - { + if (result.contains("completion_probabilities")) { res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); } @@ -2788,30 +2700,26 @@ static json format_final_response_oaicompat(const json &request, const task_resu } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const task_result &response) -{ +static std::vector format_partial_response_oaicompat(const task_result &response) { json result = response.result_json; - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) - { + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { return std::vector({response.result_json}); } bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); std::string content = json_value(result, "content", std::string("")); std::string finish_reason; - if (stopped_word || stopped_eos) - { + if (stopped_word || stopped_eos) { finish_reason = "stop"; } - if (stopped_limit) - { + if (stopped_limit) { finish_reason = "length"; } @@ -2819,54 +2727,46 @@ static std::vector format_partial_response_oaicompat(const task_result &re json choices; - if (!finish_reason.empty()) - { + if (!finish_reason.empty()) { choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}); - } - else - { - if (first) - { - if (content.empty()) - { + } else { + if (first) { + if (content.empty()) { choices = json::array({json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); - } - else - { + } else { // We have to send this as two updates to conform to openai behavior json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"}}}}})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}}}})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; return std::vector({initial_ret, second_ret}); } - } - else - { + } else { // Some idiosyncrasy in task processing logic makes several trailing calls // with empty content, we ignore these at the calee site. - if (content.empty()) - { + if (content.empty()) { return std::vector({json::object()}); } @@ -2874,9 +2774,9 @@ static std::vector format_partial_response_oaicompat(const task_result &re {"finish_reason", nullptr}, {"index", 0}, {"delta", - json{ - {"content", content}, - }}, + json{ + {"content", content}, + }}, }}); } } @@ -2891,13 +2791,15 @@ static std::vector format_partial_response_oaicompat(const task_result &re } static json format_partial_response( - llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs) -{ - json res = json{ - {"content", content}, - {"stop", false}, - {"slot_id", slot->id}, - {"multimodal", llama.multimodal}}; + llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs +) { + json res = json + { + {"content", content }, + {"stop", false}, + {"slot_id", slot->id }, + {"multimodal", llama.multimodal } + }; if (slot->sparams.n_probs > 0) { @@ -2919,6 +2821,7 @@ static json format_detokenized_response(std::string content) {"content", content}}; } + static void log_server_request(const httplib::Request &req, const httplib::Response &res) { LOG_INFO("request", { @@ -2938,23 +2841,22 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo struct token_translator { - llama_context *ctx; - std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } + llama_context * ctx; + std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } }; static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot *slot) { - auto >ps = slot->generated_token_probs; + auto & gtps = slot->generated_token_probs; auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output &cto) - { return sum + translator(cto).size(); }; + auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); if (slot->generated_text.capacity() < slot->generated_text.size() + len) { slot->generated_text.reserve(slot->generated_text.size() + len); } - for (const completion_token_output &cto : gtps) + for (const completion_token_output & cto : gtps) { slot->generated_text += translator(cto); } @@ -2998,15 +2900,14 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) - { + svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); }); + res.set_header("Access-Control-Allow-Headers", "*"); + }); - svr.Get("/health", [&](const httplib::Request &, httplib::Response &res) - { + svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { server_state current_state = state.load(); switch(current_state) { case SERVER_STATE_READY: @@ -3021,12 +2922,13 @@ int main(int argc, char **argv) res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); res.status = 500; // HTTP Internal Server Error break; - } }); + } + }); svr.set_logger(log_server_request); svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { + { const char fmt[] = "500 Internal Server Error\n%s"; char buf[BUFSIZ]; try @@ -3042,10 +2944,11 @@ int main(int argc, char **argv) snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); } res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; }); + res.status = 500; + }); svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { + { if (res.status == 401) { res.set_content("Unauthorized", "text/plain; charset=utf-8"); @@ -3058,10 +2961,11 @@ int main(int argc, char **argv) { res.set_content("File Not Found", "text/plain; charset=utf-8"); res.status = 404; - } }); + } + }); // set timeouts and change hostname and port - svr.set_read_timeout(sparams.read_timeout); + svr.set_read_timeout (sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); if (!svr.bind_to_port(sparams.hostname, sparams.port)) @@ -3080,57 +2984,49 @@ int main(int argc, char **argv) log_data["hostname"] = sparams.hostname; log_data["port"] = std::to_string(sparams.port); - if (sparams.api_keys.size() == 1) - { + if (sparams.api_keys.size() == 1) { log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); - } - else if (sparams.api_keys.size() > 1) - { + } else if (sparams.api_keys.size() > 1) { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } LOG_INFO("HTTP server listening", log_data); // run the HTTP server in a thread - see comment below std::thread t([&]() - { + { if (!svr.listen_after_bind()) { state.store(SERVER_STATE_ERROR); return 1; } - return 0; }); + return 0; + }); // load the model if (!llama.load_model(params)) { state.store(SERVER_STATE_ERROR); return 1; - } - else - { + } else { llama.initialize(); state.store(SERVER_STATE_READY); LOG_INFO("model loaded", {}); } // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool - { + auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { // If API key is not set, skip validation - if (sparams.api_keys.empty()) - { + if (sparams.api_keys.empty()) { return true; } // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) - { + if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) - { + if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { return true; // API key is valid } } @@ -3148,37 +3044,42 @@ int main(int argc, char **argv) svr.Get("/", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); - return false; }); + return false; + }); // this is only called if no index.js is found in the public --path svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); - return false; }); + return false; + }); // this is only called if no index.html is found in the public --path svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); - return false; }); + return false; + }); // this is only called if no index.html is found in the public --path svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); - return false; }); + return false; + }); - svr.Get("/props", [&llama](const httplib::Request &req, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() } }; - res.set_content(data.dump(), "application/json; charset=utf-8"); }); + res.set_content(data.dump(), "application/json; charset=utf-8"); + }); svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3244,9 +3145,10 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } }); + } + }); - svr.Get("/v1/models", [¶ms](const httplib::Request &req, httplib::Response &res) + svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); @@ -3263,11 +3165,13 @@ int main(int argc, char **argv) }} }; - res.set_content(models.dump(), "application/json; charset=utf-8"); }); + res.set_content(models.dump(), "application/json; charset=utf-8"); + }); + // TODO: add mount point without "/v1" prefix -- how? svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3337,10 +3241,11 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } }); + } + }); svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3400,18 +3305,20 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } }); + } + }); svr.Get("/model.json", [&llama](const httplib::Request &, httplib::Response &res) { const json data = llama.get_model_props(); - return res.set_content(data.dump(), "application/json; charset=utf-8"); }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res) { return res.set_content("", "application/json; charset=utf-8"); }); svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; @@ -3420,10 +3327,11 @@ int main(int argc, char **argv) tokens = llama.tokenize(body["content"], false); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; @@ -3434,10 +3342,11 @@ int main(int argc, char **argv) } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; @@ -3461,11 +3370,12 @@ int main(int argc, char **argv) const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); task_result result = llama.next_result(task_id); - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux - // std::thread t2([&]() + //std::thread t2([&]() { bool running = true; while (running)