server : fix logprobs, make it openai-compatible

This commit is contained in:
Xuan Son Nguyen 2024-12-11 14:38:57 +01:00
parent 43041d2eb3
commit 74dc729c0b
4 changed files with 217 additions and 69 deletions

View file

@ -342,6 +342,11 @@ struct server_task {
} }
} }
if (params.sampling.n_probs > 0 && params.cache_prompt) {
SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs);
params.cache_prompt = false;
}
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name); params.oaicompat_model = json_value(data, "model", model_name);
@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) {
struct completion_token_output { struct completion_token_output {
llama_token tok; llama_token tok;
float prob;
std::string text_to_send; std::string text_to_send;
struct token_prob { struct token_prob {
llama_token tok; llama_token tok;
@ -427,9 +433,13 @@ struct completion_token_output {
json to_json() const { json to_json() const {
json probs_for_token = json::array(); json probs_for_token = json::array();
for (const auto & p : probs) { for (const auto & p : probs) {
std::string tok_str(p.tok_str);
tok_str.resize(validate_utf8(tok_str));
probs_for_token.push_back(json { probs_for_token.push_back(json {
{"tok_str", p.tok_str}, {"id", p.tok},
{"prob", p.prob}, {"token", tok_str},
{"bytes", str_to_bytes(p.tok_str)},
{"logprob", p.prob},
}); });
} }
return probs_for_token; return probs_for_token;
@ -437,15 +447,27 @@ struct completion_token_output {
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) { static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
json out = json::array(); json out = json::array();
for (const auto & prob : probs) { for (const auto & it : probs) {
const std::string tok_str = prob.text_to_send; std::string tok_str(it.text_to_send);
tok_str.resize(validate_utf8(tok_str));
out.push_back(json { out.push_back(json {
{"content", tok_str}, {"id", it.tok},
{"probs", prob.to_json()}, {"token", tok_str},
{"logprob", it.prob},
{"bytes", str_to_bytes(it.text_to_send)},
{"top_logprobs", it.to_json()},
}); });
} }
return out; return out;
} }
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
std::vector<unsigned char> bytes;
for (unsigned char c : str) {
bytes.push_back(c);
}
return bytes;
}
}; };
struct server_task_result_cmpl_final : server_task_result { struct server_task_result_cmpl_final : server_task_result {
@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"tokens_cached", n_tokens_cached}, {"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()}, {"timings", timings.to_json()},
}; };
if (!probs_output.empty()) { if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
} }
return res; return res;
@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
json choices = json::array({json{ json choice = json{
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json{ {"message", json{
{"content", content}, {"content", content},
{"role", "assistant"} {"role", "assistant"}
} }
}}}); }};
if (!stream && probs_output.size() > 0) {
choice["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json(probs_output)},
};
}
std::time_t t = std::time(0); std::time_t t = std::time(0);
json res = json { json res = json {
{"choices", choices}, {"choices", json::array({choice})},
{"created", t}, {"created", t},
{"model", oaicompat_model}, {"model", oaicompat_model},
{"object", "chat.completion"}, {"object", "chat.completion"},
@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
json choices = json::array({json{{"finish_reason", finish_reason}, json choice = json{
{"index", 0}, {"finish_reason", finish_reason},
{"delta", json::object()}}}); {"index", 0},
{"delta", json::object()}
};
json ret = json { json ret = json {
{"choices", choices}, {"choices", json::array({choice})},
{"created", t}, {"created", t},
{"id", oaicompat_cmpl_id}, {"id", oaicompat_cmpl_id},
{"model", oaicompat_model}, {"model", oaicompat_model},
@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result {
int32_t n_decoded; int32_t n_decoded;
int32_t n_prompt_tokens; int32_t n_prompt_tokens;
std::vector<completion_token_output> probs_output; completion_token_output prob_output;
result_timings timings; result_timings timings;
// OAI-compat fields // OAI-compat fields
@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result {
if (timings.prompt_n > 0) { if (timings.prompt_n > 0) {
res.push_back({"timings", timings.to_json()}); res.push_back({"timings", timings.to_json()});
} }
if (!probs_output.empty()) { if (!prob_output.probs.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
} }
return res; return res;
} }
@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result {
}}); }});
} }
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output})},
};
}
json ret = json { json ret = json {
{"choices", choices}, {"choices", choices},
{"created", t}, {"created", t},
@ -951,7 +989,6 @@ struct server_slot {
// stats // stats
size_t n_sent_text = 0; // number of sent text character size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt; int64_t t_start_process_prompt;
int64_t t_start_generation; int64_t t_start_generation;
@ -973,7 +1010,6 @@ struct server_slot {
stopping_word = ""; stopping_word = "";
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0;
task_type = SERVER_TASK_TYPE_COMPLETION; task_type = SERVER_TASK_TYPE_COMPLETION;
generated_token_probs.clear(); generated_token_probs.clear();
@ -1713,7 +1749,7 @@ struct server_context {
bool process_token(completion_token_output & result, server_slot & slot) { bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); const std::string token_str = result.text_to_send;
slot.sampled = result.tok; slot.sampled = result.tok;
// search stop word and delete it // search stop word and delete it
@ -1721,26 +1757,7 @@ struct server_context {
slot.has_next_token = true; slot.has_next_token = true;
// check if there is incomplete UTF-8 character at the end // check if there is incomplete UTF-8 character at the end
bool incomplete = false; bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
}
if (!incomplete) { if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
@ -1869,6 +1886,29 @@ struct server_context {
return slot.has_next_token; // continue return slot.has_next_token; // continue
} }
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
// set prob for the sampled token
for (size_t i = 0; i < max_probs; ++i) {
if (result.tok == cur_p->data[i].id) {
result.prob = cur_p->data[i].p;
break;
}
}
// set probs for the top n tokens
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
cur_p->data[i].p,
});
}
}
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type); send_error(task.id, error, type);
} }
@ -1906,17 +1946,7 @@ struct server_context {
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); res->prob_output = tkn; // copy the token probs
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
} }
// populate timings if this is final response or timings_per_token is enabled // populate timings if this is final response or timings_per_token is enabled
@ -2747,17 +2777,12 @@ struct server_context {
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
completion_token_output result; completion_token_output result;
result.tok = id; result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later
const auto * cur_p = common_sampler_get_candidates(slot.smpl); if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result);
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
} }
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
@ -2841,7 +2866,9 @@ struct server_context {
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result; completion_token_output result;
result.tok = ids[i]; result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
// release slot because of stop condition // release slot because of stop condition

View file

@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
seed=42, seed=42,
temperature=0.8, temperature=0.8,
) )
print(res)
assert res.choices[0].finish_reason == "length" assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content) assert match_regex("(Suddenly)+", res.choices[0].message.content)
@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
assert "predicted_per_second" in data["timings"] assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"] assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10 assert data["timings"]["predicted_n"] <= 10
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
)
output_text = res.choices[0].message.content
aggregated_text = ''
assert res.choices[0].logprobs is not None
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert 0.0 <= token.logprob <= 1.0
assert token.bytes is not None and len(token.bytes) > 0
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
stream=True,
)
output_text = ''
aggregated_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert 0.0 <= token.logprob <= 1.0
assert token.bytes is not None and len(token.bytes) > 0
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text

View file

@ -260,9 +260,40 @@ def test_n_probs():
assert "completion_probabilities" in res.body assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5 assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]: for tok in res.body["completion_probabilities"]:
assert "probs" in tok assert "id" in tok and tok["id"] > 0
assert len(tok["probs"]) == 10 assert "token" in tok and type(tok["token"]) == str
for prob in tok["probs"]: assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
assert "prob" in prob assert "bytes" in tok and len(tok["bytes"]) > 0
assert "tok_str" in prob assert len(tok["top_logprobs"]) == 10
assert 0.0 <= prob["prob"] <= 1.0 for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
assert "bytes" in prob and len(prob["bytes"]) > 0
def test_n_probs_stream():
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"stream": True,
})
for data in res:
if data["stop"] == False:
assert "completion_probabilities" in data
assert len(data["completion_probabilities"]) == 1
for tok in data["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
assert "bytes" in tok and len(tok["bytes"]) > 0
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
assert "bytes" in prob and len(prob["bytes"]) > 0

View file

@ -170,6 +170,36 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
return result; return result;
} }
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
unsigned char c = text[len - i];
// Check for start of a multi-byte sequence from the end
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
if (i < 4) return len - i;
}
}
// If no cut-off multi-byte character is found, return full length
return len;
}
// //
// template utils // template utils
// //