Fix crash in server example caused by oob due to no show words scanning
- Replaced scanning code by lookahead based strategy
This commit is contained in:
parent
b50b570ed9
commit
2b496429c3
1 changed files with 66 additions and 62 deletions
|
@ -1,5 +1,5 @@
|
|||
#include <httplib.h>
|
||||
#include <json.hpp>
|
||||
#include "httplib.h"
|
||||
#include "json.hpp"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
|
@ -27,7 +27,8 @@ struct llama_server_context
|
|||
std::vector<llama_token> llama_token_newline;
|
||||
std::vector<llama_token> embd_inp;
|
||||
std::vector<std::vector<llama_token>> no_show_words;
|
||||
std::vector<llama_token> tokens_predicted;
|
||||
std::deque<llama_token> tokens_predicted;
|
||||
std::vector<llama_token>::size_type n_read_ahead = 0;
|
||||
|
||||
llama_context *ctx;
|
||||
gpt_params params;
|
||||
|
@ -294,64 +295,54 @@ struct llama_server_context
|
|||
return result;
|
||||
}
|
||||
|
||||
std::string doCompletion()
|
||||
{
|
||||
llama_token token = nextToken();
|
||||
if (token == -1) {
|
||||
return "";
|
||||
}
|
||||
tokens_predicted.clear();
|
||||
tokens_predicted.push_back(token);
|
||||
|
||||
// Avoid add the no show words to the response
|
||||
for (std::vector<llama_token> word_tokens : no_show_words)
|
||||
{
|
||||
size_t match_token = 1;
|
||||
if (tokens_predicted.front() == word_tokens.front())
|
||||
{
|
||||
bool execute_matching = true;
|
||||
if (tokens_predicted.size() > 1) { // if previus tokens had been tested
|
||||
for (size_t i = 1; i < word_tokens.size(); i++)
|
||||
{
|
||||
if (i >= tokens_predicted.size()) {
|
||||
match_token = i;
|
||||
break;
|
||||
}
|
||||
if (tokens_predicted[i] == word_tokens[i])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
execute_matching = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
while (execute_matching) {
|
||||
if (match_token == word_tokens.size()) {
|
||||
return "";
|
||||
}
|
||||
token = nextToken();
|
||||
tokens_predicted.push_back(token);
|
||||
if (token == word_tokens[match_token])
|
||||
{ // the token follow the sequence
|
||||
match_token++;
|
||||
}
|
||||
else if (match_token < word_tokens.size())
|
||||
{ // no complete all word sequence
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if(as_loop) {
|
||||
std::string doCompletion() {
|
||||
if (as_loop) {
|
||||
generated_text = "";
|
||||
}
|
||||
for (llama_token tkn : tokens_predicted)
|
||||
{
|
||||
|
||||
// Avoid add the no show words to the response
|
||||
bool removed_no_show_words;
|
||||
bool past_end_of_tokens = false;
|
||||
do {
|
||||
removed_no_show_words = false;
|
||||
|
||||
// Fill predicted tokens to `read_ahead` tokens if possible
|
||||
while (tokens_predicted.size() < n_read_ahead) {
|
||||
llama_token token = nextToken();
|
||||
if (token == -1) {
|
||||
past_end_of_tokens = true;
|
||||
break;
|
||||
}
|
||||
tokens_predicted.push_back(token);
|
||||
}
|
||||
|
||||
// Remove sequences of no_show_words in `predicted_tokens`
|
||||
for (const auto &no_show : no_show_words) {
|
||||
|
||||
const auto &occurrence =
|
||||
std::search(tokens_predicted.begin(), tokens_predicted.end(),
|
||||
no_show.begin(), no_show.end());
|
||||
|
||||
if (occurrence != tokens_predicted.end()) {
|
||||
tokens_predicted.erase(occurrence, occurrence + no_show.size());
|
||||
removed_no_show_words = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Continue until end of tokens or as long as sequences have been removed
|
||||
} while (removed_no_show_words && !past_end_of_tokens);
|
||||
|
||||
if (past_end_of_tokens) {
|
||||
// If end of tokens, return all and clear
|
||||
for (llama_token tkn : tokens_predicted) {
|
||||
generated_text += llama_token_to_str(ctx, tkn);
|
||||
}
|
||||
tokens_predicted.clear();
|
||||
} else {
|
||||
// Else just pick the 1st token and add it
|
||||
generated_text += llama_token_to_str(ctx, tokens_predicted[0]);
|
||||
tokens_predicted.pop_front();
|
||||
}
|
||||
return generated_text;
|
||||
}
|
||||
|
||||
|
@ -476,6 +467,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
{
|
||||
params.embedding = true;
|
||||
}
|
||||
else if (arg == "--keep")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_keep = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
server_print_usage(argc, argv, default_params);
|
||||
|
@ -622,18 +622,22 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
|
|||
if (!body["stop"].is_null())
|
||||
{
|
||||
std::vector<std::string> stop_words = body["stop"].get<std::vector<std::string>>();
|
||||
for (std::string stop_word : stop_words)
|
||||
for (const std::string& stop_word : stop_words)
|
||||
{
|
||||
llama.params.antiprompt.push_back(stop_word);
|
||||
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false));
|
||||
auto tokens = ::llama_tokenize(llama.ctx, stop_word, false);
|
||||
llama.n_read_ahead = std::max(llama.n_read_ahead, tokens.size());
|
||||
llama.no_show_words.push_back(tokens);
|
||||
}
|
||||
}
|
||||
if (!body["exclude"].is_null())
|
||||
{
|
||||
std::vector<std::string> no_show_words = body["exclude"].get<std::vector<std::string>>();
|
||||
for (std::string no_show : no_show_words)
|
||||
for (const std::string& no_show : no_show_words)
|
||||
{
|
||||
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false));
|
||||
auto tokens = ::llama_tokenize(llama.ctx, no_show, false);
|
||||
llama.n_read_ahead = std::max(llama.n_read_ahead, tokens.size());
|
||||
llama.no_show_words.push_back(tokens);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue