Push laziness down to grammar impl
This commit is contained in:
parent
77f4098c83
commit
dbf841b0d2
16 changed files with 224 additions and 423 deletions
1
Makefile
1
Makefile
|
@ -58,7 +58,6 @@ TEST_TARGETS = \
|
|||
tests/test-grammar-integration \
|
||||
tests/test-grammar-parser \
|
||||
tests/test-json-schema-to-grammar \
|
||||
tests/test-minja \
|
||||
tests/test-llama-grammar \
|
||||
tests/test-log \
|
||||
tests/test-model-load-cancel \
|
||||
|
|
212
common/common.h
212
common/common.h
|
@ -158,7 +158,8 @@ struct common_params_sampling {
|
|||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
|
@ -687,215 +688,6 @@ struct common_control_vector_load_info {
|
|||
// On error, returns {-1, empty}
|
||||
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
||||
|
||||
//
|
||||
// Antiprompt utils
|
||||
//
|
||||
|
||||
class llama_antiprompts {
|
||||
public:
|
||||
|
||||
struct llama_antiprompt {
|
||||
std::string value;
|
||||
bool is_grammar_trigger;
|
||||
};
|
||||
|
||||
std::vector<std::string> stop_words;
|
||||
std::vector<std::string> grammar_triggers;
|
||||
|
||||
private:
|
||||
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
|
||||
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
|
||||
struct TrieNode {
|
||||
std::unordered_map<char, TrieNode*> children;
|
||||
TrieNode* fail = nullptr;
|
||||
int output = -1;
|
||||
size_t depth = 0;
|
||||
|
||||
~TrieNode() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
for (auto & pair : children) {
|
||||
delete pair.second;
|
||||
}
|
||||
children.clear();
|
||||
fail = nullptr;
|
||||
output = -1;
|
||||
depth = 0;
|
||||
}
|
||||
};
|
||||
|
||||
TrieNode root;
|
||||
std::vector<llama_antiprompt> antiprompts;
|
||||
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
|
||||
|
||||
void build_trie() {
|
||||
// root = std::unique_ptr<TrieNode>(new TrieNode());
|
||||
for (size_t i = 0; i < antiprompts.size(); ++i) {
|
||||
TrieNode* node = &root;
|
||||
const auto & pattern = antiprompts[i].value;
|
||||
for (size_t j = 0; j < pattern.length(); ++j) {
|
||||
char c = pattern[j];
|
||||
auto it = node->children.find(c);
|
||||
if (it != node->children.end()) {
|
||||
node = it->second;
|
||||
} else {
|
||||
node = node->children[c] = new TrieNode();
|
||||
}
|
||||
if (node->depth == 0) {
|
||||
node->depth = j + 1;
|
||||
}
|
||||
}
|
||||
node->output = i;
|
||||
}
|
||||
}
|
||||
|
||||
void build_failure_and_dict_links() {
|
||||
std::queue<TrieNode*> q;
|
||||
for (auto& child : root.children) {
|
||||
child.second->fail = &root;
|
||||
q.push(child.second);
|
||||
}
|
||||
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
q.pop();
|
||||
|
||||
for (auto & pair : node->children) {
|
||||
auto & c = pair.first;
|
||||
auto & child = pair.second;
|
||||
auto f = node->fail;
|
||||
|
||||
while (f != &root && f->children.find(c) == f->children.end()) {
|
||||
f = f->fail;
|
||||
}
|
||||
|
||||
child->fail = (f == &root && f->children.find(c) == f->children.end())
|
||||
? &root : f->children[c];
|
||||
|
||||
if (child->fail->output != -1) {
|
||||
child->output = child->fail->output;
|
||||
}
|
||||
|
||||
q.push(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
bool empty() const {
|
||||
return antiprompts.empty() && stop_tokens.empty();
|
||||
}
|
||||
void clear() {
|
||||
root.clear();
|
||||
antiprompts.clear();
|
||||
stop_tokens.clear();
|
||||
}
|
||||
|
||||
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
||||
build(
|
||||
[&](const std::string & text) {
|
||||
return common_tokenize(ctx, text, /* special= */ true);
|
||||
},
|
||||
stop_words,
|
||||
grammar_triggers
|
||||
);
|
||||
}
|
||||
|
||||
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
||||
clear();
|
||||
this->stop_words = stop_words;
|
||||
this->grammar_triggers = grammar_triggers;
|
||||
|
||||
for (const std::string & stop_word : stop_words) {
|
||||
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
|
||||
}
|
||||
for (const std::string & trigger : grammar_triggers) {
|
||||
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
|
||||
}
|
||||
|
||||
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
|
||||
const auto & antiprompt = antiprompts[i];
|
||||
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
|
||||
if (tokens.size() == 1) {
|
||||
stop_tokens[tokens[0]] = i;
|
||||
}
|
||||
}
|
||||
|
||||
build_trie();
|
||||
build_failure_and_dict_links();
|
||||
}
|
||||
|
||||
struct MatchResult {
|
||||
size_t pos;
|
||||
std::string pattern;
|
||||
bool is_partial;
|
||||
size_t matchLength;
|
||||
bool is_grammar_trigger;
|
||||
|
||||
bool operator==(const MatchResult & other) const {
|
||||
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
|
||||
}
|
||||
operator std::string() const {
|
||||
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
|
||||
}
|
||||
};
|
||||
|
||||
MatchResult findSingleTokenMatch(llama_token token) const {
|
||||
auto it = stop_tokens.find(token);
|
||||
if (it != stop_tokens.end()) {
|
||||
const auto & antiprompt = antiprompts[it->second];
|
||||
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
|
||||
}
|
||||
return {std::string::npos, "", false, 0, false};
|
||||
}
|
||||
|
||||
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
|
||||
TrieNode* current = &root;
|
||||
MatchResult partialMatch{std::string::npos, "", true, 0, false};
|
||||
auto text_length = text.length();
|
||||
|
||||
for (size_t i = offset; i < text_length; ++i) {
|
||||
char c = text[i];
|
||||
while (current != &root && current->children.find(c) == current->children.end()) {
|
||||
current = current->fail;
|
||||
}
|
||||
auto it = current->children.find(c);
|
||||
if (it != current->children.end()) {
|
||||
current = it->second;
|
||||
}
|
||||
if (current->output != -1) {
|
||||
const auto & antiprompt = antiprompts[current->output];
|
||||
return {
|
||||
i - antiprompt.value.length() + 1,
|
||||
antiprompt.value,
|
||||
false,
|
||||
antiprompt.value.length(),
|
||||
antiprompt.is_grammar_trigger,
|
||||
};
|
||||
}
|
||||
// Update partial match if we're at a deeper node
|
||||
if (current->depth > partialMatch.matchLength) {
|
||||
partialMatch.pos = i - current->depth + 1;
|
||||
partialMatch.pattern = ""; // We don't know which pattern it partially matches
|
||||
partialMatch.matchLength = current->depth;
|
||||
partialMatch.is_grammar_trigger = false;
|
||||
}
|
||||
}
|
||||
|
||||
// If we've found a partial match and haven't returned a full match, return the partial match
|
||||
if (partialMatch.pos != std::string::npos) {
|
||||
if (partialMatch.pos + partialMatch.matchLength == text_length) {
|
||||
return partialMatch;
|
||||
}
|
||||
}
|
||||
|
||||
return {std::string::npos, "", false, 0, false};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Split utils
|
||||
//
|
||||
|
|
|
@ -144,15 +144,6 @@ std::string common_params_sampling::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger) {
|
||||
if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) {
|
||||
return false;
|
||||
}
|
||||
gsmpl->grmr = llama_sampler_init_grammar(vocab, gsmpl->params.grammar.c_str(), "root");
|
||||
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
|
@ -160,9 +151,22 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
std::vector<const char *> c_trigger_words;
|
||||
c_trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
c_trigger_words.push_back(str.c_str());
|
||||
}
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"),
|
||||
/* .grmr = */ llama_sampler_init_grammar(
|
||||
vocab,
|
||||
params.grammar.c_str(),
|
||||
"root",
|
||||
c_trigger_words.data(),
|
||||
c_trigger_words.size(),
|
||||
params.grammar_trigger_tokens.data(),
|
||||
params.grammar_trigger_tokens.size()
|
||||
),
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
|
|
|
@ -100,7 +100,5 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
|
|||
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
||||
|
||||
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger);
|
||||
|
||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||
|
|
3
examples/agent/.gitignore
vendored
3
examples/agent/.gitignore
vendored
|
@ -1,3 +0,0 @@
|
|||
squid/ssl_cert/
|
||||
squid/ssl_db/
|
||||
squid/cache/
|
|
@ -76,7 +76,7 @@ int main(int argc, char** argv) {
|
|||
grammar_str = buffer.str();
|
||||
}
|
||||
|
||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
|
||||
if (grammar == nullptr) {
|
||||
fprintf(stdout, "Failed to initialize llama_grammar\n");
|
||||
return 1;
|
||||
|
|
|
@ -38,7 +38,7 @@ static llama_model ** g_model;
|
|||
static common_sampler ** g_smpl;
|
||||
static common_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::string * g_output_s;
|
||||
static std::ostringstream * g_output_ss;
|
||||
static std::vector<llama_token> * g_output_tokens;
|
||||
static bool is_interacting = false;
|
||||
static bool need_insert_eot = false;
|
||||
|
@ -494,8 +494,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||
std::string output_s; g_output_s = &output_s;
|
||||
size_t last_partial_stop = std::string::npos;
|
||||
std::ostringstream output_ss; g_output_ss = &output_ss;
|
||||
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
||||
|
||||
// the first thing we will do is to output the prompt, so set color accordingly
|
||||
|
@ -504,8 +503,16 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
llama_antiprompts antiprompts;
|
||||
antiprompts.build(ctx, params.antiprompt, {});
|
||||
// single-token antiprompts
|
||||
std::vector<llama_token> antiprompt_single_token;
|
||||
|
||||
antiprompt_single_token.reserve(params.antiprompt.size());
|
||||
for (const std::string & antiprompt : params.antiprompt) {
|
||||
auto ids = ::common_tokenize(ctx, antiprompt, false, true);
|
||||
if (ids.size() == 1) {
|
||||
antiprompt_single_token.push_back(ids[0]);
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
|
@ -710,7 +717,7 @@ int main(int argc, char ** argv) {
|
|||
} else {
|
||||
// Outgoing Generated Tokens
|
||||
output_tokens.push_back(id);
|
||||
output_s.append(token_str);
|
||||
output_ss << token_str;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -723,34 +730,41 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// check for reverse prompt
|
||||
if (!antiprompts.empty()) {
|
||||
// check for reverse prompt in the last n_prev tokens
|
||||
if (!params.antiprompt.empty()) {
|
||||
const int n_prev = 32;
|
||||
const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
|
||||
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
||||
// so we'll compensate for that by widening the search window a bit.
|
||||
for (std::string & antiprompt : params.antiprompt) {
|
||||
size_t extra_padding = params.interactive ? 0 : 2;
|
||||
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||
: 0;
|
||||
|
||||
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = common_sampler_last(smpl);
|
||||
auto match = antiprompts.findSingleTokenMatch(last_token);
|
||||
if (match.pos != std::string::npos) {
|
||||
if (std::find(antiprompt_single_token.begin(), antiprompt_single_token.end(), last_token) != antiprompt_single_token.end()) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
} else {
|
||||
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
|
||||
if (match.pos != std::string::npos) {
|
||||
if (match.is_partial) {
|
||||
last_partial_stop = match.pos;
|
||||
} else {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
}
|
||||
is_antiprompt = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_antiprompt) {
|
||||
LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
|
||||
LOG_DBG("found antiprompt: %s\n", last_output.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -759,9 +773,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
if (!antiprompts.stop_words.empty()) {
|
||||
if (!params.antiprompt.empty()) {
|
||||
// tokenize and inject first reverse prompt
|
||||
const auto first_antiprompt = common_tokenize(ctx, antiprompts.stop_words.front(), false, true);
|
||||
const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
|
||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
@ -855,7 +869,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||
const llama_token token = embd_inp[i];
|
||||
output_tokens.push_back(token);
|
||||
output_s.append(common_token_to_piece(ctx, token));
|
||||
output_ss << common_token_to_piece(ctx, token);
|
||||
}
|
||||
|
||||
// reset assistant message
|
||||
|
|
|
@ -389,7 +389,15 @@ struct server_task {
|
|||
{
|
||||
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
||||
if (grammar_trigger_words != data.end()) {
|
||||
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
|
||||
auto words = to_string_vec(*grammar_trigger_words);
|
||||
for (const auto & word : params.sampling.grammar_trigger_words) {
|
||||
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
params.sampling.grammar_trigger_words.push_back(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1224,8 +1232,6 @@ struct server_slot {
|
|||
|
||||
std::string stopping_word;
|
||||
|
||||
llama_antiprompts antiprompts;
|
||||
|
||||
// sampling
|
||||
json json_schema;
|
||||
|
||||
|
@ -1329,6 +1335,35 @@ struct server_slot {
|
|||
return timings;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
for (const std::string & word : params.antiprompt) {
|
||||
size_t pos;
|
||||
|
||||
if (is_full_stop) {
|
||||
const size_t tmp = word.size() + last_token_size;
|
||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||
|
||||
pos = text.find(word, from_pos);
|
||||
} else {
|
||||
// otherwise, partial stop
|
||||
pos = find_partial_stop_string(word, text);
|
||||
}
|
||||
|
||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||
if (is_full_stop) {
|
||||
stop = STOP_TYPE_WORD;
|
||||
stopping_word = word;
|
||||
has_next_token = false;
|
||||
}
|
||||
stop_pos = pos;
|
||||
}
|
||||
}
|
||||
|
||||
return stop_pos;
|
||||
}
|
||||
|
||||
void print_timings() const {
|
||||
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
||||
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||
|
@ -1976,11 +2011,6 @@ struct server_context {
|
|||
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
|
||||
}
|
||||
|
||||
{
|
||||
slot.antiprompts.clear();
|
||||
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.smpl != nullptr) {
|
||||
common_sampler_free(slot.smpl);
|
||||
|
@ -2016,25 +2046,13 @@ struct server_context {
|
|||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
// TODO:
|
||||
// const std::string token_str = result.text_to_send;
|
||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
||||
// const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
||||
slot.sampled = result.tok;
|
||||
|
||||
if (match.pos != std::string::npos && !match.is_partial) {
|
||||
if (match.is_grammar_trigger) {
|
||||
common_sampler_trigger_grammar(vocab, slot.smpl, token_str);
|
||||
} else {
|
||||
// slot.stopped_word = true;
|
||||
slot.stopping_word = match.pattern;
|
||||
slot.has_next_token = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
if (slot.params.return_tokens) {
|
||||
slot.generated_tokens.push_back(result.tok);
|
||||
|
@ -2048,33 +2066,22 @@ struct server_context {
|
|||
if (!incomplete) {
|
||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
|
||||
match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
bool send_text = true;
|
||||
|
||||
bool is_stop_full = false;
|
||||
bool is_grammar_trigger = false;
|
||||
size_t length = slot.generated_text.size();
|
||||
|
||||
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
|
||||
if (match.is_grammar_trigger && common_sampler_trigger_grammar(vocab, slot.smpl, match.pattern)) {
|
||||
is_grammar_trigger = true;
|
||||
length = match.pos + match.matchLength;
|
||||
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
|
||||
// slot.stopped_word = true;
|
||||
slot.stopping_word = match.pattern;
|
||||
slot.has_next_token = false;
|
||||
|
||||
is_stop_full = true;
|
||||
// length = pos + match.pos;
|
||||
length = match.pos;
|
||||
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
||||
if (stop_pos != std::string::npos) {
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + pos + stop_pos,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
} else if (slot.has_next_token) {
|
||||
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
||||
send_text = stop_pos == std::string::npos;
|
||||
}
|
||||
|
||||
slot.generated_text.erase(
|
||||
slot.generated_text.begin() + length,
|
||||
slot.generated_text.end());
|
||||
pos = std::min(slot.n_sent_text, length);
|
||||
|
||||
// check if there is any token to predict
|
||||
if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
|
||||
if (send_text) {
|
||||
// no send the stop word in the response
|
||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.n_sent_text += result.text_to_send.size();
|
||||
|
|
|
@ -1199,7 +1199,11 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
|
|
|
@ -960,10 +960,26 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_words = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
llama_grammar_parser parser;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
|
@ -1035,10 +1051,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
|||
}
|
||||
} while (true);
|
||||
|
||||
std::vector<llama_token> vec_trigger_tokens;
|
||||
std::vector<std::string> vec_trigger_words;
|
||||
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
||||
GGML_ASSERT(trigger_tokens != nullptr);
|
||||
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
||||
}
|
||||
for (size_t i = 0; i < num_trigger_words; i++) {
|
||||
GGML_ASSERT(trigger_words != nullptr);
|
||||
vec_trigger_words.push_back(trigger_words[i]);
|
||||
}
|
||||
|
||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0,
|
||||
/* .trigger_buffer = */ "",
|
||||
std::move(vec_trigger_tokens),
|
||||
std::move(vec_trigger_words),
|
||||
};
|
||||
}
|
||||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||
|
@ -1055,6 +1092,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|||
grammar.rules,
|
||||
grammar.stacks,
|
||||
grammar.partial_utf8,
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_tokens,
|
||||
grammar.trigger_words,
|
||||
};
|
||||
|
||||
// redirect elements in stacks to point to new rules
|
||||
|
@ -1115,6 +1156,28 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
|
||||
return;
|
||||
} else {
|
||||
grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
|
||||
for (const auto & word : grammar.trigger_words) {
|
||||
auto pos = grammar.trigger_buffer.find(word);
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
grammar.awaiting_trigger = false;
|
||||
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
grammar.trigger_buffer.clear();
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (grammar.vocab->is_eog(token)) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -114,6 +115,11 @@ struct llama_grammar {
|
|||
|
||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||
llama_partial_utf8 partial_utf8;
|
||||
|
||||
bool awaiting_trigger;
|
||||
std::string trigger_buffer;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
std::vector<std::string> trigger_words;
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -127,7 +133,14 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||
|
||||
|
|
|
@ -1465,7 +1465,18 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
||||
std::vector<const char *> trigger_words;
|
||||
for (auto & word : ctx->grammar->trigger_words) {
|
||||
trigger_words.push_back(word.c_str());
|
||||
}
|
||||
auto * grammar_new = llama_grammar_init_impl(
|
||||
ctx->grammar->vocab,
|
||||
ctx->grammar_str.c_str(),
|
||||
ctx->grammar_root.c_str(),
|
||||
trigger_words.data(),
|
||||
trigger_words.size(),
|
||||
ctx->grammar->trigger_tokens.data(),
|
||||
ctx->grammar->trigger_tokens.size());
|
||||
|
||||
llama_grammar_free_impl(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
|
@ -1474,7 +1485,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
|
||||
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
|
@ -1511,15 +1522,24 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
|||
/* .free = */ llama_sampler_grammar_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
// struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
|
|
|
@ -133,7 +133,6 @@ llama_target_and_test(test-chat-template.cpp)
|
|||
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||
llama_target_and_test(test-gguf.cpp)
|
||||
llama_target_and_test(test-backend-ops.cpp)
|
||||
llama_target_and_test(test-antiprompts.cpp)
|
||||
llama_target_and_test(test-tool-call.cpp)
|
||||
|
||||
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
template <typename T>
|
||||
void assert_equal(const T & actual, const T & expected) {
|
||||
if (expected == actual) return;
|
||||
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
|
||||
assert(expected == actual);
|
||||
}
|
||||
|
||||
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
|
||||
int main()
|
||||
{
|
||||
auto tokenizer = [&](const std::string & text) {
|
||||
std::vector<llama_token> tokens;
|
||||
for (size_t i = 0; i < text.length(); ++i) {
|
||||
tokens.push_back(text[i]);
|
||||
}
|
||||
return tokens;
|
||||
};
|
||||
const std::vector<std::string> stop_words { };
|
||||
const std::vector<std::string> grammar_trigger_words { };
|
||||
|
||||
printf("Testing antiprompts\n");
|
||||
|
||||
llama_antiprompts antiprompts;
|
||||
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
|
||||
|
||||
assert_equal(antiprompts.findSingleTokenMatch('x'), {
|
||||
/* .pos = */ 0,
|
||||
/* .pattern = */ "x",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 1,
|
||||
/* .is_grammar_trigger = */ true,
|
||||
});
|
||||
assert_equal(antiprompts.findSingleTokenMatch('a'), {
|
||||
/* .pos = */ std::string::npos,
|
||||
/* .pattern = */ "",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 0,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "",
|
||||
/* .is_partial = */ true,
|
||||
/* .matchLength = */ 2,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "abc",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 3,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" ab c", 0), {
|
||||
/* .pos = */ std::string::npos,
|
||||
/* .pattern = */ "",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 0,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" abc abc", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "abc",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 3,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" ab abc", 0), {
|
||||
/* .pos = */ 4,
|
||||
/* .pattern = */ "abc",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 3,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "",
|
||||
/* .is_partial = */ true,
|
||||
/* .matchLength = */ 2,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "bcd",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 3,
|
||||
/* .is_grammar_trigger = */ false,
|
||||
});
|
||||
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
|
||||
/* .pos = */ 1,
|
||||
/* .pattern = */ "bca",
|
||||
/* .is_partial = */ false,
|
||||
/* .matchLength = */ 3,
|
||||
/* .is_grammar_trigger = */ true,
|
||||
});
|
||||
printf("OK\n");
|
||||
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -13,7 +13,7 @@
|
|||
using json = nlohmann::ordered_json;
|
||||
|
||||
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
|
||||
}
|
||||
|
||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||
|
|
|
@ -37,7 +37,7 @@ static std::string read_file(const std::string &path) {
|
|||
}
|
||||
|
||||
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
||||
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"));
|
||||
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0));
|
||||
}
|
||||
|
||||
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue