avoid unnecessary empty data event & send rest of partial tokens on stop

This commit is contained in:
Jhen 2023-08-23 15:52:49 +08:00
parent 3fc1127e2f
commit 3f436ea3f3

View file

@ -1330,39 +1330,42 @@ int main(int argc, char **argv)
size_t pos = std::min(sent_count, llama.generated_text.size()); size_t pos = std::min(sent_count, llama.generated_text.size());
const std::string str_test = llama.generated_text.substr(pos); const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) { if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase( llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos, llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end()); llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size()); pos = std::min(sent_count, llama.generated_text.size());
} else { } else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(), stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL); STOP_PARTIAL);
} }
const std::string to_send = stop_pos == std::string::npos if (
? llama.generated_text.substr(pos, std::string::npos) stop_pos == std::string::npos ||
: ""; // just don't send anything if we're not done // Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size(); sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) { if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) { if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
} }
sent_token_probs_index = probs_stop_pos;
}
{
// Always send partial response
// so we can get the correct partial response of the last to_send in the client
const json data = format_partial_response(llama, to_send, probs_output); const json data = format_partial_response(llama, to_send, probs_output);
const std::string str = const std::string str =