Merge pull request #14 from anon998/do-completion-update

Trim partial stopping strings when not streaming and move multibyte check.
This commit is contained in:
Randall Fitzgerald 2023-06-02 07:30:53 -04:00 committed by GitHub
commit f5d5e7020d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -67,6 +67,7 @@ struct llama_server_context
bool verbose = false; bool verbose = false;
int json_indent = -1; int json_indent = -1;
int32_t multibyte_pending = 0;
~llama_server_context() ~llama_server_context()
{ {
@ -82,6 +83,7 @@ struct llama_server_context
generated_text = ""; generated_text = "";
generated_text.reserve(params.n_ctx); generated_text.reserve(params.n_ctx);
stopping_word = ""; stopping_word = "";
multibyte_pending = 0;
n_remain = 0; n_remain = 0;
n_past = 0; n_past = 0;
@ -300,12 +302,32 @@ struct llama_server_context
std::string doCompletion() std::string doCompletion()
{ {
llama_token token = nextToken(); llama_token token = nextToken();
if (token == -1) {
return ""; std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
generated_text += token_text;
if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
const char c = token_text[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0) {
multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF0) == 0xE0) {
multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF8) == 0xF0) {
multibyte_pending = 3;
} else {
multibyte_pending = 0;
}
} }
std::string token_text = llama_token_to_str(ctx, token); if (multibyte_pending > 0 && !has_next_token) {
generated_text += token_text; has_next_token = true;
n_remain++;
}
if (verbose) { if (verbose) {
fprintf(stderr, fprintf(stderr,
@ -761,15 +783,21 @@ int main(int argc, char **argv)
llama.beginCompletion(); llama.beginCompletion();
if (!llama.stream) { if (!llama.stream) {
size_t stop_pos = std::string::npos;
while (llama.has_next_token) { while (llama.has_next_token) {
const std::string token_text = llama.doCompletion(); const std::string token_text = llama.doCompletion();
const size_t stop_pos = llama.findStoppingStrings(
llama.generated_text, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) { stop_pos = llama.findStoppingStrings(llama.generated_text,
llama.generated_text.erase(llama.generated_text.begin() + stop_pos, token_text.size(), STOP_FULL);
llama.generated_text.end()); }
}
if (stop_pos == std::string::npos) {
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
}
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
} }
json data = {{"content", llama.generated_text}, json data = {{"content", llama.generated_text},
@ -788,34 +816,10 @@ int main(int argc, char **argv)
} else { } else {
const auto chunked_content_provider = [&](size_t, DataSink &sink) { const auto chunked_content_provider = [&](size_t, DataSink &sink) {
size_t sent_count = 0; size_t sent_count = 0;
int32_t multibyte_pending = 0;
while (llama.has_next_token) { while (llama.has_next_token) {
const std::string token_text = llama.doCompletion(); const std::string token_text = llama.doCompletion();
if (llama.multibyte_pending > 0) {
if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
const char c = token_text[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0) {
multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF0) == 0xE0) {
multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF8) == 0xF0) {
multibyte_pending = 3;
} else {
multibyte_pending = 0;
}
}
if (multibyte_pending > 0) {
if (!llama.has_next_token) {
llama.has_next_token = true;
llama.n_remain++;
}
continue; continue;
} }