From d2419b325588e4086819e5be412b274679ee527a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 4 Dec 2024 18:58:16 +0100 Subject: [PATCH] many fixes --- examples/server/server.cpp | 26 ++++++++++++------- examples/server/server.hpp | 24 ++++++++++++----- examples/server/tests/tests.sh | 4 +++ .../server/tests/unit/test_chat_completion.py | 15 +++++------ 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a673fb415..c26bc0867 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1172,6 +1172,8 @@ struct server_context { res.n_decoded = slot.n_decoded; res.n_prompt_tokens = slot.n_prompt_tokens; res.content = tkn.text_to_send; + res.stop = slot.stop; + res.truncated = slot.truncated; if (slot.params.sampling.n_probs > 0) { const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); @@ -1186,7 +1188,8 @@ struct server_context { } } - if (slot.params.timings_per_token) { + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { res.timings = slot.get_timings(); } @@ -1195,6 +1198,7 @@ struct server_context { void send_final_response(server_slot & slot) { if (slot.params.stream) { + // if in stream mode, send the last partial response return send_partial_response(slot, {0, "", {}}); } @@ -1209,6 +1213,8 @@ struct server_context { res.n_tokens_cached = slot.n_past; res.content = slot.generated_text; res.stop = slot.stop; + res.truncated = slot.truncated; + res.timings = slot.get_timings(); res.generation_params = slot.params; // copy the parameters @@ -1439,6 +1445,8 @@ struct server_context { break; } + SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str()); + if (result.stop != STOP_TYPE_NONE) { if (++n_finished == id_tasks.size()) { break; @@ -1533,7 +1541,7 @@ struct server_context { res.id = task.id; res.n_idle_slots = n_idle_slots; res.n_processing_slots = n_processing_slots; - res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); res.t_start = metrics.t_start; res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); @@ -1627,13 +1635,13 @@ struct server_context { const double t_restore_ms = (t_end - t_start) / 1000.0; server_task_result_slot_save_load result; - result.id = task.id; - result.id_slot = id_slot; - result.filename = filename; - result.is_save = false; - result.n_saved = token_count; - result.n_read = nread; - result.t_ms = t_restore_ms; + result.id = task.id; + result.id_slot = id_slot; + result.filename = filename; + result.is_save = false; + result.n_restored = token_count; + result.n_read = nread; + result.t_ms = t_restore_ms; queue_results.send(result); } break; case SERVER_TASK_TYPE_SLOT_ERASE: diff --git a/examples/server/server.hpp b/examples/server/server.hpp index 6197ae565..3e2fd2f52 100644 --- a/examples/server/server.hpp +++ b/examples/server/server.hpp @@ -15,6 +15,7 @@ using json = nlohmann::ordered_json; +// cast a shared_ptr to a specific type using copy constructor #define copy_cast_ptr(TYPEOUT, ptr) *(static_cast(ptr.get())) enum stop_type { @@ -281,23 +282,34 @@ struct server_task_result_cmpl_partial : server_task_result { server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {} int index = 0; std::string content; + + bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; + stop_type stop = STOP_TYPE_NONE; std::vector probs_output; result_timings timings; json to_json() { + bool is_stop = stop != STOP_TYPE_NONE; + // non-OAI-compat JSON json res = json { - {"index", index}, - {"content", content}, - {"stop", stop != STOP_TYPE_NONE}, - {"id_slot", id_slot}, + {"index", index}, + {"content", content}, + {"stop_type", stop_type_to_str(stop)}, + {"stop", is_stop}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, }; - // populate the timings object when timings_per_token is set + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) if (timings.prompt_n > 0) { res.push_back({"timings", timings.to_json()}); } + if (is_stop) { + res.push_back({"truncated", truncated}); + } return res; } @@ -464,7 +476,7 @@ struct server_task_result_slot_erase : server_task_result { { "n_erased", n_erased }, }; } - + static server_task_result_slot_erase from_ptr(std::unique_ptr & result_ptr) { return copy_cast_ptr(server_task_result_slot_erase, result_ptr); } diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 1e285dcda..1e0777de3 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -1,5 +1,9 @@ #!/bin/bash +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + set -eu if [ $# -lt 1 ] diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 8a439f9ef..486c1f87a 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -12,13 +12,13 @@ def create_server(): @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -35,10 +35,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] assert match_regex(re_content, choice["message"]["content"]) - if truncated: - assert choice["finish_reason"] == "length" - else: - assert choice["finish_reason"] == "stop" + assert choice["finish_reason"] == finish_reason @pytest.mark.parametrize( @@ -93,7 +90,7 @@ def test_chat_completion_with_openai_library(): temperature=0.8, ) print(res) - assert res.choices[0].finish_reason == "stop" + assert res.choices[0].finish_reason == "length" assert res.choices[0].message.content is not None assert match_regex("(Suddenly)+", res.choices[0].message.content)