server: fix issue when handling probability output for incomplete tokens for multibyte character generation
This commit is contained in:
parent
ccf254bd44
commit
cf76195223
1 changed files with 73 additions and 62 deletions
|
@ -26,17 +26,6 @@ struct server_params {
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
};
|
};
|
||||||
|
|
||||||
// completion string output with probabilities
|
|
||||||
struct completion_string_output {
|
|
||||||
struct token_prob {
|
|
||||||
std::string tok_str;
|
|
||||||
float prob;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<token_prob> probs;
|
|
||||||
std::string tok_str;
|
|
||||||
};
|
|
||||||
|
|
||||||
// completion token output with probabilities
|
// completion token output with probabilities
|
||||||
struct completion_token_output {
|
struct completion_token_output {
|
||||||
struct token_prob {
|
struct token_prob {
|
||||||
|
@ -108,6 +97,36 @@ static void server_log(const char * level, const char * function, int line,
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// format incomplete utf-8 multibyte character for output
|
||||||
|
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
|
||||||
|
const std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
|
||||||
|
if (out[0] > 127) {
|
||||||
|
out = "byte: \\x" + std::format("{:x}", out[0]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert a vector of completion_token_output to json
|
||||||
|
static json probs_vector_to_json(const llama_context * ctx, const vector<completion_token_output> probs) {
|
||||||
|
json out = json::array();
|
||||||
|
for (const auto & prob : probs) {
|
||||||
|
json probs_for_token = json::array();
|
||||||
|
for (const auto & p : prob.probs) {
|
||||||
|
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
||||||
|
probs_for_token.push_back(json {
|
||||||
|
{ "tok_str", tok_str },
|
||||||
|
{ "prob", p.prob },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||||
|
out.push_back(json {
|
||||||
|
{"content", tok_str},
|
||||||
|
{"probs", probs_for_token},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
static bool server_verbose = false;
|
static bool server_verbose = false;
|
||||||
|
|
||||||
#if SERVER_VERBOSE != 1
|
#if SERVER_VERBOSE != 1
|
||||||
|
@ -129,7 +148,7 @@ struct llama_server_context {
|
||||||
bool stream = false;
|
bool stream = false;
|
||||||
bool has_next_token = false;
|
bool has_next_token = false;
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<completion_string_output> generated_text_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
size_t num_tokens_predicted = 0;
|
size_t num_tokens_predicted = 0;
|
||||||
size_t n_past = 0;
|
size_t n_past = 0;
|
||||||
|
@ -160,7 +179,7 @@ struct llama_server_context {
|
||||||
num_tokens_predicted = 0;
|
num_tokens_predicted = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
generated_text.reserve(params.n_ctx);
|
generated_text.reserve(params.n_ctx);
|
||||||
generated_text_probs.clear();
|
generated_token_probs.clear();
|
||||||
truncated = false;
|
truncated = false;
|
||||||
stopped_eos = false;
|
stopped_eos = false;
|
||||||
stopped_word = false;
|
stopped_word = false;
|
||||||
|
@ -406,22 +425,16 @@ struct llama_server_context {
|
||||||
return stop_pos;
|
return stop_pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_string_output doCompletion() {
|
completion_token_output doCompletion() {
|
||||||
const completion_token_output token_with_probs = nextToken();
|
const completion_token_output token_with_probs = nextToken();
|
||||||
completion_string_output result;
|
|
||||||
|
|
||||||
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
|
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
|
||||||
result.tok_str = token_text;
|
|
||||||
generated_text += token_text;
|
generated_text += token_text;
|
||||||
|
|
||||||
// iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob
|
if (params.n_probs > 0) {
|
||||||
for (const auto & prob : token_with_probs.probs) {
|
generated_token_probs.push_back(token_with_probs);
|
||||||
const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok);
|
|
||||||
result.probs.push_back({prob_text, prob.prob});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generated_text_probs.push_back(result);
|
|
||||||
|
|
||||||
if (multibyte_pending > 0) {
|
if (multibyte_pending > 0) {
|
||||||
multibyte_pending -= token_text.size();
|
multibyte_pending -= token_text.size();
|
||||||
} else if (token_text.size() == 1) {
|
} else if (token_text.size() == 1) {
|
||||||
|
@ -451,7 +464,7 @@ struct llama_server_context {
|
||||||
|
|
||||||
LOG_VERBOSE("next token", {
|
LOG_VERBOSE("next token", {
|
||||||
{ "token", token_with_probs.tok },
|
{ "token", token_with_probs.tok },
|
||||||
{ "token_text", llama_token_to_str(ctx, token_with_probs.tok) },
|
{ "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
|
||||||
{ "has_next_token", has_next_token },
|
{ "has_next_token", has_next_token },
|
||||||
{ "n_remain", n_remain },
|
{ "n_remain", n_remain },
|
||||||
{ "num_tokens_predicted", num_tokens_predicted },
|
{ "num_tokens_predicted", num_tokens_predicted },
|
||||||
|
@ -461,7 +474,7 @@ struct llama_server_context {
|
||||||
{ "stopping_word", stopping_word },
|
{ "stopping_word", stopping_word },
|
||||||
});
|
});
|
||||||
|
|
||||||
return result;
|
return token_with_probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> getEmbedding() {
|
std::vector<float> getEmbedding() {
|
||||||
|
@ -713,26 +726,10 @@ static json format_embedding_response(llama_server_context & llama) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_string_output> & probs) {
|
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
|
||||||
|
|
||||||
json completion_probabilities_json = json::array();
|
json res = json {
|
||||||
for (const auto & prob : probs) {
|
|
||||||
json probs_for_token = json::array();
|
|
||||||
for (const auto & p : prob.probs) {
|
|
||||||
probs_for_token.push_back(json {
|
|
||||||
{ "tok_str", p.tok_str },
|
|
||||||
{ "prob", p.prob },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
completion_probabilities_json.push_back(json {
|
|
||||||
{"content", prob.tok_str},
|
|
||||||
{"probs", probs_for_token},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return json {
|
|
||||||
{ "content", content },
|
{ "content", content },
|
||||||
{ "completion_probabilities", completion_probabilities_json},
|
|
||||||
{ "stop", true },
|
{ "stop", true },
|
||||||
{ "model", llama.params.model_alias },
|
{ "model", llama.params.model_alias },
|
||||||
{ "tokens_predicted", llama.num_tokens_predicted },
|
{ "tokens_predicted", llama.num_tokens_predicted },
|
||||||
|
@ -743,25 +740,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
|
||||||
{ "stopped_word", llama.stopped_word },
|
{ "stopped_word", llama.stopped_word },
|
||||||
{ "stopped_limit", llama.stopped_limit },
|
{ "stopped_limit", llama.stopped_limit },
|
||||||
{ "stopping_word", llama.stopping_word },
|
{ "stopping_word", llama.stopping_word },
|
||||||
};
|
}
|
||||||
|
|
||||||
|
if (llama.params.n_probs > 0) {
|
||||||
|
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
|
||||||
|
res["completion_probabilities"] = completion_probabilities_json;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_partial_response(const std::string & content, const completion_string_output & probs) {
|
static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
|
||||||
json res = json {
|
json res = json {
|
||||||
{ "content", content },
|
{ "content", content },
|
||||||
{ "stop", false },
|
{ "stop", false },
|
||||||
};
|
};
|
||||||
|
|
||||||
// iterate through probs.probs, and add to res
|
if (llama.params.n_probs > 0) {
|
||||||
json probs_json = json::array();
|
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
|
||||||
for (const auto & prob : probs.probs) {
|
res["completion_probabilities"] = completion_probabilities_json;
|
||||||
probs_json.push_back(json {
|
|
||||||
{ "tok_str", prob.tok_str },
|
|
||||||
{ "prob", prob.prob },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (probs.probs.size() > 0) {
|
|
||||||
res["probs"] = probs_json;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -897,8 +894,8 @@ int main(int argc, char ** argv) {
|
||||||
size_t stop_pos = std::string::npos;
|
size_t stop_pos = std::string::npos;
|
||||||
|
|
||||||
while (llama.has_next_token) {
|
while (llama.has_next_token) {
|
||||||
const completion_string_output token_text_with_probs = llama.doCompletion();
|
const completion_token_output token_with_probs = llama.doCompletion();
|
||||||
const std::string token_text = token_text_with_probs.tok_str;
|
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
|
||||||
|
|
||||||
stop_pos = llama.findStoppingStrings(llama.generated_text,
|
stop_pos = llama.findStoppingStrings(llama.generated_text,
|
||||||
token_text.size(), STOP_FULL);
|
token_text.size(), STOP_FULL);
|
||||||
|
@ -912,7 +909,7 @@ int main(int argc, char ** argv) {
|
||||||
llama.generated_text.end());
|
llama.generated_text.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs);
|
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
|
||||||
|
|
||||||
llama_print_timings(llama.ctx);
|
llama_print_timings(llama.ctx);
|
||||||
|
|
||||||
|
@ -921,9 +918,11 @@ int main(int argc, char ** argv) {
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
||||||
size_t sent_count = 0;
|
size_t sent_count = 0;
|
||||||
|
size_t sent_token_probs_index = 0;
|
||||||
|
|
||||||
while (llama.has_next_token) {
|
while (llama.has_next_token) {
|
||||||
const completion_string_output token_text_with_probs = llama.doCompletion();
|
const completion_token_output token_with_probs = llama.doCompletion();
|
||||||
|
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
|
||||||
if (llama.multibyte_pending > 0) {
|
if (llama.multibyte_pending > 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -932,24 +931,36 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const std::string str_test = llama.generated_text.substr(pos);
|
const std::string str_test = llama.generated_text.substr(pos);
|
||||||
size_t stop_pos =
|
size_t stop_pos =
|
||||||
llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL);
|
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
||||||
if (stop_pos != std::string::npos) {
|
if (stop_pos != std::string::npos) {
|
||||||
llama.generated_text.erase(
|
llama.generated_text.erase(
|
||||||
llama.generated_text.begin() + pos + stop_pos,
|
llama.generated_text.begin() + pos + stop_pos,
|
||||||
llama.generated_text.end());
|
llama.generated_text.end());
|
||||||
pos = std::min(sent_count, llama.generated_text.size());
|
pos = std::min(sent_count, llama.generated_text.size());
|
||||||
} else {
|
} else {
|
||||||
stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(),
|
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
||||||
STOP_PARTIAL);
|
STOP_PARTIAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string to_send = llama.generated_text.substr(pos, stop_pos);
|
const std::string to_send = llama.generated_text.substr(pos, stop_pos);
|
||||||
sent_count += to_send.size();
|
sent_count += to_send.size();
|
||||||
|
|
||||||
|
std::vector<completion_token_output> probs_output = {};
|
||||||
|
|
||||||
|
if (llama.params.n_probs > 0) {
|
||||||
|
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||||
|
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||||
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||||
|
if (probs_pos < probs_stop_pos) {
|
||||||
|
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
|
||||||
|
}
|
||||||
|
sent_token_probs_index = probs_stop_pos;
|
||||||
|
}
|
||||||
|
|
||||||
const json data = llama.has_next_token
|
const json data = llama.has_next_token
|
||||||
? format_partial_response(to_send, token_text_with_probs)
|
? format_partial_response(llama, to_send, probs_output)
|
||||||
// Generation is done, send extra information.
|
// Generation is done, send extra information.
|
||||||
: format_final_response(llama, to_send, {token_text_with_probs});
|
: format_final_response(llama, to_send, probs_output);
|
||||||
|
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"data: " +
|
"data: " +
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue