grammar
: trigger words + refactor of antiprompts
This commit is contained in:
parent
70392f1f81
commit
5b6d5040d5
12 changed files with 436 additions and 108 deletions
6
Makefile
6
Makefile
|
@ -44,6 +44,7 @@ BUILD_TARGETS = \
|
|||
|
||||
# Binaries only useful for tests
|
||||
TEST_TARGETS = \
|
||||
tests/test-antiprompts \
|
||||
tests/test-arg-parser \
|
||||
tests/test-autorelease \
|
||||
tests/test-backend-ops \
|
||||
|
@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
|||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-antiprompts: tests/test-antiprompts.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-grad0: tests/test-grad0.cpp \
|
||||
$(OBJ_GGML)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
|
198
common/common.h
198
common/common.h
|
@ -4,9 +4,11 @@
|
|||
|
||||
#include "llama.h"
|
||||
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
|
@ -134,6 +136,7 @@ struct gpt_sampler_params {
|
|||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
|
@ -533,6 +536,201 @@ struct llama_control_vector_load_info {
|
|||
// On error, returns {-1, empty}
|
||||
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
|
||||
|
||||
//
|
||||
// Antiprompt utils
|
||||
//
|
||||
|
||||
class llama_antiprompts {
|
||||
public:
|
||||
|
||||
struct llama_antiprompt {
|
||||
std::string value;
|
||||
bool is_grammar_trigger;
|
||||
};
|
||||
|
||||
std::vector<std::string> stop_words;
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
|
||||
private:
|
||||
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
|
||||
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
|
||||
struct TrieNode {
|
||||
std::unordered_map<char, TrieNode> children;
|
||||
TrieNode* fail = nullptr;
|
||||
int output = -1;
|
||||
size_t depth = 0;
|
||||
|
||||
void clear() {
|
||||
children.clear();
|
||||
fail = nullptr;
|
||||
output = -1;
|
||||
depth = 0;
|
||||
}
|
||||
};
|
||||
|
||||
TrieNode root;
|
||||
std::vector<llama_antiprompt> antiprompts;
|
||||
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
|
||||
|
||||
void build_trie() {
|
||||
// root = std::unique_ptr<TrieNode>(new TrieNode());
|
||||
for (size_t i = 0; i < antiprompts.size(); ++i) {
|
||||
TrieNode* node = &root;
|
||||
const auto & pattern = antiprompts[i].value;
|
||||
for (size_t j = 0; j < pattern.length(); ++j) {
|
||||
char c = pattern[j];
|
||||
auto & child = node->children[c];
|
||||
if (child.depth == 0) {
|
||||
child.depth = j + 1;
|
||||
}
|
||||
node = &child;
|
||||
}
|
||||
node->output = i;
|
||||
}
|
||||
}
|
||||
|
||||
void build_failure_and_dict_links() {
|
||||
std::queue<TrieNode*> q;
|
||||
for (auto& child : root.children) {
|
||||
child.second.fail = &root;
|
||||
q.push(&child.second);
|
||||
}
|
||||
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
q.pop();
|
||||
|
||||
for (auto & pair : node->children) {
|
||||
auto & c = pair.first;
|
||||
auto & child = pair.second;
|
||||
auto f = node->fail;
|
||||
|
||||
while (f != &root && f->children.find(c) == f->children.end()) {
|
||||
f = f->fail;
|
||||
}
|
||||
|
||||
child.fail = (f == &root && f->children.find(c) == f->children.end())
|
||||
? &root : &f->children[c];
|
||||
|
||||
if (child.fail->output != -1) {
|
||||
child.output = child.fail->output;
|
||||
}
|
||||
|
||||
q.push(&child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
bool empty() const {
|
||||
return antiprompts.empty() && stop_tokens.empty();
|
||||
}
|
||||
void clear() {
|
||||
root.clear();
|
||||
antiprompts.clear();
|
||||
stop_tokens.clear();
|
||||
}
|
||||
|
||||
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||
build(
|
||||
[&](const std::string & text) {
|
||||
return llama_tokenize(ctx, text, /* special= */ true);
|
||||
},
|
||||
stop_words,
|
||||
grammar_trigger_words
|
||||
);
|
||||
}
|
||||
|
||||
void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||
clear();
|
||||
this->stop_words = stop_words;
|
||||
this->grammar_trigger_words = grammar_trigger_words;
|
||||
|
||||
for (const std::string & stop_word : stop_words) {
|
||||
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
|
||||
}
|
||||
for (const std::string & trigger : grammar_trigger_words) {
|
||||
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
|
||||
}
|
||||
|
||||
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
|
||||
const auto & antiprompt = antiprompts[i];
|
||||
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
|
||||
if (tokens.size() == 1) {
|
||||
stop_tokens[tokens[0]] = i;
|
||||
}
|
||||
}
|
||||
|
||||
build_trie();
|
||||
build_failure_and_dict_links();
|
||||
}
|
||||
|
||||
struct MatchResult {
|
||||
size_t pos;
|
||||
std::string pattern;
|
||||
bool is_partial;
|
||||
size_t matchLength;
|
||||
bool is_grammar_trigger;
|
||||
|
||||
bool operator==(const MatchResult & other) const {
|
||||
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
|
||||
}
|
||||
operator std::string() const {
|
||||
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
|
||||
}
|
||||
};
|
||||
|
||||
MatchResult findSingleTokenMatch(llama_token token) const {
|
||||
auto it = stop_tokens.find(token);
|
||||
if (it != stop_tokens.end()) {
|
||||
const auto & antiprompt = antiprompts[it->second];
|
||||
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
|
||||
}
|
||||
return {std::string::npos, "", false, 0, false};
|
||||
}
|
||||
|
||||
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
|
||||
TrieNode* current = &root;
|
||||
MatchResult partialMatch{std::string::npos, "", true, 0, false};
|
||||
|
||||
for (size_t i = offset; i < text.length(); ++i) {
|
||||
char c = text[i];
|
||||
while (current != &root && current->children.find(c) == current->children.end()) {
|
||||
current = current->fail;
|
||||
}
|
||||
auto it = current->children.find(c);
|
||||
if (it != current->children.end()) {
|
||||
current = &it->second;
|
||||
}
|
||||
if (current->output != -1) {
|
||||
const auto & antiprompt = antiprompts[current->output];
|
||||
return {
|
||||
i - antiprompt.value.length() + 1,
|
||||
antiprompt.value,
|
||||
false,
|
||||
antiprompt.value.length(),
|
||||
antiprompt.is_grammar_trigger,
|
||||
};
|
||||
}
|
||||
// Update partial match if we're at a deeper node
|
||||
if (current->depth > partialMatch.matchLength) {
|
||||
partialMatch.pos = i - current->depth + 1;
|
||||
partialMatch.pattern = ""; // We don't know which pattern it partially matches
|
||||
partialMatch.matchLength = current->depth;
|
||||
partialMatch.is_grammar_trigger = false;
|
||||
}
|
||||
}
|
||||
|
||||
// If we've found a partial match and haven't returned a full match, return the partial match
|
||||
if (partialMatch.pos != std::string::npos) {
|
||||
return partialMatch;
|
||||
}
|
||||
|
||||
return {std::string::npos, "", false, 0, false};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Split utils
|
||||
//
|
||||
|
|
|
@ -139,6 +139,15 @@ std::string gpt_sampler_params::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
|
||||
if (gsmpl->grmr) {
|
||||
return false;
|
||||
}
|
||||
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
|
||||
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
||||
|
@ -146,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
|
||||
auto * result = new gpt_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
||||
/* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
|
@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
|
||||
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
if (gsmpl->grmr) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
}
|
||||
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
|
|
|
@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n
|
|||
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
|
||||
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
|
||||
|
||||
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger);
|
||||
|
||||
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
|
||||
|
|
|
@ -36,7 +36,7 @@ static llama_model ** g_model;
|
|||
static gpt_sampler ** g_smpl;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
static std::string * g_output_s;
|
||||
static std::vector<llama_token> * g_output_tokens;
|
||||
static bool is_interacting = false;
|
||||
static bool need_insert_eot = false;
|
||||
|
@ -115,7 +115,7 @@ static void sigint_handler(int signo) {
|
|||
console::cleanup();
|
||||
LOG("\n");
|
||||
gpt_perf_print(*g_ctx, *g_smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, *g_output_s, *g_output_tokens);
|
||||
|
||||
// make sure all logs are flushed
|
||||
LOG("Interrupted by user\n");
|
||||
|
@ -507,7 +507,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||
std::ostringstream output_ss; g_output_ss = &output_ss;
|
||||
std::string output_s; g_output_s = &output_s;
|
||||
size_t last_partial_stop = std::string::npos;
|
||||
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
||||
|
||||
// the first thing we will do is to output the prompt, so set color accordingly
|
||||
|
@ -516,13 +517,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
// tokenized antiprompts
|
||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||
|
||||
antiprompt_ids.reserve(params.antiprompt.size());
|
||||
for (const std::string & antiprompt : params.antiprompt) {
|
||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
llama_antiprompts antiprompts;
|
||||
antiprompts.build(ctx, params.antiprompt, {});
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
|
@ -727,7 +723,7 @@ int main(int argc, char ** argv) {
|
|||
} else {
|
||||
// Outgoing Generated Tokens
|
||||
output_tokens.push_back(id);
|
||||
output_ss << token_str;
|
||||
output_s.append(token_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -740,44 +736,34 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// check for reverse prompt in the last n_prev tokens
|
||||
if (!params.antiprompt.empty()) {
|
||||
const int n_prev = 32;
|
||||
const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
|
||||
|
||||
// check for reverse prompt
|
||||
if (!antiprompts.empty()) {
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
||||
// so we'll compensate for that by widening the search window a bit.
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
size_t extra_padding = params.interactive ? 0 : 2;
|
||||
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
: 0;
|
||||
|
||||
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = gpt_sampler_last(smpl);
|
||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||
if (ids.size() == 1 && last_token == ids[0]) {
|
||||
auto match = antiprompts.findSingleTokenMatch(last_token);
|
||||
if (match.pos != std::string::npos) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
} else {
|
||||
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
|
||||
if (match.pos != std::string::npos) {
|
||||
if (match.is_partial) {
|
||||
last_partial_stop = match.pos;
|
||||
} else {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_antiprompt) {
|
||||
LOG_DBG("found antiprompt: %s\n", last_output.c_str());
|
||||
LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -786,9 +772,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
if (!params.antiprompt.empty()) {
|
||||
if (!antiprompts.stop_words.empty()) {
|
||||
// tokenize and inject first reverse prompt
|
||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
|
||||
const auto first_antiprompt = ::llama_tokenize(ctx, antiprompts.stop_words.front(), false, true);
|
||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
@ -882,7 +868,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||
const llama_token token = embd_inp[i];
|
||||
output_tokens.push_back(token);
|
||||
output_ss << llama_token_to_piece(ctx, token);
|
||||
output_s.append(llama_token_to_piece(ctx, token));
|
||||
}
|
||||
|
||||
// reset assistant message
|
||||
|
@ -926,7 +912,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG("\n\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
write_logfile(ctx, params, model, input_tokens, output_s, output_tokens);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
|
|
|
@ -131,8 +131,6 @@ struct slot_params {
|
|||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
};
|
||||
|
@ -183,6 +181,8 @@ struct server_slot {
|
|||
std::string oaicompat_model;
|
||||
std::string stopping_word;
|
||||
|
||||
llama_antiprompts antiprompts;
|
||||
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
|
@ -281,34 +281,6 @@ struct server_slot {
|
|||
};
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
for (const std::string & word : params.antiprompt) {
|
||||
size_t pos;
|
||||
|
||||
if (type == STOP_TYPE_FULL) {
|
||||
const size_t tmp = word.size() + last_token_size;
|
||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||
|
||||
pos = text.find(word, from_pos);
|
||||
} else {
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
|
||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||
if (type == STOP_TYPE_FULL) {
|
||||
stopped_word = true;
|
||||
stopping_word = word;
|
||||
has_next_token = false;
|
||||
}
|
||||
stop_pos = pos;
|
||||
}
|
||||
}
|
||||
|
||||
return stop_pos;
|
||||
}
|
||||
|
||||
void print_timings() const {
|
||||
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
||||
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||
|
@ -999,16 +971,26 @@ struct server_context {
|
|||
}
|
||||
|
||||
{
|
||||
slot.params.antiprompt.clear();
|
||||
slot.antiprompts.clear();
|
||||
|
||||
const auto & stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
for (const auto & word : *stop) {
|
||||
if (!word.empty()) {
|
||||
slot.params.antiprompt.push_back(word);
|
||||
auto copy_string_array = [&](const json & data, const std::string & key, std::vector<std::string> & vec) {
|
||||
const auto & arr = data.find(key);
|
||||
if (arr != data.end() && arr->is_array()) {
|
||||
for (const auto & word : *arr) {
|
||||
if (word.is_string()) {
|
||||
vec.push_back(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::string> stop_words;
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
|
||||
copy_string_array(data, "stop", stop_words);
|
||||
copy_string_array(data, "grammar_trigger_words", grammar_trigger_words);
|
||||
|
||||
slot.antiprompts.build(ctx, stop_words, grammar_trigger_words);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -1110,6 +1092,18 @@ struct server_context {
|
|||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
|
||||
if (match.pos != std::string::npos && !match.is_partial) {
|
||||
if (match.is_grammar_trigger) {
|
||||
gpt_sampler_trigger_grammar(model, slot.smpl, llama_token_to_piece(ctx, result.tok, params.special));
|
||||
} else {
|
||||
slot.stopped_word = true;
|
||||
slot.stopping_word = match.pattern;
|
||||
slot.has_next_token = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
@ -1139,23 +1133,33 @@ struct server_context {
|
|||
if (!incomplete) {
|
||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
bool is_stop_full = false;
|
||||
match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
|
||||
|
||||
bool is_stop_full = false;
|
||||
bool is_grammar_trigger = false;
|
||||
size_t length = slot.generated_text.size();
|
||||
|
||||
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
|
||||
if (match.is_grammar_trigger && gpt_sampler_trigger_grammar(model, slot.smpl, match.pattern)) {
|
||||
is_grammar_trigger = true;
|
||||
length = pos + match.pos + match.matchLength;
|
||||
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
|
||||
slot.stopped_word = true;
|
||||
slot.stopping_word = match.pattern;
|
||||
slot.has_next_token = false;
|
||||
|
||||
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
|
||||
if (stop_pos != std::string::npos) {
|
||||
is_stop_full = true;
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + pos + stop_pos,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
} else {
|
||||
is_stop_full = false;
|
||||
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
|
||||
// length = pos + match.pos;
|
||||
length = match.pos;
|
||||
}
|
||||
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + length,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.n_sent_text, length);
|
||||
|
||||
// check if there is any token to predict
|
||||
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
|
||||
if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
|
||||
// no send the stop word in the response
|
||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.n_sent_text += result.text_to_send.size();
|
||||
|
@ -1243,7 +1247,8 @@ struct server_context {
|
|||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"stop", slot.antiprompts.stop_words},
|
||||
{"grammar_trigger", slot.antiprompts.grammar_trigger_words},
|
||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
|
|
|
@ -196,20 +196,15 @@ static size_t common_part(const std::string & a, const std::string & b) {
|
|||
return i;
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
static size_t find_partial_stop_string(const std::string & stop, const std::string & text) {
|
||||
if (!text.empty() && !stop.empty()) {
|
||||
const char text_last_char = text.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||
if (ends_with(text, current_partial)) {
|
||||
return text.size() - char_index - 1;
|
||||
}
|
||||
auto it = std::find(stop.rbegin(), stop.rend(), text.back());
|
||||
while (it != stop.rend()) {
|
||||
size_t length = std::distance(it, stop.rend());
|
||||
if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) {
|
||||
return text.length() - length;
|
||||
}
|
||||
it = std::find(std::next(it), stop.rend(), text.back());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1121,7 +1121,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
}
|
||||
|
||||
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
|
|
@ -142,3 +142,7 @@ void llama_grammar_apply_impl(
|
|||
void llama_grammar_accept_impl(
|
||||
struct llama_grammar & grammar,
|
||||
llama_token token);
|
||||
|
||||
void llama_grammar_accept_str(
|
||||
struct llama_grammar & grammar,
|
||||
const std::string & piece);
|
||||
|
|
|
@ -193,6 +193,12 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_sampler_accept_str(struct llama_sampler * smpl, const char * piece) {
|
||||
if (smpl->iface->accept_str) {
|
||||
smpl->iface->accept_str(smpl, piece);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(smpl->iface->apply);
|
||||
smpl->iface->apply(smpl, cur_p);
|
||||
|
@ -325,6 +331,7 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||
/* .name = */ llama_sampler_chain_name,
|
||||
/* .accept = */ llama_sampler_chain_accept,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_chain_apply,
|
||||
/* .reset = */ llama_sampler_chain_reset,
|
||||
/* .clone = */ llama_sampler_chain_clone,
|
||||
|
@ -399,6 +406,7 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|||
static struct llama_sampler_i llama_sampler_greedy_i = {
|
||||
/* .name = */ llama_sampler_greedy_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_greedy_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
|
@ -457,6 +465,7 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_dist_i = {
|
||||
/* .name = */ llama_sampler_dist_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_dist_apply,
|
||||
/* .reset = */ llama_sampler_dist_reset,
|
||||
/* .clone = */ llama_sampler_dist_clone,
|
||||
|
@ -488,6 +497,7 @@ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_t
|
|||
static struct llama_sampler_i llama_sampler_softmax_i = {
|
||||
/* .name = */ llama_sampler_softmax_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_softmax_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
|
@ -528,6 +538,7 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||
/* .name = */ llama_sampler_top_k_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_k_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_k_clone,
|
||||
|
@ -594,6 +605,7 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_top_p_i = {
|
||||
/* .name = */ llama_sampler_top_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_p_clone,
|
||||
|
@ -690,6 +702,7 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_min_p_i = {
|
||||
/* .name = */ llama_sampler_min_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_min_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_min_p_clone,
|
||||
|
@ -785,6 +798,7 @@ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
||||
/* .name = */ llama_sampler_tail_free_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_tail_free_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_tail_free_clone,
|
||||
|
@ -884,6 +898,7 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_typical_i = {
|
||||
/* .name = */ llama_sampler_typical_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_typical_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_typical_clone,
|
||||
|
@ -929,6 +944,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||
/* .name = */ llama_sampler_temp_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_clone,
|
||||
|
@ -1042,6 +1058,7 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||
/* .name = */ llama_sampler_temp_ext_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_ext_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_ext_clone,
|
||||
|
@ -1145,6 +1162,7 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||
/* .name = */ llama_sampler_mirostat_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_clone,
|
||||
|
@ -1244,6 +1262,7 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
||||
/* .name = */ llama_sampler_mirostat_v2_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
||||
|
@ -1287,6 +1306,13 @@ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama
|
|||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_grammar_accept_str(struct llama_sampler * smpl, const char * piece) {
|
||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
llama_grammar_accept_str(*ctx->grammar, piece);
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
|
@ -1339,6 +1365,7 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_grammar_i = {
|
||||
/* .name = */ llama_sampler_grammar_name,
|
||||
/* .accept = */ llama_sampler_grammar_accept_impl,
|
||||
/* .accept_str = */ llama_sampler_grammar_accept_str,
|
||||
/* .apply = */ llama_sampler_grammar_apply,
|
||||
/* .reset = */ llama_sampler_grammar_reset,
|
||||
/* .clone = */ llama_sampler_grammar_clone,
|
||||
|
@ -1522,6 +1549,7 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_penalties_i = {
|
||||
/* .name = */ llama_sampler_penalties_name,
|
||||
/* .accept = */ llama_sampler_penalties_accept,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_penalties_apply,
|
||||
/* .reset = */ llama_sampler_penalties_reset,
|
||||
/* .clone = */ llama_sampler_penalties_clone,
|
||||
|
@ -1624,6 +1652,7 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
||||
/* .name = */ llama_sampler_logit_bias_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .accept_str = */ nullptr,
|
||||
/* .apply = */ llama_sampler_logit_bias_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_logit_bias_clone,
|
||||
|
|
|
@ -122,6 +122,7 @@ llama_target_and_test(test-grad0.cpp)
|
|||
llama_target_and_test(test-barrier.cpp)
|
||||
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||
llama_target_and_test(test-backend-ops.cpp)
|
||||
llama_target_and_test(test-antiprompts.cpp)
|
||||
|
||||
llama_target_and_test(test-rope.cpp)
|
||||
|
||||
|
|
88
tests/test-antiprompts.cpp
Normal file
88
tests/test-antiprompts.cpp
Normal file
|
@ -0,0 +1,88 @@
|
|||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
template <typename T>
|
||||
void assert_equal(const T & actual, const T & expected) {
|
||||
if (expected == actual) return;
|
||||
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
|
||||
assert(expected == actual);
|
||||
}
|
||||
|
||||
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
|
||||
int main()
|
||||
{
|
||||
auto tokenizer = [&](const std::string & text) {
|
||||
std::vector<llama_token> tokens;
|
||||
for (size_t i = 0; i < text.length(); ++i) {
|
||||
tokens.push_back(text[i]);
|
||||
}
|
||||
return tokens;
|
||||
};
|
||||
const std::vector<std::string> stop_words { };
|
||||
const std::vector<std::string> grammar_trigger_words { };
|
||||
|
||||
printf("Testing antiprompts\n");
|
||||
|
||||
llama_antiprompts antiprompts;
|
||||
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
|
||||
|
||||
assert_equal(antiprompts.findSingleTokenMatch('x'), {
|
||||
.pos = 0,
|
||||
.pattern = "x",
|
||||
.is_partial = false,
|
||||
.matchLength = 1,
|
||||
.is_grammar_trigger = true,
|
||||
});
|
||||
assert_equal(antiprompts.findSingleTokenMatch('a'), {
|
||||
.pos = std::string::npos,
|
||||
.pattern = "",
|
||||
.is_partial = false,
|
||||
.matchLength = 0,
|
||||
.is_grammar_trigger = false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
|
||||
.pos = 1,
|
||||
.pattern = "",
|
||||
.is_partial = true,
|
||||
.matchLength = 2,
|
||||
.is_grammar_trigger = false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
|
||||
.pos = 1,
|
||||
.pattern = "abc",
|
||||
.is_partial = false,
|
||||
.matchLength = 3,
|
||||
.is_grammar_trigger = false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
|
||||
.pos = 1,
|
||||
.pattern = "",
|
||||
.is_partial = true,
|
||||
.matchLength = 2,
|
||||
.is_grammar_trigger = false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
|
||||
.pos = 1,
|
||||
.pattern = "bcd",
|
||||
.is_partial = false,
|
||||
.matchLength = 3,
|
||||
.is_grammar_trigger = false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
|
||||
.pos = 1,
|
||||
.pattern = "bca",
|
||||
.is_partial = false,
|
||||
.matchLength = 3,
|
||||
.is_grammar_trigger = true,
|
||||
});
|
||||
printf("OK\n");
|
||||
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue