diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8cb992470..2c94318b4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,6 +342,11 @@ struct server_task { } } + if (params.sampling.n_probs > 0 && params.cache_prompt) { + SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs); + params.cache_prompt = false; + } + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; params.oaicompat_model = json_value(data, "model", model_name); @@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) { struct completion_token_output { llama_token tok; + float prob; std::string text_to_send; struct token_prob { llama_token tok; @@ -427,9 +433,13 @@ struct completion_token_output { json to_json() const { json probs_for_token = json::array(); for (const auto & p : probs) { + std::string tok_str(p.tok_str); + tok_str.resize(validate_utf8(tok_str)); probs_for_token.push_back(json { - {"tok_str", p.tok_str}, - {"prob", p.prob}, + {"id", p.tok}, + {"token", tok_str}, + {"bytes", str_to_bytes(p.tok_str)}, + {"logprob", p.prob}, }); } return probs_for_token; @@ -437,15 +447,27 @@ struct completion_token_output { static json probs_vector_to_json(const std::vector & probs) { json out = json::array(); - for (const auto & prob : probs) { - const std::string tok_str = prob.text_to_send; + for (const auto & it : probs) { + std::string tok_str(it.text_to_send); + tok_str.resize(validate_utf8(tok_str)); out.push_back(json { - {"content", tok_str}, - {"probs", prob.to_json()}, + {"id", it.tok}, + {"token", tok_str}, + {"logprob", it.prob}, + {"bytes", str_to_bytes(it.text_to_send)}, + {"top_logprobs", it.to_json()}, }); } return out; } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } }; struct server_task_result_cmpl_final : server_task_result { @@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result { {"tokens_cached", n_tokens_cached}, {"timings", timings.to_json()}, }; - if (!probs_output.empty()) { + if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } return res; @@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - json choices = json::array({json{ + json choice = json{ {"finish_reason", finish_reason}, {"index", 0}, {"message", json{ {"content", content}, {"role", "assistant"} } - }}}); + }}; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output)}, + }; + } std::time_t t = std::time(0); json res = json { - {"choices", choices}, + {"choices", json::array({choice})}, {"created", t}, {"model", oaicompat_model}, {"object", "chat.completion"}, @@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - json choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); + json choice = json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; json ret = json { - {"choices", choices}, + {"choices", json::array({choice})}, {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, @@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; - std::vector probs_output; + completion_token_output prob_output; result_timings timings; // OAI-compat fields @@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n > 0) { res.push_back({"timings", timings.to_json()}); } - if (!probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}); } return res; } @@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result { }}); } + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output})}, + }; + } + json ret = json { {"choices", choices}, {"created", t}, @@ -951,7 +989,6 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; int64_t t_start_process_prompt; int64_t t_start_generation; @@ -973,7 +1010,6 @@ struct server_slot { stopping_word = ""; n_past = 0; n_sent_text = 0; - n_sent_token_probs = 0; task_type = SERVER_TASK_TYPE_COMPLETION; generated_token_probs.clear(); @@ -1713,7 +1749,7 @@ struct server_context { bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; // search stop word and delete it @@ -1721,26 +1757,7 @@ struct server_context { slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); @@ -1869,6 +1886,29 @@ struct server_context { return slot.has_next_token; // continue } + void populate_token_probs(const server_slot & slot, completion_token_output & result) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set prob for the sampled token + for (size_t i = 0; i < max_probs; ++i) { + if (result.tok == cur_p->data[i].id) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probs for the top n tokens + for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) { + auto tok_id = cur_p->data[i].id; + result.probs.push_back({ + tok_id, + tokens_to_output_formatted_string(ctx, tok_id), + cur_p->data[i].p, + }); + } + } + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(task.id, error, type); } @@ -1906,17 +1946,7 @@ struct server_context { // populate res.probs_output if (slot.params.sampling.n_probs > 0) { - const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); - - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); - - std::vector probs_output; - if (probs_pos < probs_stop_pos) { - res->probs_output = std::vector( - slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } + res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled @@ -2747,17 +2777,12 @@ struct server_context { slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; completion_token_output result; - result.tok = id; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // set later - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - - for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { - auto tok_id = cur_p->data[i].id; - result.probs.push_back({ - tok_id, - tokens_to_output_formatted_string(ctx, tok_id), - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result); } if (!process_token(result, slot)) { @@ -2841,7 +2866,9 @@ struct server_context { for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; - result.tok = ids[i]; + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // set later if (!process_token(result, slot)) { // release slot because of stop condition diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 6573cc17f..299472fa4 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library(): seed=42, temperature=0.8, ) - print(res) assert res.choices[0].finish_reason == "length" assert res.choices[0].message.content is not None assert match_regex("(Suddenly)+", res.choices[0].message.content) @@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token(): assert "predicted_per_second" in data["timings"] assert "predicted_n" in data["timings"] assert data["timings"]["predicted_n"] <= 10 + + +def test_logprobs(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + ) + output_text = res.choices[0].message.content + aggregated_text = '' + assert res.choices[0].logprobs is not None + assert res.choices[0].logprobs.content is not None + for token in res.choices[0].logprobs.content: + aggregated_text += token.token + assert 0.0 <= token.logprob <= 1.0 + assert token.bytes is not None and len(token.bytes) > 0 + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text + + +def test_logprobs_stream(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + stream=True, + ) + output_text = '' + aggregated_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert 0.0 <= token.logprob <= 1.0 + assert token.bytes is not None and len(token.bytes) > 0 + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 7f4f9cd03..4c89ee3ee 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -260,9 +260,40 @@ def test_n_probs(): assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 for tok in res.body["completion_probabilities"]: - assert "probs" in tok - assert len(tok["probs"]) == 10 - for prob in tok["probs"]: - assert "prob" in prob - assert "tok_str" in prob - assert 0.0 <= prob["prob"] <= 1.0 + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "bytes" in tok and len(tok["bytes"]) > 0 + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "bytes" in prob and len(prob["bytes"]) > 0 + + +def test_n_probs_stream(): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "stream": True, + }) + for data in res: + if data["stop"] == False: + assert "completion_probabilities" in data + assert len(data["completion_probabilities"]) == 1 + for tok in data["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "bytes" in tok and len(tok["bytes"]) > 0 + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "bytes" in prob and len(prob["bytes"]) > 0 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8f545aea5..3750cf758 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -170,6 +170,36 @@ static std::vector tokenize_input_prompts(llama_context * ctx, con return result; } +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + // // template utils //