grammar: trigger words + refactor of antiprompts

This commit is contained in:
ochafik 2024-09-25 15:51:37 +01:00
parent 70392f1f81
commit 5b6d5040d5
12 changed files with 436 additions and 108 deletions

View file

@ -44,6 +44,7 @@ BUILD_TARGETS = \
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = \ TEST_TARGETS = \
tests/test-antiprompts \
tests/test-arg-parser \ tests/test-arg-parser \
tests/test-autorelease \ tests/test-autorelease \
tests/test-backend-ops \ tests/test-backend-ops \
@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-antiprompts: tests/test-antiprompts.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-grad0: tests/test-grad0.cpp \ tests/test-grad0: tests/test-grad0.cpp \
$(OBJ_GGML) $(OBJ_GGML)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

View file

@ -4,9 +4,11 @@
#include "llama.h" #include "llama.h"
#include <queue>
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <unordered_map>
#ifdef _WIN32 #ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\' #define DIRECTORY_SEPARATOR '\\'
@ -134,6 +136,7 @@ struct gpt_sampler_params {
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@ -533,6 +536,201 @@ struct llama_control_vector_load_info {
// On error, returns {-1, empty} // On error, returns {-1, empty}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos); llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
//
// Antiprompt utils
//
class llama_antiprompts {
public:
struct llama_antiprompt {
std::string value;
bool is_grammar_trigger;
};
std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;
private:
// The AhoCorasick algorithm allows efficient string matching with multiple patterns.
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
struct TrieNode {
std::unordered_map<char, TrieNode> children;
TrieNode* fail = nullptr;
int output = -1;
size_t depth = 0;
void clear() {
children.clear();
fail = nullptr;
output = -1;
depth = 0;
}
};
TrieNode root;
std::vector<llama_antiprompt> antiprompts;
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
void build_trie() {
// root = std::unique_ptr<TrieNode>(new TrieNode());
for (size_t i = 0; i < antiprompts.size(); ++i) {
TrieNode* node = &root;
const auto & pattern = antiprompts[i].value;
for (size_t j = 0; j < pattern.length(); ++j) {
char c = pattern[j];
auto & child = node->children[c];
if (child.depth == 0) {
child.depth = j + 1;
}
node = &child;
}
node->output = i;
}
}
void build_failure_and_dict_links() {
std::queue<TrieNode*> q;
for (auto& child : root.children) {
child.second.fail = &root;
q.push(&child.second);
}
while (!q.empty()) {
auto node = q.front();
q.pop();
for (auto & pair : node->children) {
auto & c = pair.first;
auto & child = pair.second;
auto f = node->fail;
while (f != &root && f->children.find(c) == f->children.end()) {
f = f->fail;
}
child.fail = (f == &root && f->children.find(c) == f->children.end())
? &root : &f->children[c];
if (child.fail->output != -1) {
child.output = child.fail->output;
}
q.push(&child);
}
}
}
public:
bool empty() const {
return antiprompts.empty() && stop_tokens.empty();
}
void clear() {
root.clear();
antiprompts.clear();
stop_tokens.clear();
}
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
build(
[&](const std::string & text) {
return llama_tokenize(ctx, text, /* special= */ true);
},
stop_words,
grammar_trigger_words
);
}
void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
clear();
this->stop_words = stop_words;
this->grammar_trigger_words = grammar_trigger_words;
for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
}
for (const std::string & trigger : grammar_trigger_words) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
}
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
const auto & antiprompt = antiprompts[i];
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
if (tokens.size() == 1) {
stop_tokens[tokens[0]] = i;
}
}
build_trie();
build_failure_and_dict_links();
}
struct MatchResult {
size_t pos;
std::string pattern;
bool is_partial;
size_t matchLength;
bool is_grammar_trigger;
bool operator==(const MatchResult & other) const {
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
}
operator std::string() const {
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
}
};
MatchResult findSingleTokenMatch(llama_token token) const {
auto it = stop_tokens.find(token);
if (it != stop_tokens.end()) {
const auto & antiprompt = antiprompts[it->second];
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
}
return {std::string::npos, "", false, 0, false};
}
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
TrieNode* current = &root;
MatchResult partialMatch{std::string::npos, "", true, 0, false};
for (size_t i = offset; i < text.length(); ++i) {
char c = text[i];
while (current != &root && current->children.find(c) == current->children.end()) {
current = current->fail;
}
auto it = current->children.find(c);
if (it != current->children.end()) {
current = &it->second;
}
if (current->output != -1) {
const auto & antiprompt = antiprompts[current->output];
return {
i - antiprompt.value.length() + 1,
antiprompt.value,
false,
antiprompt.value.length(),
antiprompt.is_grammar_trigger,
};
}
// Update partial match if we're at a deeper node
if (current->depth > partialMatch.matchLength) {
partialMatch.pos = i - current->depth + 1;
partialMatch.pattern = ""; // We don't know which pattern it partially matches
partialMatch.matchLength = current->depth;
partialMatch.is_grammar_trigger = false;
}
}
// If we've found a partial match and haven't returned a full match, return the partial match
if (partialMatch.pos != std::string::npos) {
return partialMatch;
}
return {std::string::npos, "", false, 0, false};
}
};
// //
// Split utils // Split utils
// //

View file

@ -139,6 +139,15 @@ std::string gpt_sampler_params::print() const {
return std::string(result); return std::string(result);
} }
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
if (gsmpl->grmr) {
return false;
}
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
return true;
}
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -146,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
auto * result = new gpt_sampler { auto * result = new gpt_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},
@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
void gpt_sampler_free(struct gpt_sampler * gsmpl) { void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) { if (gsmpl) {
llama_sampler_free(gsmpl->grmr); if (gsmpl->grmr) {
llama_sampler_free(gsmpl->grmr);
}
llama_sampler_free(gsmpl->chain); llama_sampler_free(gsmpl->chain);

View file

@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger);
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars); std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);

View file

@ -36,7 +36,7 @@ static llama_model ** g_model;
static gpt_sampler ** g_smpl; static gpt_sampler ** g_smpl;
static gpt_params * g_params; static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens; static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss; static std::string * g_output_s;
static std::vector<llama_token> * g_output_tokens; static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false; static bool is_interacting = false;
static bool need_insert_eot = false; static bool need_insert_eot = false;
@ -115,7 +115,7 @@ static void sigint_handler(int signo) {
console::cleanup(); console::cleanup();
LOG("\n"); LOG("\n");
gpt_perf_print(*g_ctx, *g_smpl); gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, *g_output_s, *g_output_tokens);
// make sure all logs are flushed // make sure all logs are flushed
LOG("Interrupted by user\n"); LOG("Interrupted by user\n");
@ -507,7 +507,8 @@ int main(int argc, char ** argv) {
std::vector<int> input_tokens; g_input_tokens = &input_tokens; std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens; std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss; std::string output_s; g_output_s = &output_s;
size_t last_partial_stop = std::string::npos;
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
// the first thing we will do is to output the prompt, so set color accordingly // the first thing we will do is to output the prompt, so set color accordingly
@ -516,13 +517,8 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
// tokenized antiprompts llama_antiprompts antiprompts;
std::vector<std::vector<llama_token>> antiprompt_ids; antiprompts.build(ctx, params.antiprompt, {});
antiprompt_ids.reserve(params.antiprompt.size());
for (const std::string & antiprompt : params.antiprompt) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
}
if (llama_model_has_encoder(model)) { if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size(); int enc_input_size = embd_inp.size();
@ -727,7 +723,7 @@ int main(int argc, char ** argv) {
} else { } else {
// Outgoing Generated Tokens // Outgoing Generated Tokens
output_tokens.push_back(id); output_tokens.push_back(id);
output_ss << token_str; output_s.append(token_str);
} }
} }
} }
@ -740,44 +736,34 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt in the last n_prev tokens // check for reverse prompt
if (!params.antiprompt.empty()) { if (!antiprompts.empty()) {
const int n_prev = 32;
const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
is_antiprompt = false; is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
// so we'll compensate for that by widening the search window a bit.
for (std::string & antiprompt : params.antiprompt) {
size_t extra_padding = params.interactive ? 0 : 2;
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0;
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
}
}
// check for reverse prompt using special tokens // check for reverse prompt using special tokens
llama_token last_token = gpt_sampler_last(smpl); llama_token last_token = gpt_sampler_last(smpl);
for (std::vector<llama_token> ids : antiprompt_ids) { auto match = antiprompts.findSingleTokenMatch(last_token);
if (ids.size() == 1 && last_token == ids[0]) { if (match.pos != std::string::npos) {
if (params.interactive) { if (params.interactive) {
is_interacting = true; is_interacting = true;
}
is_antiprompt = true;
} else {
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
if (match.pos != std::string::npos) {
if (match.is_partial) {
last_partial_stop = match.pos;
} else {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
} }
is_antiprompt = true;
break;
} }
} }
if (is_antiprompt) { if (is_antiprompt) {
LOG_DBG("found antiprompt: %s\n", last_output.c_str()); LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
} }
} }
@ -786,9 +772,9 @@ int main(int argc, char ** argv) {
LOG_DBG("found an EOG token\n"); LOG_DBG("found an EOG token\n");
if (params.interactive) { if (params.interactive) {
if (!params.antiprompt.empty()) { if (!antiprompts.stop_words.empty()) {
// tokenize and inject first reverse prompt // tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); const auto first_antiprompt = ::llama_tokenize(ctx, antiprompts.stop_words.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true; is_antiprompt = true;
} }
@ -882,7 +868,7 @@ int main(int argc, char ** argv) {
for (size_t i = original_size; i < embd_inp.size(); ++i) { for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i]; const llama_token token = embd_inp[i];
output_tokens.push_back(token); output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token); output_s.append(llama_token_to_piece(ctx, token));
} }
// reset assistant message // reset assistant message
@ -926,7 +912,7 @@ int main(int argc, char ** argv) {
LOG("\n\n"); LOG("\n\n");
gpt_perf_print(ctx, smpl); gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); write_logfile(ctx, params, model, input_tokens, output_s, output_tokens);
gpt_sampler_free(smpl); gpt_sampler_free(smpl);

View file

@ -131,8 +131,6 @@ struct slot_params {
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
json input_prefix; json input_prefix;
json input_suffix; json input_suffix;
}; };
@ -183,6 +181,8 @@ struct server_slot {
std::string oaicompat_model; std::string oaicompat_model;
std::string stopping_word; std::string stopping_word;
llama_antiprompts antiprompts;
// sampling // sampling
json json_schema; json json_schema;
@ -281,34 +281,6 @@ struct server_slot {
}; };
} }
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
size_t stop_pos = std::string::npos;
for (const std::string & word : params.antiprompt) {
size_t pos;
if (type == STOP_TYPE_FULL) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
} else {
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_TYPE_FULL) {
stopped_word = true;
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
void print_timings() const { void print_timings() const {
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
@ -999,16 +971,26 @@ struct server_context {
} }
{ {
slot.params.antiprompt.clear(); slot.antiprompts.clear();
const auto & stop = data.find("stop"); auto copy_string_array = [&](const json & data, const std::string & key, std::vector<std::string> & vec) {
if (stop != data.end() && stop->is_array()) { const auto & arr = data.find(key);
for (const auto & word : *stop) { if (arr != data.end() && arr->is_array()) {
if (!word.empty()) { for (const auto & word : *arr) {
slot.params.antiprompt.push_back(word); if (word.is_string()) {
vec.push_back(word);
}
} }
} }
} };
std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;
copy_string_array(data, "stop", stop_words);
copy_string_array(data, "grammar_trigger_words", grammar_trigger_words);
slot.antiprompts.build(ctx, stop_words, grammar_trigger_words);
} }
{ {
@ -1110,6 +1092,18 @@ struct server_context {
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
slot.sampled = result.tok; slot.sampled = result.tok;
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
gpt_sampler_trigger_grammar(model, slot.smpl, llama_token_to_piece(ctx, result.tok, params.special));
} else {
slot.stopped_word = true;
slot.stopping_word = match.pattern;
slot.has_next_token = false;
return false;
}
}
// search stop word and delete it // search stop word and delete it
slot.generated_text += token_str; slot.generated_text += token_str;
slot.has_next_token = true; slot.has_next_token = true;
@ -1139,23 +1133,33 @@ struct server_context {
if (!incomplete) { if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos); match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
bool is_stop_full = false;
bool is_stop_full = false;
bool is_grammar_trigger = false;
size_t length = slot.generated_text.size();
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
if (match.is_grammar_trigger && gpt_sampler_trigger_grammar(model, slot.smpl, match.pattern)) {
is_grammar_trigger = true;
length = pos + match.pos + match.matchLength;
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
slot.stopped_word = true;
slot.stopping_word = match.pattern;
slot.has_next_token = false;
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true; is_stop_full = true;
slot.generated_text.erase( // length = pos + match.pos;
slot.generated_text.begin() + pos + stop_pos, length = match.pos;
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
} else {
is_stop_full = false;
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
} }
slot.generated_text.erase(
slot.generated_text.begin() + length,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, length);
// check if there is any token to predict // check if there is any token to predict
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
// no send the stop word in the response // no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos); result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size(); slot.n_sent_text += result.text_to_send.size();
@ -1243,7 +1247,8 @@ struct server_context {
{"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta}, {"mirostat_eta", slot.sparams.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl}, {"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.antiprompts.stop_words},
{"grammar_trigger", slot.antiprompts.grammar_trigger_words},
{"max_tokens", slot.params.n_predict}, // User configured n_predict {"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard}, {"n_discard", slot.params.n_discard},

View file

@ -196,20 +196,15 @@ static size_t common_part(const std::string & a, const std::string & b) {
return i; return i;
} }
static bool ends_with(const std::string & str, const std::string & suffix) { static size_t find_partial_stop_string(const std::string & stop, const std::string & text) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) { if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back(); auto it = std::find(stop.rbegin(), stop.rend(), text.back());
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { while (it != stop.rend()) {
if (stop[char_index] == text_last_char) { size_t length = std::distance(it, stop.rend());
const std::string current_partial = stop.substr(0, char_index + 1); if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) {
if (ends_with(text, current_partial)) { return text.length() - length;
return text.size() - char_index - 1;
}
} }
it = std::find(std::next(it), stop.rend(), text.back());
} }
} }

View file

@ -1121,7 +1121,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
} }
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
llama_grammar_accept_str(grammar, piece);
}
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;

View file

@ -142,3 +142,7 @@ void llama_grammar_apply_impl(
void llama_grammar_accept_impl( void llama_grammar_accept_impl(
struct llama_grammar & grammar, struct llama_grammar & grammar,
llama_token token); llama_token token);
void llama_grammar_accept_str(
struct llama_grammar & grammar,
const std::string & piece);

View file

@ -193,6 +193,12 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
} }
} }
void llama_sampler_accept_str(struct llama_sampler * smpl, const char * piece) {
if (smpl->iface->accept_str) {
smpl->iface->accept_str(smpl, piece);
}
}
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
GGML_ASSERT(smpl->iface->apply); GGML_ASSERT(smpl->iface->apply);
smpl->iface->apply(smpl, cur_p); smpl->iface->apply(smpl, cur_p);
@ -325,6 +331,7 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_chain_i = { static struct llama_sampler_i llama_sampler_chain_i = {
/* .name = */ llama_sampler_chain_name, /* .name = */ llama_sampler_chain_name,
/* .accept = */ llama_sampler_chain_accept, /* .accept = */ llama_sampler_chain_accept,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_chain_apply, /* .apply = */ llama_sampler_chain_apply,
/* .reset = */ llama_sampler_chain_reset, /* .reset = */ llama_sampler_chain_reset,
/* .clone = */ llama_sampler_chain_clone, /* .clone = */ llama_sampler_chain_clone,
@ -399,6 +406,7 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
static struct llama_sampler_i llama_sampler_greedy_i = { static struct llama_sampler_i llama_sampler_greedy_i = {
/* .name = */ llama_sampler_greedy_name, /* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply, /* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ nullptr, /* .clone = */ nullptr,
@ -457,6 +465,7 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_dist_i = { static struct llama_sampler_i llama_sampler_dist_i = {
/* .name = */ llama_sampler_dist_name, /* .name = */ llama_sampler_dist_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_dist_apply, /* .apply = */ llama_sampler_dist_apply,
/* .reset = */ llama_sampler_dist_reset, /* .reset = */ llama_sampler_dist_reset,
/* .clone = */ llama_sampler_dist_clone, /* .clone = */ llama_sampler_dist_clone,
@ -488,6 +497,7 @@ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_t
static struct llama_sampler_i llama_sampler_softmax_i = { static struct llama_sampler_i llama_sampler_softmax_i = {
/* .name = */ llama_sampler_softmax_name, /* .name = */ llama_sampler_softmax_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_softmax_apply, /* .apply = */ llama_sampler_softmax_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ nullptr, /* .clone = */ nullptr,
@ -528,6 +538,7 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_top_k_i = { static struct llama_sampler_i llama_sampler_top_k_i = {
/* .name = */ llama_sampler_top_k_name, /* .name = */ llama_sampler_top_k_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_top_k_apply, /* .apply = */ llama_sampler_top_k_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_k_clone, /* .clone = */ llama_sampler_top_k_clone,
@ -594,6 +605,7 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_top_p_i = { static struct llama_sampler_i llama_sampler_top_p_i = {
/* .name = */ llama_sampler_top_p_name, /* .name = */ llama_sampler_top_p_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_top_p_apply, /* .apply = */ llama_sampler_top_p_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_p_clone, /* .clone = */ llama_sampler_top_p_clone,
@ -690,6 +702,7 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_min_p_i = { static struct llama_sampler_i llama_sampler_min_p_i = {
/* .name = */ llama_sampler_min_p_name, /* .name = */ llama_sampler_min_p_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_min_p_apply, /* .apply = */ llama_sampler_min_p_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_min_p_clone, /* .clone = */ llama_sampler_min_p_clone,
@ -785,6 +798,7 @@ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_tail_free_i = { static struct llama_sampler_i llama_sampler_tail_free_i = {
/* .name = */ llama_sampler_tail_free_name, /* .name = */ llama_sampler_tail_free_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_tail_free_apply, /* .apply = */ llama_sampler_tail_free_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_tail_free_clone, /* .clone = */ llama_sampler_tail_free_clone,
@ -884,6 +898,7 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_typical_i = { static struct llama_sampler_i llama_sampler_typical_i = {
/* .name = */ llama_sampler_typical_name, /* .name = */ llama_sampler_typical_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_typical_apply, /* .apply = */ llama_sampler_typical_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_typical_clone, /* .clone = */ llama_sampler_typical_clone,
@ -929,6 +944,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_temp_i = { static struct llama_sampler_i llama_sampler_temp_i = {
/* .name = */ llama_sampler_temp_name, /* .name = */ llama_sampler_temp_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_temp_apply, /* .apply = */ llama_sampler_temp_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_clone, /* .clone = */ llama_sampler_temp_clone,
@ -1042,6 +1058,7 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_temp_ext_i = { static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ llama_sampler_temp_ext_name, /* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_temp_ext_apply, /* .apply = */ llama_sampler_temp_ext_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_ext_clone, /* .clone = */ llama_sampler_temp_ext_clone,
@ -1145,6 +1162,7 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_mirostat_i = { static struct llama_sampler_i llama_sampler_mirostat_i = {
/* .name = */ llama_sampler_mirostat_name, /* .name = */ llama_sampler_mirostat_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_mirostat_apply, /* .apply = */ llama_sampler_mirostat_apply,
/* .reset = */ llama_sampler_mirostat_reset, /* .reset = */ llama_sampler_mirostat_reset,
/* .clone = */ llama_sampler_mirostat_clone, /* .clone = */ llama_sampler_mirostat_clone,
@ -1244,6 +1262,7 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_mirostat_v2_i = { static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
/* .name = */ llama_sampler_mirostat_v2_name, /* .name = */ llama_sampler_mirostat_v2_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_mirostat_v2_apply, /* .apply = */ llama_sampler_mirostat_v2_apply,
/* .reset = */ llama_sampler_mirostat_v2_reset, /* .reset = */ llama_sampler_mirostat_v2_reset,
/* .clone = */ llama_sampler_mirostat_v2_clone, /* .clone = */ llama_sampler_mirostat_v2_clone,
@ -1287,6 +1306,13 @@ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama
} }
} }
static void llama_sampler_grammar_accept_str(struct llama_sampler * smpl, const char * piece) {
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) {
llama_grammar_accept_str(*ctx->grammar, piece);
}
}
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_grammar *) smpl->ctx; auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) { if (ctx->grammar) {
@ -1339,6 +1365,7 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_grammar_i = { static struct llama_sampler_i llama_sampler_grammar_i = {
/* .name = */ llama_sampler_grammar_name, /* .name = */ llama_sampler_grammar_name,
/* .accept = */ llama_sampler_grammar_accept_impl, /* .accept = */ llama_sampler_grammar_accept_impl,
/* .accept_str = */ llama_sampler_grammar_accept_str,
/* .apply = */ llama_sampler_grammar_apply, /* .apply = */ llama_sampler_grammar_apply,
/* .reset = */ llama_sampler_grammar_reset, /* .reset = */ llama_sampler_grammar_reset,
/* .clone = */ llama_sampler_grammar_clone, /* .clone = */ llama_sampler_grammar_clone,
@ -1522,6 +1549,7 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_penalties_i = { static struct llama_sampler_i llama_sampler_penalties_i = {
/* .name = */ llama_sampler_penalties_name, /* .name = */ llama_sampler_penalties_name,
/* .accept = */ llama_sampler_penalties_accept, /* .accept = */ llama_sampler_penalties_accept,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_penalties_apply, /* .apply = */ llama_sampler_penalties_apply,
/* .reset = */ llama_sampler_penalties_reset, /* .reset = */ llama_sampler_penalties_reset,
/* .clone = */ llama_sampler_penalties_clone, /* .clone = */ llama_sampler_penalties_clone,
@ -1624,6 +1652,7 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
static struct llama_sampler_i llama_sampler_logit_bias_i = { static struct llama_sampler_i llama_sampler_logit_bias_i = {
/* .name = */ llama_sampler_logit_bias_name, /* .name = */ llama_sampler_logit_bias_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .accept_str = */ nullptr,
/* .apply = */ llama_sampler_logit_bias_apply, /* .apply = */ llama_sampler_logit_bias_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_logit_bias_clone, /* .clone = */ llama_sampler_logit_bias_clone,

View file

@ -122,6 +122,7 @@ llama_target_and_test(test-grad0.cpp)
llama_target_and_test(test-barrier.cpp) llama_target_and_test(test-barrier.cpp)
# llama_target_and_test(test-opt.cpp) # SLOW # llama_target_and_test(test-opt.cpp) # SLOW
llama_target_and_test(test-backend-ops.cpp) llama_target_and_test(test-backend-ops.cpp)
llama_target_and_test(test-antiprompts.cpp)
llama_target_and_test(test-rope.cpp) llama_target_and_test(test-rope.cpp)

View file

@ -0,0 +1,88 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include "llama.h"
#include "common.h"
#include <cassert>
template <typename T>
void assert_equal(const T & actual, const T & expected) {
if (expected == actual) return;
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
assert(expected == actual);
}
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
int main()
{
auto tokenizer = [&](const std::string & text) {
std::vector<llama_token> tokens;
for (size_t i = 0; i < text.length(); ++i) {
tokens.push_back(text[i]);
}
return tokens;
};
const std::vector<std::string> stop_words { };
const std::vector<std::string> grammar_trigger_words { };
printf("Testing antiprompts\n");
llama_antiprompts antiprompts;
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
assert_equal(antiprompts.findSingleTokenMatch('x'), {
.pos = 0,
.pattern = "x",
.is_partial = false,
.matchLength = 1,
.is_grammar_trigger = true,
});
assert_equal(antiprompts.findSingleTokenMatch('a'), {
.pos = std::string::npos,
.pattern = "",
.is_partial = false,
.matchLength = 0,
.is_grammar_trigger = false,
});
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
.pos = 1,
.pattern = "",
.is_partial = true,
.matchLength = 2,
.is_grammar_trigger = false,
});
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
.pos = 1,
.pattern = "abc",
.is_partial = false,
.matchLength = 3,
.is_grammar_trigger = false,
});
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
.pos = 1,
.pattern = "",
.is_partial = true,
.matchLength = 2,
.is_grammar_trigger = false,
});
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
.pos = 1,
.pattern = "bcd",
.is_partial = false,
.matchLength = 3,
.is_grammar_trigger = false,
});
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
.pos = 1,
.pattern = "bca",
.is_partial = false,
.matchLength = 3,
.is_grammar_trigger = true,
});
printf("OK\n");
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
return 0;
}