always send partial response for get correct probs of last to_send

This commit is contained in:
Jhen 2023-08-21 13:26:23 +08:00
parent 371cc14815
commit 1e9fe8a954

View file

@ -1031,7 +1031,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{ {
json res = json{ json res = json{
{"content", content}, {"content", ""},
{"stop", true}, {"stop", true},
{"model", llama.params.model_alias}, {"model", llama.params.model_alias},
{"tokens_predicted", llama.num_tokens_predicted}, {"tokens_predicted", llama.num_tokens_predicted},
@ -1312,24 +1312,45 @@ int main(int argc, char **argv)
sent_token_probs_index = probs_stop_pos; sent_token_probs_index = probs_stop_pos;
} }
const json data = llama.has_next_token {
? format_partial_response(llama, to_send, probs_output) // Always send partial response
// Generation is done, send extra information. // so we can get the correct partial response of the last to_send in the client
: format_final_response(llama, to_send, llama.generated_token_probs); const json data = format_partial_response(llama, to_send, probs_output);
const std::string str = const std::string str =
"data: " + "data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; "\n\n";
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
if (!sink.write(str.data(), str.size())) { if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {}); LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
return false; return false;
}
}
if (!llama.has_next_token) {
// Generation is done, send extra information.
const json data = format_final_response(llama, to_send, llama.generated_token_probs);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
} }
} }