diff --git a/common/common.h b/common/common.h index 0373fd3ea..95d20401d 100644 --- a/common/common.h +++ b/common/common.h @@ -215,7 +215,7 @@ struct common_params { struct common_params_speculative speculative; std::string model = ""; // model path // NOLINT - std::string model_alias = "unknown"; // model alias // NOLINT + std::string model_alias = ""; // model alias // NOLINT std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT std::string hf_repo = ""; // HF repo // NOLINT diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b58f10186..95d4bfd37 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -250,29 +250,29 @@ struct completion_token_output { std::string text_to_send; struct token_prob { llama_token tok; + std::string tok_str; float prob; }; std::vector probs; - json to_json(const llama_context * ctx) const { + json to_json() const { json probs_for_token = json::array(); for (const auto & p : probs) { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); probs_for_token.push_back(json { - {"tok_str", tok_str}, + {"tok_str", p.tok_str}, {"prob", p.prob}, }); } return probs_for_token; } - static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { + static json probs_vector_to_json(const std::vector & probs) { json out = json::array(); for (const auto & prob : probs) { - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); + const std::string tok_str = prob.text_to_send; out.push_back(json { {"content", tok_str}, - {"probs", prob.to_json(ctx)}, + {"probs", prob.to_json()}, }); } return out; @@ -309,7 +309,7 @@ struct server_task_result_cmpl_final : server_task_result { virtual json to_json() override { // non-OAI-compat JSON - return json { + json res = json { {"index", index}, {"content", content}, {"id_slot", id_slot}, @@ -326,6 +326,10 @@ struct server_task_result_cmpl_final : server_task_result { {"tokens_cached", n_tokens_cached}, {"timings", timings.to_json()}, }; + if (!probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + } + return res; } virtual json to_json_oai_compat() override { @@ -362,12 +366,6 @@ struct server_task_result_cmpl_final : server_task_result { if (verbose) { res["__verbose"] = to_json(); } - - // TODO: fix this - // if (result.contains("completion_probabilities")) { - // res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - // } - if (timings.prompt_n >= 0) { res.push_back({"timings", timings.to_json()}); } @@ -418,6 +416,9 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n > 0) { res.push_back({"timings", timings.to_json()}); } + if (!probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + } if (is_stop) { res.push_back({"truncated", truncated}); } @@ -2786,9 +2787,11 @@ struct server_context { const auto * cur_p = common_sampler_get_candidates(slot.smpl); for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { + auto tok_id = cur_p->data[i].id; result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, + tok_id, + tokens_to_output_formatted_string(ctx, tok_id), + i >= cur_p->size ? 0.0f : cur_p->data[i].p, }); } @@ -2920,10 +2923,6 @@ int main(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - llama_backend_init(); llama_numa_init(params.numa); diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 2930a2e0d..fa3d0a2f5 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` +Hint: You can compile and run test in single command, useful for local developement: + +```shell +cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh +``` + To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 486c1f87a..11bf712b6 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -14,7 +14,7 @@ def create_server(): @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ] ) @@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ], }) assert res.status_code == 200 + assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] @@ -39,17 +40,17 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), ] ) -def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): +def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server + server.model_alias = None server.start() res = server.make_stream_request("POST", "/chat/completions", data={ - "model": model, "max_tokens": max_tokens, "messages": [ {"role": "system", "content": system_prompt}, @@ -60,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r content = "" for data in res: choice = data["choices"][0] + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted assert "content" not in choice["delta"] assert match_regex(re_content, content) - # FIXME: not sure why this is incorrect in stream mode - # if truncated: - # assert choice["finish_reason"] == "length" - # else: - # assert choice["finish_reason"] == "stop" + assert choice["finish_reason"] == finish_reason else: assert choice["finish_reason"] is None content += choice["delta"]["content"] diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 2fa30dd03..1c3aa77de 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp content += data["content"] +def test_completion_stream_vs_non_stream(): + global server + server.start() + res_stream = server.make_stream_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + "stream": True, + }) + res_non_stream = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + }) + content_stream = "" + for data in res_stream: + content_stream += data["content"] + assert content_stream == res_non_stream.body["content"] + + @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server @@ -221,3 +239,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): assert len(res.body["content"]) > 10 # FIXME: the result is not deterministic when using other slot than slot 0 # assert match_regex(re_content, res.body["content"]) + + +def test_n_probs(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "probs" in tok + assert len(tok["probs"]) == 10 + for prob in tok["probs"]: + assert "prob" in prob + assert "tok_str" in prob + assert 0.0 <= prob["prob"] <= 1.0