Merge pull request #9 from anon998/stopping-strings
Fix stopping strings.
This commit is contained in:
commit
5f6e16da36
1 changed files with 78 additions and 16 deletions
|
@ -20,6 +20,33 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
|
|||
return i;
|
||||
}
|
||||
|
||||
enum stop_type {
|
||||
STOP_FULL,
|
||||
STOP_PARTIAL,
|
||||
};
|
||||
|
||||
bool ends_with(const std::string &str, const std::string &suffix)
|
||||
{
|
||||
return str.size() >= suffix.size() &&
|
||||
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
size_t find_partial_stop_string(const std::string &stop, const std::string &text)
|
||||
{
|
||||
if (!text.empty()) {
|
||||
const char text_last_char = text.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||
if (ends_with(text, current_partial)) {
|
||||
return text.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
struct llama_server_context
|
||||
{
|
||||
bool stream = false;
|
||||
|
@ -248,6 +275,31 @@ struct llama_server_context
|
|||
return result;
|
||||
}
|
||||
|
||||
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
|
||||
const stop_type type)
|
||||
{
|
||||
size_t stop_pos = std::string::npos;
|
||||
for (const std::string &word : params.antiprompt) {
|
||||
size_t pos;
|
||||
if (type == STOP_FULL) {
|
||||
const size_t tmp = word.size() + last_token_size;
|
||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||
pos = text.find(word, from_pos);
|
||||
} else {
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
if (pos != std::string::npos &&
|
||||
(stop_pos == std::string::npos || pos < stop_pos)) {
|
||||
if (type == STOP_FULL) {
|
||||
stopping_word = word;
|
||||
has_next_token = false;
|
||||
}
|
||||
stop_pos = pos;
|
||||
}
|
||||
}
|
||||
return stop_pos;
|
||||
}
|
||||
|
||||
std::string doCompletion()
|
||||
{
|
||||
llama_token token = nextToken();
|
||||
|
@ -272,16 +324,6 @@ struct llama_server_context
|
|||
stopping_word.c_str());
|
||||
}
|
||||
|
||||
for (const std::string& word : params.antiprompt) {
|
||||
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
|
||||
if (i != std::string::npos) {
|
||||
generated_text.erase(generated_text.begin() + i, generated_text.end());
|
||||
stopping_word = word;
|
||||
has_next_token = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return token_text;
|
||||
}
|
||||
|
||||
|
@ -711,7 +753,14 @@ int main(int argc, char **argv)
|
|||
|
||||
if (!llama.stream) {
|
||||
while (llama.has_next_token) {
|
||||
llama.doCompletion();
|
||||
const std::string token_text = llama.doCompletion();
|
||||
const size_t stop_pos = llama.findStoppingStrings(
|
||||
llama.generated_text, token_text.size(), STOP_FULL);
|
||||
|
||||
if (stop_pos != std::string::npos) {
|
||||
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
|
||||
llama.generated_text.end());
|
||||
}
|
||||
}
|
||||
|
||||
json data = {{"content", llama.generated_text},
|
||||
|
@ -724,7 +773,7 @@ int main(int argc, char **argv)
|
|||
|
||||
llama_print_timings(llama.ctx);
|
||||
|
||||
return res.set_content(
|
||||
res.set_content(
|
||||
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
|
@ -733,7 +782,7 @@ int main(int argc, char **argv)
|
|||
int32_t multibyte_pending = 0;
|
||||
|
||||
while (llama.has_next_token) {
|
||||
std::string token_text = llama.doCompletion();
|
||||
const std::string token_text = llama.doCompletion();
|
||||
|
||||
if (multibyte_pending > 0) {
|
||||
multibyte_pending -= token_text.size();
|
||||
|
@ -761,8 +810,22 @@ int main(int argc, char **argv)
|
|||
continue;
|
||||
}
|
||||
|
||||
const size_t pos = std::min(sent_count, llama.generated_text.size());
|
||||
std::string to_send = llama.generated_text.substr(pos);
|
||||
size_t pos = std::min(sent_count, llama.generated_text.size());
|
||||
|
||||
const char *str_test = llama.generated_text.c_str() + pos;
|
||||
size_t stop_pos =
|
||||
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
||||
if (stop_pos != std::string::npos) {
|
||||
llama.generated_text.erase(
|
||||
llama.generated_text.begin() + pos + stop_pos,
|
||||
llama.generated_text.end());
|
||||
pos = std::min(sent_count, llama.generated_text.size());
|
||||
} else {
|
||||
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
||||
STOP_PARTIAL);
|
||||
}
|
||||
|
||||
std::string to_send = llama.generated_text.substr(pos, stop_pos);
|
||||
sent_count += to_send.size();
|
||||
|
||||
json data;
|
||||
|
@ -808,7 +871,6 @@ int main(int argc, char **argv)
|
|||
}
|
||||
});
|
||||
|
||||
|
||||
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
json body = json::parse(req.body);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue