From 1011a51b8780a1b53ece91201583ad0c756a7e88 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 4 Dec 2024 14:16:01 +0100 Subject: [PATCH] move all response types to struct --- examples/server/server.cpp | 385 +++++++++++++++---------------- examples/server/server.hpp | 457 ++++++++++++++++++++++++++++++------- examples/server/utils.hpp | 1 + 3 files changed, 559 insertions(+), 284 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1482ecbee..de073b085 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -33,6 +33,9 @@ using json = nlohmann::ordered_json; +// using shared_ptr for polymorphism of server_task_result +using task_result_ptr = std::unique_ptr; + struct server_slot { int id; int id_task = -1; @@ -79,9 +82,7 @@ struct server_slot { bool has_next_token = true; bool has_new_line = false; bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; + stop_type stop; bool oaicompat = false; @@ -115,9 +116,7 @@ struct server_slot { generated_text = ""; has_new_line = false; truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; + stop = STOP_TYPE_NONE; stopping_word = ""; n_past = 0; n_sent_text = 0; @@ -203,7 +202,7 @@ struct server_slot { if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { if (is_full_stop) { - stopped_word = true; + stop = STOP_TYPE_WORD; stopping_word = word; has_next_token = false; } @@ -428,8 +427,8 @@ struct server_response { // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; - // the main result queue - std::vector queue_results; + // the main result queue (using ptr for polymorphism) + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; @@ -469,7 +468,7 @@ struct server_response { } // This function blocks the thread until there is a response for one of the id_tasks - server_task_result recv(const std::unordered_set & id_tasks) { + task_result_ptr recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ @@ -477,8 +476,8 @@ struct server_response { }); for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { - server_task_result res = queue_results[i]; + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -489,7 +488,7 @@ struct server_response { } // single-task version of recv() - server_task_result recv(int id_task) { + task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } @@ -501,9 +500,9 @@ struct server_response { std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { if (result.id == id_task) { - SRV_DBG("task id = %d moved to result queue\n", result.id); + SRV_DBG("task id = %d pushed to result queue\n", result.id); - queue_results.push_back(std::move(result)); + queue_results.push_back(std::make_unique(result)); condition_results.notify_all(); return; } @@ -694,7 +693,7 @@ struct server_context { slots.push_back(slot); } - default_generation_settings_for_props = get_formated_generation(slots.front()); + default_generation_settings_for_props = slots[0].params.to_json(); default_generation_settings_for_props["seed"] = -1; // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens @@ -797,7 +796,7 @@ struct server_context { slot.oaicompat_model = ""; } - slot.timings_per_token = json_value(data, "timings_per_token", false); + slot.params.timings_per_token = json_value(data, "timings_per_token", false); slot.params.stream = json_value(data, "stream", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); @@ -1056,7 +1055,7 @@ struct server_context { // check the limits if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); @@ -1065,7 +1064,7 @@ struct server_context { if (slot.has_new_line) { // if we have already seen a new line, we stop after a certain time limit if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); @@ -1085,7 +1084,7 @@ struct server_context { } if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // cut the last line @@ -1114,7 +1113,7 @@ struct server_context { // if context shift is disabled, we stop when it reaches the context limit if (slot.n_past >= slot.n_ctx) { slot.truncated = true; - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", @@ -1122,7 +1121,7 @@ struct server_context { } if (llama_token_is_eog(model, result.tok)) { - slot.stopped_eos = true; + slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; SLT_DBG(slot, "%s", "stopped by EOS\n"); @@ -1132,7 +1131,7 @@ struct server_context { if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { slot.truncated = true; - slot.stopped_limit = true; + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction SLT_WRN(slot, @@ -1201,35 +1200,12 @@ struct server_context { res.has_new_line = slot.has_new_line; res.n_tokens_cached = slot.n_past; res.content = slot.generated_text; + res.stop = slot.stop; - res.params = slot.params; // copy the parameters - - - - res.error = false; - res.stop = true; - res.data = json { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"id_slot", slot.id}, - {"stop", true}, - {"model", params_base.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", common_detokenize(ctx, slot.prompt_tokens)}, - {"has_new_line", slot.has_new_line}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}, - {"index", slot.index}, - }; + res.generation_params = slot.params; // copy the parameters if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stopped_word) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -1399,25 +1375,34 @@ struct server_context { } // receive the results from task(s) created by create_tasks_inference - void receive_cmpl_results( + template + void receive_multi_results( const std::unordered_set & id_tasks, - const std::function&)> & result_handler, + const std::function&)> & result_handler, const std::function & error_handler) { - // TODO: currently, there is no way to detect the client has cancelled the request - std::vector results(id_tasks.size()); + std::vector results(id_tasks.size()); for (size_t i = 0; i < id_tasks.size(); i++) { - server_task_result result = queue_results.recv(id_tasks); + task_result_ptr result_raw = queue_results.recv(id_tasks); - if (result.error) { - error_handler(result.data); + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + error_handler(format_error_response(result.err_msg, result.err_type)); cancel_tasks(id_tasks); return; } - const size_t idx = result.data["index"]; - GGML_ASSERT(idx < results.size() && "index out of range"); - - results[idx] = result; + if ( + result_raw->type == RESULT_TYPE_CMPL_FINAL + || result_raw->type == RESULT_TYPE_EMBD + || result_raw->type == RESULT_TYPE_RERANK + ) { + auto result = T::from_ptr(result_raw); + const size_t idx = result.index; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; + } else { + GGML_ASSERT(false && "unexpected result type"); + } } result_handler(results); } @@ -1425,23 +1410,27 @@ struct server_context { // receive the results from task(s) created by create_tasks_inference, in stream mode void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const - std::function & result_handler, const + std::function & result_handler, const std::function & error_handler) { size_t n_finished = 0; while (true) { - server_task_result result = queue_results.recv(id_tasks); + task_result_ptr result_raw = queue_results.recv(id_tasks); + + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + error_handler(format_error_response(result.err_msg, result.err_type)); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT(result_raw->type == RESULT_TYPE_CMPL_PARTIAL); + auto result = server_task_result_cmpl_partial::from_ptr(result_raw); if (!result_handler(result)) { cancel_tasks(id_tasks); break; } - if (result.error) { - error_handler(result.data); - cancel_tasks(id_tasks); - break; - } - - if (result.stop) { + if (result.stop != STOP_TYPE_NONE) { if (++n_finished == id_tasks.size()) { break; } @@ -1508,7 +1497,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); + json slot_data = slot.params.to_json(); slot_data["id"] = slot.id; slot_data["id_task"] = slot.id_task; slot_data["is_processing"] = slot.is_processing(); @@ -1518,9 +1507,6 @@ struct server_context { {"has_new_line", slot.has_new_line}, {"n_remain", slot.n_remaining}, {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, }; @@ -1534,34 +1520,28 @@ struct server_context { } SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - server_task_result res; - res.id = task.id; - res.stop = true; - res.error = false; - res.data = { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", queue_tasks.queue_tasks_deferred.size() }, - { "t_start", metrics.t_start}, + server_task_result_metrics res; + res.id = task.id; + res.n_idle_slots = n_idle_slots; + res.n_processing_slots = n_processing_slots; + res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res.t_start = metrics.t_start; - { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - { "t_tokens_generation_total", metrics.t_tokens_generation_total}, - { "n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - { "t_prompt_processing_total", metrics.t_prompt_processing_total}, + res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res.kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); - { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - { "t_prompt_processing", metrics.t_prompt_processing}, - { "n_tokens_predicted", metrics.n_tokens_predicted}, - { "t_tokens_generation", metrics.t_tokens_generation}, + res.n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res.t_prompt_processing_total = metrics.t_prompt_processing_total; + res.n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res.t_tokens_generation_total = metrics.t_tokens_generation_total; - { "n_decode_total", metrics.n_decode_total}, - { "n_busy_slots_total", metrics.n_busy_slots_total}, + res.n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res.t_prompt_processing = metrics.t_prompt_processing; + res.n_tokens_predicted = metrics.n_tokens_predicted; + res.t_tokens_generation = metrics.t_tokens_generation; - { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, - - { "slots", slots_data }, - }; + res.n_decode_total = metrics.n_decode_total; + res.n_busy_slots_total = metrics.n_busy_slots_total; if (json_value(task.data, "reset_bucket", false)) { metrics.reset_bucket(); @@ -1594,19 +1574,14 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written - { "timings", { - { "save_ms", t_save_ms } - } } - }; + server_task_result_slot_save_load result; + result.id = task.id; + result.id_slot = id_slot; + result.filename = filename; + result.is_save = true; + result.n_saved = token_count; + result.n_written = nwrite; + result.t_ms = t_save_ms; queue_results.send(result); } break; case SERVER_TASK_TYPE_SLOT_RESTORE: @@ -1642,19 +1617,14 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", token_count }, // tokens restored - { "n_read", nread }, // bytes read - { "timings", { - { "restore_ms", t_restore_ms } - } } - }; + server_task_result_slot_save_load result; + result.id = task.id; + result.id_slot = id_slot; + result.filename = filename; + result.is_save = false; + result.n_saved = token_count; + result.n_read = nread; + result.t_ms = t_restore_ms; queue_results.send(result); } break; case SERVER_TASK_TYPE_SLOT_ERASE: @@ -1677,24 +1647,17 @@ struct server_context { llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "n_erased", n_erased } - }; + server_task_result_slot_erase result; + result.id = task.id; + result.id_slot = id_slot; + result.n_erased = n_erased; queue_results.send(result); } break; case SERVER_TASK_TYPE_SET_LORA: { common_lora_adapters_apply(ctx, loras); - server_task_result result; + server_task_result_apply_lora result; result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{ "success", true }}; queue_results.send(result); } break; } @@ -2456,19 +2419,26 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + task_result_ptr result_raw = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); + if (result_raw->type != RESULT_TYPE_METRICS) { + SRV_ERR("Unexpected result type: %d\n", result_raw->type); + res_error(res, format_error_response("Unexpected result type", ERROR_TYPE_SERVER)); + return; + } + + auto result = server_task_result_metrics::from_ptr(result_raw); + // optionally return "fail_on_no_slot" error - const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { - if (n_idle_slots == 0) { + if (result.n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.data.at("slots")); + res_ok(res, result.slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -2488,73 +2458,68 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + task_result_ptr result_raw = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + res_error(res, format_error_response(result.err_msg, result.err_type)); + return; + } - json data = result.data; - - const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); - const uint64_t t_prompt_processing = data.at("t_prompt_processing"); - - const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); - const uint64_t t_tokens_generation = data.at("t_tokens_generation"); - - const uint64_t n_decode_total = data.at("n_decode_total"); - const uint64_t n_busy_slots_total = data.at("n_busy_slots_total"); - - const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); + GGML_ASSERT(result_raw->type == RESULT_TYPE_METRICS); + auto result = server_task_result_metrics::from_ptr(result_raw); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} + {"value", (uint64_t) result.n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} + {"value", (uint64_t) result.t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data.at("n_tokens_predicted_total")} + {"value", (uint64_t) result.n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} + {"value", (uint64_t) result.t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", n_decode_total} + {"value", result.n_decode_total} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) n_busy_slots_total / (float) n_decode_total} + {"value", (float) result.n_busy_slots_total / (float) result.n_decode_total} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + {"value", result.n_prompt_tokens_processed ? 1.e3 / result.t_prompt_processing * result.n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + {"value", result.n_tokens_predicted ? 1.e3 / result.t_tokens_generation * result.n_tokens_predicted : 0.} },{ {"name", "kv_cache_usage_ratio"}, {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} + {"value", 1. * result.kv_cache_used_cells / params.n_ctx} },{ {"name", "kv_cache_tokens"}, {"help", "KV-cache tokens."}, - {"value", (uint64_t) data.at("kv_cache_tokens_count")} + {"value", (uint64_t) result.kv_cache_tokens_count} },{ {"name", "requests_processing"}, {"help", "Number of request processing."}, - {"value", (uint64_t) data.at("processing")} + {"value", (uint64_t) result.n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of request deferred."}, - {"value", (uint64_t) data.at("deferred")} + {"value", (uint64_t) result.n_tasks_deferred} }}} }; @@ -2575,8 +2540,7 @@ int main(int argc, char ** argv) { } } - const int64_t t_start = data.at("t_start"); - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(result.t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -2602,14 +2566,18 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + res_error(res, format_error_response(result.err_msg, result.err_type)); + return; } + + GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD); + auto result = server_task_result_slot_save_load::from_ptr(result_raw); + res_ok(res, result.to_json()); }; const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { @@ -2632,14 +2600,18 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + res_error(res, format_error_response(result.err_msg, result.err_type)); + return; } + + GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD); + auto result = server_task_result_slot_save_load::from_ptr(result_raw); + res_ok(res, result.to_json()); }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { @@ -2652,14 +2624,18 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + res_error(res, format_error_response(result.err_msg, result.err_type)); + return; } + + GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_ERASE); + auto result = server_task_result_slot_erase::from_ptr(result_raw); + res_ok(res, result.to_json()); }; const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { @@ -2728,15 +2704,15 @@ int main(int argc, char ** argv) { const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { if (results.size() == 1) { // single result - res_ok(res, results[0].data); + res_ok(res, results[0].to_json()); } else { // multiple results (multitask) json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + for (auto & res : results) { + arr.push_back(res.to_json()); } res_ok(res, arr); } @@ -2747,8 +2723,8 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - return server_sent_event(sink, "data", result.data); + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { + return server_sent_event(sink, "data", result.to_json()); }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); @@ -2837,9 +2813,9 @@ int main(int argc, char ** argv) { const auto completion_id = gen_chatcmplid(); if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose); res_ok(res, result_oai); }, [&](const json & error_data) { res_error(res, error_data); @@ -2848,8 +2824,8 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { + std::vector result_array = format_partial_response_oaicompat(result.to_json(), completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { continue; // skip the stop token @@ -2974,9 +2950,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(res.type == RESULT_TYPE_EMBD); + responses.push_back(res.to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3052,9 +3029,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(res.type == RESULT_TYPE_RERANK); + responses.push_back(res.to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3110,11 +3088,18 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - server_task_result result = ctx_server.queue_results.recv(id_task); + task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, result.data); - res.status = 200; // HTTP OK + if (result_raw->type == RESULT_TYPE_ERROR) { + auto result = server_task_result_error::from_ptr(result_raw); + res_error(res, format_error_response(result.err_msg, result.err_type)); + return; + } + + GGML_ASSERT(result_raw->type == RESULT_TYPE_APPLY_LORA); + auto result = server_task_result_apply_lora::from_ptr(result_raw); + res_ok(res, result.to_json()); }; // diff --git a/examples/server/server.hpp b/examples/server/server.hpp index a9287bf6d..081ad2069 100644 --- a/examples/server/server.hpp +++ b/examples/server/server.hpp @@ -15,6 +15,8 @@ using json = nlohmann::ordered_json; +#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast(ptr.get())) + enum stop_type { STOP_TYPE_NONE, STOP_TYPE_EOS, @@ -65,6 +67,19 @@ enum error_type { ERROR_TYPE_NOT_SUPPORTED, // custom error }; +enum result_type { + RESULT_TYPE_CMPL_FINAL, + RESULT_TYPE_CMPL_PARTIAL, + RESULT_TYPE_EMBD, + RESULT_TYPE_RERANK, + RESULT_TYPE_METRICS, + RESULT_TYPE_SLOT_SAVE_LOAD, + RESULT_TYPE_SLOT_ERASE, + RESULT_TYPE_APPLY_LORA, + RESULT_TYPE_ERROR, + RESULT_TYPE_UNKNOWN, // will throw an error +}; + struct server_task { int id = -1; // to be filled by server_queue int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL @@ -87,90 +102,6 @@ struct server_task { } }; -struct result_timings { - int32_t prompt_n; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; -}; - -enum result_type { - RESULT_TYPE_CMPL_FINAL, - RESULT_TYPE_CMPL_PARTIAL, - RESULT_TYPE_EMBD, - RESULT_TYPE_RERANK, - RESULT_TYPE_ERROR, - RESULT_TYPE_UNKNOWN, // will throw an error -}; - -struct server_task_result { - result_type type = RESULT_TYPE_UNKNOWN; - int id = -1; - int id_slot = -1; -}; - -struct server_task_result_cmpl_final : server_task_result { - result_type type = RESULT_TYPE_CMPL_FINAL; - int index = 0; - std::string content; - bool stream; - bool timings_per_token; - result_timings timings; - - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t has_new_line; - int32_t stopping_word; - int32_t n_tokens_cached; - stop_type stop = STOP_TYPE_NONE; - std::vector probs_output; - - slot_params params; -}; - -struct completion_token_output { - llama_token tok; - std::string text_to_send; - struct token_prob { - llama_token tok; - float prob; - }; - std::vector probs; -}; - -struct server_task_result_cmpl_partial : server_task_result { - result_type type = RESULT_TYPE_CMPL_PARTIAL; - int index = 0; - std::string content; - stop_type stop = STOP_TYPE_NONE; - std::vector probs_output; - result_timings timings; -}; - -struct server_task_result_embd : server_task_result { - result_type type = RESULT_TYPE_EMBD; - int index = 0; - std::vector embedding; -}; - -struct server_task_result_rerank : server_task_result { - result_type type = RESULT_TYPE_RERANK; - int index = 0; - float score; -}; - -struct server_task_result_error : server_task_result { - result_type type = RESULT_TYPE_ERROR; - int index = 0; - error_type err_type; - std::string err_msg; -}; - struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt @@ -188,4 +119,362 @@ struct slot_params { struct common_params_sampling sampling; struct common_params_speculative speculative; + + // params only used in to_json() + int32_t n_ctx; + uint32_t seed_cur; + bool can_speculative; + + json to_json() { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + return json { + {"n_ctx", n_ctx}, + {"n_predict", n_predict}, // Server configured n_predict + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"penalize_nl", sampling.penalize_nl}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + //{"logit_bias", sampling.logit_bias}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"samplers", samplers}, + {"speculative", can_speculative}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + }; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + json to_json() { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } +}; + +struct server_task_result { + result_type type = RESULT_TYPE_UNKNOWN; + int id = -1; + int id_slot = -1; + server_task_result() = default; + server_task_result(result_type type) : type(type) {} +}; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + std::string text_to_send; + struct token_prob { + llama_token tok; + float prob; + }; + std::vector probs; +}; + +struct server_task_result_cmpl_final : server_task_result { + server_task_result_cmpl_final() : server_task_result(RESULT_TYPE_CMPL_FINAL) {} + int index = 0; + std::string content; + bool stream; + bool timings_per_token; + result_timings timings; + std::string model_alias; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t has_new_line; + int32_t stopping_word; + int32_t n_tokens_cached; + stop_type stop = STOP_TYPE_NONE; + std::vector probs_output; + + slot_params generation_params; + + json to_json() { + // non-OAI-compat JSON + return json { + {"index", index}, + {"content", content}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", model_alias}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + } + + static server_task_result_cmpl_final from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_cmpl_final, result_ptr); + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {} + int index = 0; + std::string content; + stop_type stop = STOP_TYPE_NONE; + std::vector probs_output; + result_timings timings; + + json to_json() { + json res = json { + {"index", index}, + {"content", content}, + {"stop", stop != STOP_TYPE_NONE}, + {"id_slot", id_slot}, + }; + // populate the timings object when timings_per_token is set + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + return res; + } + + static server_task_result_cmpl_partial from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_cmpl_partial, result_ptr); + } +}; + +struct server_task_result_embd : server_task_result { + server_task_result_embd() : server_task_result(RESULT_TYPE_EMBD) {} + result_type type = RESULT_TYPE_EMBD; + int index = 0; + std::vector embedding; + + json to_json() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + static server_task_result_embd from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_embd, result_ptr); + } +}; + +struct server_task_result_rerank : server_task_result { + server_task_result_rerank() : server_task_result(RESULT_TYPE_RERANK) {} + int index = 0; + float score; + + json to_json() { + return json { + {"index", index}, + {"score", score}, + }; + } + + static server_task_result_rerank from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_rerank, result_ptr); + } +}; + +struct server_task_result_error : server_task_result { + server_task_result_error() : server_task_result(RESULT_TYPE_ERROR) {} + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + static server_task_result_error from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_error, result_ptr); + } +}; + +struct server_task_result_metrics : server_task_result { + server_task_result_metrics() : server_task_result(RESULT_TYPE_METRICS) {} + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // TODO: get rid of this json object and use to_json() instead + json slots_data = json::array(); + + json to_json() { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, + + { "slots", slots_data }, + }; + } + + static server_task_result_metrics from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_metrics, result_ptr); + } +}; + +struct server_task_result_slot_save_load : server_task_result { + server_task_result_slot_save_load() : server_task_result(RESULT_TYPE_SLOT_SAVE_LOAD) {} + std::string filename; + bool is_save; // true = save, false = load + + size_t n_saved; + size_t n_written; + + size_t n_restored; + size_t n_read; + + double t_ms; + + json to_json() { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_saved }, + { "n_written", n_written }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_restored }, + { "n_read", n_read }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } + + static server_task_result_slot_save_load from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_slot_save_load, result_ptr); + } +}; + +struct server_task_result_slot_erase : server_task_result { + server_task_result_slot_erase() : server_task_result(RESULT_TYPE_SLOT_ERASE) {} + size_t n_erased; + + json to_json() { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } + + static server_task_result_slot_erase from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_slot_erase, result_ptr); + } +}; + +struct server_task_result_apply_lora : server_task_result { + server_task_result_apply_lora() : server_task_result(RESULT_TYPE_APPLY_LORA) {} + json to_json() { + return json {{ "success", true }}; + } + + static server_task_result_apply_lora from_ptr(std::unique_ptr & result_ptr) { + return copy_cast_ptr(server_task_result_apply_lora, result_ptr); + } }; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d65773add..b01a7757f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"