diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93f999298..679301005 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,29 +61,28 @@ static bool server_verbose = false; } while (0) #endif -#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_ERROR(MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) json oaicompat_completion_params_parse(const json &body); std::string format_chatml(std::vector messages); - // // base64 utils (TODO: move to common in the future) // static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static std::vector base64_decode(const std::string & encoded_string) +static std::vector base64_decode(const std::string &encoded_string) { int i = 0; int j = 0; @@ -98,17 +97,18 @@ static std::vector base64_decode(const std::string & encoded_string) while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; + char_array_4[i++] = encoded_string[in_]; + in_++; if (i == 4) { - for (i = 0; i <4; i++) + for (i = 0; i < 4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (i = 0; (i < 3); i++) { @@ -120,19 +120,19 @@ static std::vector base64_decode(const std::string & encoded_string) if (i) { - for (j = i; j <4; j++) + for (j = i; j < 4; j++) { char_array_4[j] = 0; } - for (j = 0; j <4; j++) + for (j = 0; j < 4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (j = 0; (j < i - 1); j++) { @@ -147,18 +147,21 @@ static std::vector base64_decode(const std::string & encoded_string) // parallel // -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed +enum server_state +{ + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed }; -enum task_type { +enum task_type +{ TASK_TYPE_COMPLETION, TASK_TYPE_CANCEL, }; -struct task_server { +struct task_server +{ int id; int target_id; task_type type; @@ -168,7 +171,8 @@ struct task_server { int multitask_id = -1; }; -struct task_result { +struct task_result +{ int id; int multitask_id = -1; bool stop; @@ -176,7 +180,8 @@ struct task_result { json result_json; }; -struct task_multi { +struct task_multi +{ int id; std::set subtasks_remaining{}; std::vector results{}; @@ -198,12 +203,12 @@ enum slot_command struct slot_params { - bool stream = true; + bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - uint32_t seed = -1; // RNG seed - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_predict = -1; // new tokens to predict + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_predict = -1; // new tokens to predict std::vector antiprompt; @@ -216,10 +221,10 @@ struct slot_image int32_t id; bool request_encode_image = false; - float * image_embedding = nullptr; + float *image_embedding = nullptr; int32_t image_tokens = 0; - clip_image_u8 * img_data; + clip_image_u8 *img_data; std::string prefix_prompt; // before of this image }; @@ -295,13 +300,12 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) static void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { - nlohmann::ordered_json log - { + nlohmann::ordered_json log{ {"timestamp", time(nullptr)}, - {"level", level}, - {"function", function}, - {"line", line}, - {"message", message}, + {"level", level}, + {"function", function}, + {"line", line}, + {"message", message}, }; if (!extra.empty()) @@ -340,16 +344,15 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector 0; // no budget } - bool available() const { + bool available() const + { return state == IDLE && command == NONE; } - bool is_processing() const { + bool is_processing() const + { return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; } - void add_token_string(const completion_token_output &token) { + void add_token_string(const completion_token_output &token) + { if (command == RELEASE) { return; @@ -490,7 +499,8 @@ struct llama_client_slot generated_token_probs.push_back(token); } - void release() { + void release() + { if (state == IDLE || state == PROCESSING) { t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; @@ -498,27 +508,28 @@ struct llama_client_slot } } - json get_formated_timings() { - return json - { - {"prompt_n", num_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, + json get_formated_timings() + { + return json{ + {"prompt_n", num_prompt_tokens_processed}, + {"prompt_ms", t_prompt_processing}, + {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, + {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, + {"predicted_n", n_decoded}, + {"predicted_ms", t_token_generation}, {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, + {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, }; } - void print_timings() const { + void print_timings() const + { LOG_TEE("\n"); LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); + __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded); + __func__, t_token_generation, n_decoded, t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded); LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation); } }; @@ -534,21 +545,21 @@ struct llama_server_context llama_batch batch; - bool multimodal = false; - bool clean_kv_cache = true; + bool multimodal = false; + bool clean_kv_cache = true; bool all_slots_are_idle = false; - bool add_bos_token = true; + bool add_bos_token = true; int32_t id_gen; - int32_t n_ctx; // total context for all clients / slots + int32_t n_ctx; // total context for all clients / slots // system prompt bool system_need_update = false; - std::string system_prompt; + std::string system_prompt; std::vector system_tokens; - std::string name_user; // this should be the antiprompt + std::string name_user; // this should be the antiprompt std::string name_assistant; // slots / clients @@ -556,7 +567,7 @@ struct llama_server_context std::vector queue_tasks; std::vector queue_results; - std::vector queue_multitasks; + std::vector queue_multitasks; std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks std::condition_variable condition_tasks; std::mutex mutex_results; @@ -579,16 +590,19 @@ struct llama_server_context bool load_model(const gpt_params ¶ms_) { params = params_; - if (!params.mmproj.empty()) { + if (!params.mmproj.empty()) + { multimodal = true; LOG_TEE("Multi Modal Mode Enabled"); - clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); - if(clp_ctx == nullptr) { + clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/1); + if (clp_ctx == nullptr) + { LOG_ERROR("unable to load clip model", {{"model", params.mmproj}}); return false; } - if (params.n_ctx < 2048) { // request larger context for the image embedding + if (params.n_ctx < 2048) + { // request larger context for the image embedding params.n_ctx = 2048; } } @@ -600,10 +614,12 @@ struct llama_server_context return false; } - if (multimodal) { + if (multimodal) + { const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); - const int n_embd_llm = llama_n_embd(model); - if (n_embd_clip != n_embd_llm) { + const int n_embd_llm = llama_n_embd(model); + if (n_embd_clip != n_embd_llm) + { LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm); llama_free(ctx); llama_free_model(model); @@ -618,7 +634,8 @@ struct llama_server_context return true; } - void initialize() { + void initialize() + { id_gen = 0; // create slots @@ -646,7 +663,7 @@ struct llama_server_context system_tokens.clear(); } - std::vector tokenize(const json & json_prompt, bool add_bos) const + std::vector tokenize(const json &json_prompt, bool add_bos) const { // TODO: currently, we tokenize using special tokens by default // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) @@ -660,7 +677,7 @@ struct llama_server_context if (json_prompt.is_array()) { bool first = true; - for (const auto& p : json_prompt) + for (const auto &p : json_prompt) { if (p.is_string()) { @@ -696,11 +713,12 @@ struct llama_server_context return prompt_tokens; } - llama_client_slot* get_slot(int id) { + llama_client_slot *get_slot(int id) + { int64_t t_last = ggml_time_us(); llama_client_slot *last_used = nullptr; - for (llama_client_slot & slot : slots) + for (llama_client_slot &slot : slots) { if (slot.id == id && slot.available()) { @@ -717,39 +735,43 @@ struct llama_server_context return last_used; } - bool launch_slot_with_data(llama_client_slot* &slot, json data) { + bool launch_slot_with_data(llama_client_slot *&slot, json data) + { slot_params default_params; llama_sampling_params default_sparams; - if (data.count("__oaicompat") != 0) { + if (data.count("__oaicompat") != 0) + { slot->oaicompat = true; slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - } else { + } + else + { slot->oaicompat = false; slot->oaicompat_model = ""; } - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(data, "seed", default_params.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->params.stream = json_value(data, "stream", false); + slot->params.cache_prompt = json_value(data, "cache_prompt", false); + slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); + slot->params.seed = json_value(data, "seed", default_params.seed); + slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); // infill if (data.count("input_prefix") != 0) @@ -888,7 +910,8 @@ struct llama_server_context std::string prompt = slot->prompt.get(); size_t pos = 0, begin_prefix = 0; std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) { + while ((pos = prompt.find(pattern, pos)) != std::string::npos) + { size_t end_prefix = pos; pos += pattern.length(); size_t end_pos = prompt.find("]", pos); @@ -901,19 +924,23 @@ struct llama_server_context bool found = false; for (slot_image &img : slot->images) { - if (img.id == img_id) { + if (img.id == img_id) + { found = true; img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); begin_prefix = end_pos + 1; break; } } - if (!found) { + if (!found) + { LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id); slot->images.clear(); return false; } - } catch (const std::invalid_argument& e) { + } + catch (const std::invalid_argument &e) + { LOG_TEE("Invalid image number id in prompt\n"); slot->images.clear(); return false; @@ -942,22 +969,24 @@ struct llama_server_context return true; } - void kv_cache_clear() { + void kv_cache_clear() + { // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } - void update_system_prompt() { + void update_system_prompt() + { system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); llama_batch_clear(batch); kv_cache_clear(); - for (int i = 0; i < (int) system_tokens.size(); ++i) + for (int i = 0; i < (int)system_tokens.size(); ++i) { - llama_batch_add(batch, system_tokens[i], i, { 0 }, false); + llama_batch_add(batch, system_tokens[i], i, {0}, false); } if (llama_decode(ctx, batch) != 0) @@ -976,7 +1005,8 @@ struct llama_server_context system_need_update = false; } - void notify_system_prompt_changed() { + void notify_system_prompt_changed() + { // release all slots for (llama_client_slot &slot : slots) { @@ -986,9 +1016,10 @@ struct llama_server_context system_need_update = true; } - void process_system_prompt_data(const json &sys_props) { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); + void process_system_prompt_data(const json &sys_props) + { + system_prompt = sys_props.value("prompt", ""); + name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); if (slots.size() > 0) @@ -1031,7 +1062,8 @@ struct llama_server_context return stop_pos; } - bool process_token(completion_token_output &result, llama_client_slot &slot) { + bool process_token(completion_token_output &result, llama_client_slot &slot) + { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; @@ -1152,8 +1184,8 @@ struct llama_server_context { continue; } - clip_image_f32 * img_res = clip_image_f32_init(); - if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true)) + clip_image_f32 *img_res = clip_image_f32_init(); + if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/true)) { LOG_TEE("Error processing the given image"); clip_free(clp_ctx); @@ -1180,7 +1212,7 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(task_server& task, const std::string &error) + void send_error(task_server &task, const std::string &error) { LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); std::unique_lock lock(mutex_results); @@ -1189,12 +1221,12 @@ struct llama_server_context res.multitask_id = task.multitask_id; res.stop = false; res.error = true; - res.result_json = { { "content", error } }; + res.result_json = {{"content", error}}; queue_results.push_back(res); condition_results.notify_all(); } - void add_multi_task(int id, std::vector& sub_ids) + void add_multi_task(int id, std::vector &sub_ids) { std::lock_guard lock(mutex_tasks); task_multi multi; @@ -1204,10 +1236,10 @@ struct llama_server_context condition_tasks.notify_one(); } - void update_multi_task(int multitask_id, int subtask_id, task_result& result) + void update_multi_task(int multitask_id, int subtask_id, task_result &result) { std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) + for (auto &multitask : queue_multitasks) { if (multitask.id == multitask_id) { @@ -1228,34 +1260,34 @@ struct llama_server_context const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - return json { - {"n_ctx", slot.n_ctx}, - {"model", params.model_alias}, - {"seed", slot.params.seed}, - {"temperature", slot.sparams.temp}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, + return json{ + {"n_ctx", slot.n_ctx}, + {"model", params.model_alias}, + {"seed", slot.params.seed}, + {"temperature", slot.sparams.temp}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"min_p", slot.sparams.min_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.penalty_last_n}, + {"repeat_penalty", slot.sparams.penalty_repeat}, + {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, - {"n_keep", params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"grammar", slot.sparams.grammar}, + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"n_predict", slot.params.n_predict}, + {"n_keep", params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"grammar", slot.sparams.grammar}, }; } @@ -1268,19 +1300,17 @@ struct llama_server_context res.error = false; res.stop = false; - res.result_json = json - { - {"content", tkn.text_to_send}, - {"stop", false}, - {"slot_id", slot.id}, - {"multimodal", multimodal} - }; + res.result_json = json{ + {"content", tkn.text_to_send}, + {"stop", false}, + {"slot_id", slot.id}, + {"multimodal", multimodal}}; if (slot.sparams.n_probs > 0) { std::vector probs_output = {}; const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); + size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); if (probs_pos < probs_stop_pos) { @@ -1309,24 +1339,22 @@ struct llama_server_context res.error = false; res.stop = true; - res.result_json = json - { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"slot_id", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.num_prompt_tokens}, + res.result_json = json{ + {"content", !slot.params.stream ? slot.generated_text : ""}, + {"slot_id", slot.id}, + {"stop", true}, + {"model", params.model_alias}, + {"tokens_predicted", slot.n_decoded}, + {"tokens_evaluated", slot.num_prompt_tokens}, {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()} - }; + {"prompt", slot.prompt}, + {"truncated", slot.truncated}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + {"tokens_cached", slot.n_past}, + {"timings", slot.get_formated_timings()}}; if (slot.sparams.n_probs > 0) { @@ -1339,8 +1367,8 @@ struct llama_server_context else { probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); } res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); } @@ -1379,8 +1407,7 @@ struct llama_server_context LOG_WARNING("embedding disabled", { {"params.embedding", params.embedding}, }); - res.result_json = json - { + res.result_json = json{ {"embedding", std::vector(n_embd, 0.0f)}, }; } @@ -1388,9 +1415,8 @@ struct llama_server_context { const float *data = llama_get_embeddings(ctx); std::vector embedding(data, data + n_embd); - res.result_json = json - { - {"embedding", embedding }, + res.result_json = json{ + {"embedding", embedding}, }; } queue_results.push_back(res); @@ -1427,11 +1453,10 @@ struct llama_server_context while (true) { std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); + condition_results.wait(lock, [&] + { return !queue_results.empty(); }); - for (int i = 0; i < (int) queue_results.size(); i++) + for (int i = 0; i < (int)queue_results.size(); i++) { // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result if (queue_results[i].multitask_id == task_id) @@ -1452,7 +1477,7 @@ struct llama_server_context } // never reached - //return task_result{-1, false, false, {}}; + // return task_result{-1, false, false, {}}; } // for multiple images processing @@ -1460,22 +1485,22 @@ struct llama_server_context { int image_idx = 0; - while (image_idx < (int) slot.images.size()) + while (image_idx < (int)slot.images.size()) { slot_image &img = slot.images[image_idx]; // process prefix prompt - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, + batch.pos + i, batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; if (llama_decode(ctx, batch_view)) @@ -1495,7 +1520,18 @@ struct llama_server_context } const int n_embd = llama_n_embd(model); - llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; + llama_batch batch_img = { + n_eval, + nullptr, + (img.image_embedding + i * n_embd), + nullptr, + nullptr, + nullptr, + nullptr, + slot.n_past, + 1, + 0, + }; if (llama_decode(ctx, batch_img)) { LOG_TEE("%s : failed to eval image\n", __func__); @@ -1508,14 +1544,13 @@ struct llama_server_context llama_batch_clear(batch); // append prefix of next image - const auto json_prompt = (image_idx >= (int) slot.images.size()) ? - slot.params.input_suffix : // no more images, then process suffix prompt - (json)(slot.images[image_idx].prefix_prompt); + const auto json_prompt = (image_idx >= (int)slot.images.size()) ? slot.params.input_suffix : // no more images, then process suffix prompt + (json)(slot.images[image_idx].prefix_prompt); std::vector append_tokens = tokenize(json_prompt, false); // has next image - for (int i = 0; i < (int) append_tokens.size(); ++i) + for (int i = 0; i < (int)append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); + llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true); slot.n_past += 1; } } @@ -1534,7 +1569,7 @@ struct llama_server_context condition_tasks.notify_one(); } - int split_multiprompt_task(task_server& multiprompt_task) + int split_multiprompt_task(task_server &multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); @@ -1564,55 +1599,60 @@ struct llama_server_context queue_tasks.erase(queue_tasks.begin()); switch (task.type) { - case TASK_TYPE_COMPLETION: { - llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) + case TASK_TYPE_COMPLETION: + { + llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) + { + LOG_TEE("slot unavailable\n"); + // send error result + send_error(task, "slot unavailable"); + break; + } + + if (task.data.contains("system_prompt")) + { + if (!all_slots_are_idle) { - LOG_TEE("slot unavailable\n"); - // send error result - send_error(task, "slot unavailable"); + send_error(task, "system prompt can only be updated when all slots are idle"); break; } + process_system_prompt_data(task.data["system_prompt"]); - if (task.data.contains("system_prompt")) + // reset cache_tokens for all slots + for (llama_client_slot &slot : slots) { - if (!all_slots_are_idle) { - send_error(task, "system prompt can only be updated when all slots are idle"); - break; - } - process_system_prompt_data(task.data["system_prompt"]); - - // reset cache_tokens for all slots - for (llama_client_slot &slot : slots) - { - slot.cache_tokens.clear(); - } + slot.cache_tokens.clear(); } + } - slot->reset(); + slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; + slot->multitask_id = task.multitask_id; - if (!launch_slot_with_data(slot, task.data)) + if (!launch_slot_with_data(slot, task.data)) + { + // send error result + send_error(task, "internal_error"); + break; + } + } + break; + case TASK_TYPE_CANCEL: + { // release slot linked with the task id + for (auto &slot : slots) + { + if (slot.task_id == task.target_id) { - // send error result - send_error(task, "internal_error"); + slot.release(); break; } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) - { - if (slot.task_id == task.target_id) - { - slot.release(); - break; - } - } - } break; + } + } + break; } } @@ -1631,13 +1671,12 @@ struct llama_server_context // collect json results into one json result std::vector result_jsons; - for (auto& subres : queue_iterator->results) + for (auto &subres : queue_iterator->results) { result_jsons.push_back(subres.result_json); aggregate_result.error = aggregate_result.error && subres.error; } - aggregate_result.result_json = json{ "results", result_jsons }; - + aggregate_result.result_json = json{"results", result_jsons}; agg_results.push_back(aggregate_result); @@ -1659,7 +1698,8 @@ struct llama_server_context queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end()); } - bool update_slots() { + bool update_slots() + { // attend tasks process_tasks(); @@ -1679,21 +1719,20 @@ struct llama_server_context kv_cache_clear(); } std::unique_lock lock(mutex_tasks); - condition_tasks.wait(lock, [&]{ - return !queue_tasks.empty(); - }); + condition_tasks.wait(lock, [&] + { return !queue_tasks.empty(); }); } for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + if (slot.is_processing() && slot.cache_tokens.size() >= (size_t)slot.n_ctx) { // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) @@ -1708,15 +1747,15 @@ struct llama_server_context slot.truncated = true; LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + }); } } // decode any currently ongoing sequences - for (auto & slot : slots) + for (auto &slot : slots) { // release the slot if (slot.command == RELEASE) @@ -1725,7 +1764,7 @@ struct llama_server_context slot.command = NONE; slot.t_last_used = ggml_time_us(); - LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size()); + LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int)slot.cache_tokens.size()); continue; } @@ -1737,7 +1776,7 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, {slot.id}, true); slot.n_past += 1; } @@ -1748,7 +1787,7 @@ struct llama_server_context // assign workload to the slots if (params.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) + for (auto &slot : slots) { const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()) || !slot.images.empty(); @@ -1783,7 +1822,8 @@ struct llama_server_context auto suffix_tokens = tokenize(slot.params.input_suffix, false); const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) + { suffix_tokens.erase(suffix_tokens.begin()); } @@ -1796,7 +1836,7 @@ struct llama_server_context } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt + prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt } slot.num_prompt_tokens = prompt_tokens.size(); @@ -1818,11 +1858,11 @@ struct llama_server_context new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); slot.truncated = true; prompt_tokens = new_tokens; @@ -1851,7 +1891,7 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } - LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int)system_tokens.size() + slot.n_past); llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); @@ -1865,18 +1905,18 @@ struct llama_server_context } LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, - }); + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, + }); const bool has_images = process_images(slot); // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) + for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id}, false); } if (has_images && !ingest_images(slot, n_batch)) @@ -1892,7 +1932,7 @@ struct llama_server_context } slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.n_tokens - 1; } } } @@ -1903,20 +1943,20 @@ struct llama_server_context return true; } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); llama_batch batch_view = - { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; + { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; const int ret = llama_decode(ctx, batch_view); if (ret != 0) @@ -1936,9 +1976,9 @@ struct llama_server_context continue; } - for (auto & slot : slots) + for (auto &slot : slots) { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) + if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { continue; } @@ -1964,7 +2004,7 @@ struct llama_server_context slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; } - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; + llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false}; result.tok = id; const int32_t n_probs = slot.sparams.n_probs; @@ -2012,6 +2052,10 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --grp-attn-n N\n"); + printf(" group-attention factor (default: %d)\n", params.grp_attn_n); + printf(" --grp-attn-w N\n"); + printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -2064,7 +2108,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, } static void server_params_parse(int argc, char **argv, server_params &sparams, - gpt_params ¶ms, llama_server_context& llama) + gpt_params ¶ms, llama_server_context &llama) { gpt_params default_params; server_params default_sparams; @@ -2118,16 +2162,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::ifstream key_file(argv[i]); - if (!key_file) { + if (!key_file) + { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; break; } std::string key; - while (std::getline(key_file, key)) { - if (key.size() > 0) { - sparams.api_keys.push_back(key); - } + while (std::getline(key_file, key)) + { + if (key.size() > 0) + { + sparams.api_keys.push_back(key); + } } key_file.close(); } @@ -2181,10 +2228,23 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::string value(argv[i]); - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } - else { invalid_param = true; break; } + /**/ if (value == "none") + { + params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; + } + else if (value == "linear") + { + params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; + } + else if (value == "yarn") + { + params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; + } + else + { + invalid_param = true; + break; + } } else if (arg == "--rope-freq-base") { @@ -2206,7 +2266,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-ext-factor") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } @@ -2214,7 +2275,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-attn-factor") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } @@ -2222,7 +2284,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-beta-fast") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } @@ -2230,12 +2293,31 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--yarn-beta-slow") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } params.yarn_beta_slow = std::stof(argv[i]); } + else if (arg == "--grp-attn-n") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.grp_attn_n = std::stoi(argv[i]); + } + else if (arg == "--grp-attn-w") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.grp_attn_w = std::stoi(argv[i]); + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) @@ -2281,7 +2363,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } @@ -2298,7 +2381,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, { params.split_mode = LLAMA_SPLIT_ROW; } - else { + else + { invalid_param = true; break; } @@ -2375,7 +2459,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - const char * lora_adapter = argv[i]; + const char *lora_adapter = argv[i]; if (++i >= argc) { invalid_param = true; @@ -2429,7 +2513,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-n" || arg == "--n-predict") + } + else if (arg == "-n" || arg == "--n-predict") { if (++i >= argc) { @@ -2437,7 +2522,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_predict = std::stoi(argv[i]); - } else if (arg == "-spf" || arg == "--system-prompt-file") + } + else if (arg == "-spf" || arg == "--system-prompt-file") { if (++i >= argc) { @@ -2445,7 +2531,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } std::ifstream file(argv[i]); - if (!file) { + if (!file) + { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; break; @@ -2454,11 +2541,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(systm_content) - ); + std::back_inserter(systm_content)); llama.process_system_prompt_data(json::parse(systm_content)); } - else if(arg == "--mmproj") + else if (arg == "--mmproj") { if (++i >= argc) { @@ -2474,12 +2560,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--override-kv") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } - char * sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { + char *sep = strchr(argv[i], '='); + if (sep == nullptr || sep - argv[i] >= 128) + { fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); invalid_param = true; break; @@ -2488,27 +2576,39 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::strncpy(kvo.key, argv[i], sep - argv[i]); kvo.key[sep - argv[i]] = 0; sep++; - if (strncmp(sep, "int:", 4) == 0) { + if (strncmp(sep, "int:", 4) == 0) + { sep += 4; kvo.tag = LLAMA_KV_OVERRIDE_INT; kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { + } + else if (strncmp(sep, "float:", 6) == 0) + { sep += 6; kvo.tag = LLAMA_KV_OVERRIDE_FLOAT; kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { + } + else if (strncmp(sep, "bool:", 5) == 0) + { sep += 5; kvo.tag = LLAMA_KV_OVERRIDE_BOOL; - if (std::strcmp(sep, "true") == 0) { + if (std::strcmp(sep, "true") == 0) + { kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { + } + else if (std::strcmp(sep, "false") == 0) + { kvo.bool_value = false; - } else { + } + else + { fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); invalid_param = true; break; } - } else { + } + else + { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; break; @@ -2522,7 +2622,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, exit(1); } } - if (!params.kv_overrides.empty()) { + if (!params.kv_overrides.empty()) + { params.kv_overrides.emplace_back(llama_model_kv_override()); params.kv_overrides.back().key[0] = 0; } @@ -2544,7 +2645,8 @@ static std::string random_string() std::string result(32, ' '); - for (int i = 0; i < 32; ++i) { + for (int i = 0; i < 32; ++i) + { result[i] = str[generator() % str.size()]; } @@ -2562,9 +2664,10 @@ std::string format_chatml(std::vector messages) { std::ostringstream chatml_msgs; - for (auto it = messages.begin(); it != messages.end(); ++it) { + for (auto it = messages.begin(); it != messages.end(); ++it) + { chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; + << json_value(*it, "role", std::string("user")) << '\n'; chatml_msgs << json_value(*it, "content", std::string("")) << "<|im_end|>\n"; } @@ -2590,35 +2693,39 @@ json oaicompat_completion_params_parse( // // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' - llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); - llama_params["temperature"] = json_value(body, "temperature", 0.0); - llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); - llama_params["top_p"] = json_value(body, "top_p", 1.0); - llama_params["n_predict"] = json_value(body, "max_tokens", -1); - llama_params["logit_bias"] = json_value(body, "logit_bias",json::object()); + llama_params["model"] = json_value(body, "model", std::string("unknown")); + llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); + llama_params["temperature"] = json_value(body, "temperature", 0.0); + llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); + llama_params["top_p"] = json_value(body, "top_p", 1.0); + llama_params["n_predict"] = json_value(body, "max_tokens", -1); + llama_params["logit_bias"] = json_value(body, "logit_bias", json::object()); llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); - llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); - llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); - llama_params["stream"] = json_value(body, "stream", false); - llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); - llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); - llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); - llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); - llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); + llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); + llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); + llama_params["stream"] = json_value(body, "stream", false); + llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); + llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); + llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); + llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); + llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); - if (body.count("grammar") != 0) { + if (body.count("grammar") != 0) + { llama_params["grammar"] = json_value(body, "grammar", json::object()); } // Handle 'stop' field - if (body.contains("stop") && body["stop"].is_string()) { + if (body.contains("stop") && body["stop"].is_string()) + { llama_params["stop"] = json::array({body["stop"].get()}); - } else { + } + else + { llama_params["stop"] = json_value(body, "stop", json::array()); } @@ -2632,45 +2739,48 @@ static json format_final_response_oaicompat(const json &request, const task_resu { json result = response.result_json; - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { + if (stopped_word || stopped_eos) + { finish_reason = "stop"; } json choices = streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) + {"index", 0}, + {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}}}); std::time_t t = std::time(0); json res = json{{"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", - json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", + json{{"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, + {"id", gen_chatcmplid()}}; - if (server_verbose) { + if (server_verbose) + { res["__verbose"] = result; } - if (result.contains("completion_probabilities")) { + if (result.contains("completion_probabilities")) + { res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); } @@ -2678,26 +2788,30 @@ static json format_final_response_oaicompat(const json &request, const task_resu } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const task_result &response) { +static std::vector format_partial_response_oaicompat(const task_result &response) +{ json result = response.result_json; - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) + { return std::vector({response.result_json}); } bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); std::string content = json_value(result, "content", std::string("")); std::string finish_reason; - if (stopped_word || stopped_eos) { + if (stopped_word || stopped_eos) + { finish_reason = "stop"; } - if (stopped_limit) { + if (stopped_limit) + { finish_reason = "length"; } @@ -2705,46 +2819,54 @@ static std::vector format_partial_response_oaicompat(const task_result &re json choices; - if (!finish_reason.empty()) { + if (!finish_reason.empty()) + { choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { + } + else + { + if (first) + { + if (content.empty()) + { choices = json::array({json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}}); - } else { + } + else + { // We have to send this as two updates to conform to openai behavior json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"}}}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; return std::vector({initial_ret, second_ret}); } - } else { + } + else + { // Some idiosyncrasy in task processing logic makes several trailing calls // with empty content, we ignore these at the calee site. - if (content.empty()) { + if (content.empty()) + { return std::vector({json::object()}); } @@ -2752,9 +2874,9 @@ static std::vector format_partial_response_oaicompat(const task_result &re {"finish_reason", nullptr}, {"index", 0}, {"delta", - json{ - {"content", content}, - }}, + json{ + {"content", content}, + }}, }}); } } @@ -2769,15 +2891,13 @@ static std::vector format_partial_response_oaicompat(const task_result &re } static json format_partial_response( - llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs -) { - json res = json - { - {"content", content }, - {"stop", false}, - {"slot_id", slot->id }, - {"multimodal", llama.multimodal } - }; + llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs) +{ + json res = json{ + {"content", content}, + {"stop", false}, + {"slot_id", slot->id}, + {"multimodal", llama.multimodal}}; if (slot->sparams.n_probs > 0) { @@ -2799,7 +2919,6 @@ static json format_detokenized_response(std::string content) {"content", content}}; } - static void log_server_request(const httplib::Request &req, const httplib::Response &res) { LOG_INFO("request", { @@ -2819,22 +2938,23 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo struct token_translator { - llama_context * ctx; - std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } + llama_context *ctx; + std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } }; static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot *slot) { - auto & gtps = slot->generated_token_probs; + auto >ps = slot->generated_token_probs; auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; + auto add_strlen = [=](size_t sum, const completion_token_output &cto) + { return sum + translator(cto).size(); }; const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); if (slot->generated_text.capacity() < slot->generated_text.size() + len) { slot->generated_text.reserve(slot->generated_text.size() + len); } - for (const completion_token_output & cto : gtps) + for (const completion_token_output &cto : gtps) { slot->generated_text += translator(cto); } @@ -2878,14 +2998,15 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { + svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - }); + res.set_header("Access-Control-Allow-Headers", "*"); }); - svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/health", [&](const httplib::Request &, httplib::Response &res) + { server_state current_state = state.load(); switch(current_state) { case SERVER_STATE_READY: @@ -2900,13 +3021,12 @@ int main(int argc, char **argv) res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); res.status = 500; // HTTP Internal Server Error break; - } - }); + } }); svr.set_logger(log_server_request); svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { + { const char fmt[] = "500 Internal Server Error\n%s"; char buf[BUFSIZ]; try @@ -2922,11 +3042,10 @@ int main(int argc, char **argv) snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); } res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; - }); + res.status = 500; }); svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { + { if (res.status == 401) { res.set_content("Unauthorized", "text/plain; charset=utf-8"); @@ -2939,11 +3058,10 @@ int main(int argc, char **argv) { res.set_content("File Not Found", "text/plain; charset=utf-8"); res.status = 404; - } - }); + } }); // set timeouts and change hostname and port - svr.set_read_timeout (sparams.read_timeout); + svr.set_read_timeout(sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); if (!svr.bind_to_port(sparams.hostname, sparams.port)) @@ -2962,49 +3080,57 @@ int main(int argc, char **argv) log_data["hostname"] = sparams.hostname; log_data["port"] = std::to_string(sparams.port); - if (sparams.api_keys.size() == 1) { + if (sparams.api_keys.size() == 1) + { log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); - } else if (sparams.api_keys.size() > 1) { + } + else if (sparams.api_keys.size() > 1) + { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } LOG_INFO("HTTP server listening", log_data); // run the HTTP server in a thread - see comment below std::thread t([&]() - { + { if (!svr.listen_after_bind()) { state.store(SERVER_STATE_ERROR); return 1; } - return 0; - }); + return 0; }); // load the model if (!llama.load_model(params)) { state.store(SERVER_STATE_ERROR); return 1; - } else { + } + else + { llama.initialize(); state.store(SERVER_STATE_READY); LOG_INFO("model loaded", {}); } // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { + auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool + { // If API key is not set, skip validation - if (sparams.api_keys.empty()) { + if (sparams.api_keys.empty()) + { return true; } // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { + if (auth_header.substr(0, prefix.size()) == prefix) + { std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { + if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) + { return true; // API key is valid } } @@ -3022,42 +3148,37 @@ int main(int argc, char **argv) svr.Get("/", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); - return false; - }); + return false; }); // this is only called if no index.js is found in the public --path svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); - return false; - }); + return false; }); // this is only called if no index.html is found in the public --path svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); - return false; - }); + return false; }); // this is only called if no index.html is found in the public --path svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res) { res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); - return false; - }); + return false; }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() } }; - res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + res.set_content(data.dump(), "application/json; charset=utf-8"); }); svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3123,10 +3244,9 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); + } }); - svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) + svr.Get("/v1/models", [¶ms](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); @@ -3143,13 +3263,11 @@ int main(int argc, char **argv) }} }; - res.set_content(models.dump(), "application/json; charset=utf-8"); - }); - + res.set_content(models.dump(), "application/json; charset=utf-8"); }); // TODO: add mount point without "/v1" prefix -- how? svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3219,11 +3337,10 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); + } }); svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3283,20 +3400,18 @@ int main(int argc, char **argv) }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); + } }); svr.Get("/model.json", [&llama](const httplib::Request &, httplib::Response &res) { const json data = llama.get_model_props(); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); }); svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res) { return res.set_content("", "application/json; charset=utf-8"); }); svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; @@ -3305,11 +3420,10 @@ int main(int argc, char **argv) tokens = llama.tokenize(body["content"], false); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); }); svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; @@ -3320,11 +3434,10 @@ int main(int argc, char **argv) } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + return res.set_content(data.dump(), "application/json; charset=utf-8"); }); svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) - { + { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; @@ -3348,12 +3461,11 @@ int main(int argc, char **argv) const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); task_result result = llama.next_result(task_id); - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); - }); + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux - //std::thread t2([&]() + // std::thread t2([&]() { bool running = true; while (running)