fix stopping strings
This commit is contained in:
parent
342604bb81
commit
e9b1f0bf5c
1 changed files with 78 additions and 16 deletions
|
@ -20,6 +20,33 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum stop_type {
|
||||||
|
STOP_FULL,
|
||||||
|
STOP_PARTIAL,
|
||||||
|
};
|
||||||
|
|
||||||
|
bool ends_with(const std::string &str, const std::string &suffix)
|
||||||
|
{
|
||||||
|
return str.size() >= suffix.size() &&
|
||||||
|
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t find_partial_stop_string(const std::string &stop, const std::string &text)
|
||||||
|
{
|
||||||
|
if (!text.empty()) {
|
||||||
|
const char text_last_char = text.back();
|
||||||
|
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||||
|
if (stop[char_index] == text_last_char) {
|
||||||
|
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||||
|
if (ends_with(text, current_partial)) {
|
||||||
|
return text.size() - char_index - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_server_context
|
struct llama_server_context
|
||||||
{
|
{
|
||||||
bool stream = false;
|
bool stream = false;
|
||||||
|
@ -248,6 +275,31 @@ struct llama_server_context
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
|
||||||
|
const stop_type type)
|
||||||
|
{
|
||||||
|
size_t stop_pos = std::string::npos;
|
||||||
|
for (const std::string &word : params.antiprompt) {
|
||||||
|
size_t pos;
|
||||||
|
if (type == STOP_FULL) {
|
||||||
|
const size_t tmp = word.size() + last_token_size;
|
||||||
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||||
|
pos = text.find(word, from_pos);
|
||||||
|
} else {
|
||||||
|
pos = find_partial_stop_string(word, text);
|
||||||
|
}
|
||||||
|
if (pos != std::string::npos &&
|
||||||
|
(stop_pos == std::string::npos || pos < stop_pos)) {
|
||||||
|
if (type == STOP_FULL) {
|
||||||
|
stopping_word = word;
|
||||||
|
has_next_token = false;
|
||||||
|
}
|
||||||
|
stop_pos = pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stop_pos;
|
||||||
|
}
|
||||||
|
|
||||||
std::string doCompletion()
|
std::string doCompletion()
|
||||||
{
|
{
|
||||||
llama_token token = nextToken();
|
llama_token token = nextToken();
|
||||||
|
@ -272,16 +324,6 @@ struct llama_server_context
|
||||||
stopping_word.c_str());
|
stopping_word.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const std::string& word : params.antiprompt) {
|
|
||||||
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
|
|
||||||
if (i != std::string::npos) {
|
|
||||||
generated_text.erase(generated_text.begin() + i, generated_text.end());
|
|
||||||
stopping_word = word;
|
|
||||||
has_next_token = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return token_text;
|
return token_text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -711,7 +753,14 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
if (!llama.stream) {
|
if (!llama.stream) {
|
||||||
while (llama.has_next_token) {
|
while (llama.has_next_token) {
|
||||||
llama.doCompletion();
|
const std::string token_text = llama.doCompletion();
|
||||||
|
const size_t stop_pos = llama.findStoppingStrings(
|
||||||
|
llama.generated_text, token_text.size(), STOP_FULL);
|
||||||
|
|
||||||
|
if (stop_pos != std::string::npos) {
|
||||||
|
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
|
||||||
|
llama.generated_text.end());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = {{"content", llama.generated_text},
|
json data = {{"content", llama.generated_text},
|
||||||
|
@ -724,7 +773,7 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
llama_print_timings(llama.ctx);
|
llama_print_timings(llama.ctx);
|
||||||
|
|
||||||
return res.set_content(
|
res.set_content(
|
||||||
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
|
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
|
||||||
"application/json");
|
"application/json");
|
||||||
} else {
|
} else {
|
||||||
|
@ -733,7 +782,7 @@ int main(int argc, char **argv)
|
||||||
int32_t multibyte_pending = 0;
|
int32_t multibyte_pending = 0;
|
||||||
|
|
||||||
while (llama.has_next_token) {
|
while (llama.has_next_token) {
|
||||||
std::string token_text = llama.doCompletion();
|
const std::string token_text = llama.doCompletion();
|
||||||
|
|
||||||
if (multibyte_pending > 0) {
|
if (multibyte_pending > 0) {
|
||||||
multibyte_pending -= token_text.size();
|
multibyte_pending -= token_text.size();
|
||||||
|
@ -761,8 +810,22 @@ int main(int argc, char **argv)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t pos = std::min(sent_count, llama.generated_text.size());
|
size_t pos = std::min(sent_count, llama.generated_text.size());
|
||||||
std::string to_send = llama.generated_text.substr(pos);
|
|
||||||
|
const char *str_test = llama.generated_text.c_str() + pos;
|
||||||
|
size_t stop_pos =
|
||||||
|
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
||||||
|
if (stop_pos != std::string::npos) {
|
||||||
|
llama.generated_text.erase(
|
||||||
|
llama.generated_text.begin() + pos + stop_pos,
|
||||||
|
llama.generated_text.end());
|
||||||
|
pos = std::min(sent_count, llama.generated_text.size());
|
||||||
|
} else {
|
||||||
|
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
||||||
|
STOP_PARTIAL);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string to_send = llama.generated_text.substr(pos, stop_pos);
|
||||||
sent_count += to_send.size();
|
sent_count += to_send.size();
|
||||||
|
|
||||||
json data;
|
json data;
|
||||||
|
@ -808,7 +871,6 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
json body = json::parse(req.body);
|
json body = json::parse(req.body);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue