avoid unnecessary empty data event & send rest of partial tokens on stop
This commit is contained in:
parent
3fc1127e2f
commit
3f436ea3f3
1 changed files with 19 additions and 16 deletions
|
@ -1330,39 +1330,42 @@ int main(int argc, char **argv)
|
||||||
size_t pos = std::min(sent_count, llama.generated_text.size());
|
size_t pos = std::min(sent_count, llama.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = llama.generated_text.substr(pos);
|
const std::string str_test = llama.generated_text.substr(pos);
|
||||||
|
bool is_stop_full = false;
|
||||||
size_t stop_pos =
|
size_t stop_pos =
|
||||||
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
||||||
if (stop_pos != std::string::npos) {
|
if (stop_pos != std::string::npos) {
|
||||||
|
is_stop_full = true;
|
||||||
llama.generated_text.erase(
|
llama.generated_text.erase(
|
||||||
llama.generated_text.begin() + pos + stop_pos,
|
llama.generated_text.begin() + pos + stop_pos,
|
||||||
llama.generated_text.end());
|
llama.generated_text.end());
|
||||||
pos = std::min(sent_count, llama.generated_text.size());
|
pos = std::min(sent_count, llama.generated_text.size());
|
||||||
} else {
|
} else {
|
||||||
|
is_stop_full = false;
|
||||||
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
||||||
STOP_PARTIAL);
|
STOP_PARTIAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string to_send = stop_pos == std::string::npos
|
if (
|
||||||
? llama.generated_text.substr(pos, std::string::npos)
|
stop_pos == std::string::npos ||
|
||||||
: ""; // just don't send anything if we're not done
|
// Send rest of the text if we are at the end of the generation
|
||||||
|
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
|
||||||
|
) {
|
||||||
|
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
|
||||||
|
|
||||||
sent_count += to_send.size();
|
sent_count += to_send.size();
|
||||||
|
|
||||||
std::vector<completion_token_output> probs_output = {};
|
std::vector<completion_token_output> probs_output = {};
|
||||||
|
|
||||||
if (llama.params.n_probs > 0) {
|
if (llama.params.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||||
if (probs_pos < probs_stop_pos) {
|
if (probs_pos < probs_stop_pos) {
|
||||||
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
|
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
|
||||||
|
}
|
||||||
|
sent_token_probs_index = probs_stop_pos;
|
||||||
}
|
}
|
||||||
sent_token_probs_index = probs_stop_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// Always send partial response
|
|
||||||
// so we can get the correct partial response of the last to_send in the client
|
|
||||||
const json data = format_partial_response(llama, to_send, probs_output);
|
const json data = format_partial_response(llama, to_send, probs_output);
|
||||||
|
|
||||||
const std::string str =
|
const std::string str =
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue