Push laziness down to grammar impl

This commit is contained in:
Olivier Chafik 2025-01-22 01:25:54 +00:00
parent 77f4098c83
commit dbf841b0d2
16 changed files with 224 additions and 423 deletions

View file

@ -58,7 +58,6 @@ TEST_TARGETS = \
tests/test-grammar-integration \
tests/test-grammar-parser \
tests/test-json-schema-to-grammar \
tests/test-minja \
tests/test-llama-grammar \
tests/test-log \
tests/test-model-load-cancel \

View file

@ -158,7 +158,8 @@ struct common_params_sampling {
};
std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@ -687,215 +688,6 @@ struct common_control_vector_load_info {
// On error, returns {-1, empty}
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
//
// Antiprompt utils
//
class llama_antiprompts {
public:
struct llama_antiprompt {
std::string value;
bool is_grammar_trigger;
};
std::vector<std::string> stop_words;
std::vector<std::string> grammar_triggers;
private:
// The AhoCorasick algorithm allows efficient string matching with multiple patterns.
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
struct TrieNode {
std::unordered_map<char, TrieNode*> children;
TrieNode* fail = nullptr;
int output = -1;
size_t depth = 0;
~TrieNode() {
clear();
}
void clear() {
for (auto & pair : children) {
delete pair.second;
}
children.clear();
fail = nullptr;
output = -1;
depth = 0;
}
};
TrieNode root;
std::vector<llama_antiprompt> antiprompts;
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
void build_trie() {
// root = std::unique_ptr<TrieNode>(new TrieNode());
for (size_t i = 0; i < antiprompts.size(); ++i) {
TrieNode* node = &root;
const auto & pattern = antiprompts[i].value;
for (size_t j = 0; j < pattern.length(); ++j) {
char c = pattern[j];
auto it = node->children.find(c);
if (it != node->children.end()) {
node = it->second;
} else {
node = node->children[c] = new TrieNode();
}
if (node->depth == 0) {
node->depth = j + 1;
}
}
node->output = i;
}
}
void build_failure_and_dict_links() {
std::queue<TrieNode*> q;
for (auto& child : root.children) {
child.second->fail = &root;
q.push(child.second);
}
while (!q.empty()) {
auto node = q.front();
q.pop();
for (auto & pair : node->children) {
auto & c = pair.first;
auto & child = pair.second;
auto f = node->fail;
while (f != &root && f->children.find(c) == f->children.end()) {
f = f->fail;
}
child->fail = (f == &root && f->children.find(c) == f->children.end())
? &root : f->children[c];
if (child->fail->output != -1) {
child->output = child->fail->output;
}
q.push(child);
}
}
}
public:
bool empty() const {
return antiprompts.empty() && stop_tokens.empty();
}
void clear() {
root.clear();
antiprompts.clear();
stop_tokens.clear();
}
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
build(
[&](const std::string & text) {
return common_tokenize(ctx, text, /* special= */ true);
},
stop_words,
grammar_triggers
);
}
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
clear();
this->stop_words = stop_words;
this->grammar_triggers = grammar_triggers;
for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
}
for (const std::string & trigger : grammar_triggers) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
}
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
const auto & antiprompt = antiprompts[i];
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
if (tokens.size() == 1) {
stop_tokens[tokens[0]] = i;
}
}
build_trie();
build_failure_and_dict_links();
}
struct MatchResult {
size_t pos;
std::string pattern;
bool is_partial;
size_t matchLength;
bool is_grammar_trigger;
bool operator==(const MatchResult & other) const {
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
}
operator std::string() const {
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
}
};
MatchResult findSingleTokenMatch(llama_token token) const {
auto it = stop_tokens.find(token);
if (it != stop_tokens.end()) {
const auto & antiprompt = antiprompts[it->second];
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
}
return {std::string::npos, "", false, 0, false};
}
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
TrieNode* current = &root;
MatchResult partialMatch{std::string::npos, "", true, 0, false};
auto text_length = text.length();
for (size_t i = offset; i < text_length; ++i) {
char c = text[i];
while (current != &root && current->children.find(c) == current->children.end()) {
current = current->fail;
}
auto it = current->children.find(c);
if (it != current->children.end()) {
current = it->second;
}
if (current->output != -1) {
const auto & antiprompt = antiprompts[current->output];
return {
i - antiprompt.value.length() + 1,
antiprompt.value,
false,
antiprompt.value.length(),
antiprompt.is_grammar_trigger,
};
}
// Update partial match if we're at a deeper node
if (current->depth > partialMatch.matchLength) {
partialMatch.pos = i - current->depth + 1;
partialMatch.pattern = ""; // We don't know which pattern it partially matches
partialMatch.matchLength = current->depth;
partialMatch.is_grammar_trigger = false;
}
}
// If we've found a partial match and haven't returned a full match, return the partial match
if (partialMatch.pos != std::string::npos) {
if (partialMatch.pos + partialMatch.matchLength == text_length) {
return partialMatch;
}
}
return {std::string::npos, "", false, 0, false};
}
};
//
// Split utils
//

View file

@ -144,15 +144,6 @@ std::string common_params_sampling::print() const {
return std::string(result);
}
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger) {
if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) {
return false;
}
gsmpl->grmr = llama_sampler_init_grammar(vocab, gsmpl->params.grammar.c_str(), "root");
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
return true;
}
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
@ -160,9 +151,22 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf;
std::vector<const char *> c_trigger_words;
c_trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
c_trigger_words.push_back(str.c_str());
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"),
/* .grmr = */ llama_sampler_init_grammar(
vocab,
params.grammar.c_str(),
"root",
c_trigger_words.data(),
c_trigger_words.size(),
params.grammar_trigger_tokens.data(),
params.grammar_trigger_tokens.size()
),
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},

View file

@ -100,7 +100,5 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);

View file

@ -1,3 +0,0 @@
squid/ssl_cert/
squid/ssl_db/
squid/cache/

View file

@ -76,7 +76,7 @@ int main(int argc, char** argv) {
grammar_str = buffer.str();
}
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
if (grammar == nullptr) {
fprintf(stdout, "Failed to initialize llama_grammar\n");
return 1;

View file

@ -38,7 +38,7 @@ static llama_model ** g_model;
static common_sampler ** g_smpl;
static common_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::string * g_output_s;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;
@ -494,8 +494,7 @@ int main(int argc, char ** argv) {
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::string output_s; g_output_s = &output_s;
size_t last_partial_stop = std::string::npos;
std::ostringstream output_ss; g_output_ss = &output_ss;
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
// the first thing we will do is to output the prompt, so set color accordingly
@ -504,8 +503,16 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
llama_antiprompts antiprompts;
antiprompts.build(ctx, params.antiprompt, {});
// single-token antiprompts
std::vector<llama_token> antiprompt_single_token;
antiprompt_single_token.reserve(params.antiprompt.size());
for (const std::string & antiprompt : params.antiprompt) {
auto ids = ::common_tokenize(ctx, antiprompt, false, true);
if (ids.size() == 1) {
antiprompt_single_token.push_back(ids[0]);
}
}
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
@ -710,7 +717,7 @@ int main(int argc, char ** argv) {
} else {
// Outgoing Generated Tokens
output_tokens.push_back(id);
output_s.append(token_str);
output_ss << token_str;
}
}
}
@ -723,34 +730,41 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt
if (!antiprompts.empty()) {
// check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) {
const int n_prev = 32;
const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
// so we'll compensate for that by widening the search window a bit.
for (std::string & antiprompt : params.antiprompt) {
size_t extra_padding = params.interactive ? 0 : 2;
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0;
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
}
}
// check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl);
auto match = antiprompts.findSingleTokenMatch(last_token);
if (match.pos != std::string::npos) {
if (std::find(antiprompt_single_token.begin(), antiprompt_single_token.end(), last_token) != antiprompt_single_token.end()) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
} else {
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
if (match.pos != std::string::npos) {
if (match.is_partial) {
last_partial_stop = match.pos;
} else {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
}
}
}
if (is_antiprompt) {
LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
LOG_DBG("found antiprompt: %s\n", last_output.c_str());
}
}
@ -759,9 +773,9 @@ int main(int argc, char ** argv) {
LOG_DBG("found an EOG token\n");
if (params.interactive) {
if (!antiprompts.stop_words.empty()) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = common_tokenize(ctx, antiprompts.stop_words.front(), false, true);
const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
@ -855,7 +869,7 @@ int main(int argc, char ** argv) {
for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_s.append(common_token_to_piece(ctx, token));
output_ss << common_token_to_piece(ctx, token);
}
// reset assistant message

View file

@ -389,7 +389,15 @@ struct server_task {
{
const auto grammar_trigger_words = data.find("grammar_trigger_words");
if (grammar_trigger_words != data.end()) {
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
auto words = to_string_vec(*grammar_trigger_words);
for (const auto & word : params.sampling.grammar_trigger_words) {
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
params.sampling.grammar_trigger_words.push_back(word);
}
}
}
@ -1224,8 +1232,6 @@ struct server_slot {
std::string stopping_word;
llama_antiprompts antiprompts;
// sampling
json json_schema;
@ -1329,6 +1335,35 @@ struct server_slot {
return timings;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
for (const std::string & word : params.antiprompt) {
size_t pos;
if (is_full_stop) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
} else {
// otherwise, partial stop
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (is_full_stop) {
stop = STOP_TYPE_WORD;
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
void print_timings() const {
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
@ -1976,11 +2011,6 @@ struct server_context {
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
}
{
slot.antiprompts.clear();
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
}
{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
@ -2016,25 +2046,13 @@ struct server_context {
}
bool process_token(completion_token_output & result, server_slot & slot) {
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
// TODO:
// const std::string token_str = result.text_to_send;
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
// const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
slot.sampled = result.tok;
if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(vocab, slot.smpl, token_str);
} else {
// slot.stopped_word = true;
slot.stopping_word = match.pattern;
slot.has_next_token = false;
return false;
}
}
// search stop word and delete it
slot.generated_text += token_str;
if (slot.params.return_tokens) {
slot.generated_tokens.push_back(result.tok);
@ -2048,33 +2066,22 @@ struct server_context {
if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
const std::string str_test = slot.generated_text.substr(pos);
bool send_text = true;
bool is_stop_full = false;
bool is_grammar_trigger = false;
size_t length = slot.generated_text.size();
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
if (match.is_grammar_trigger && common_sampler_trigger_grammar(vocab, slot.smpl, match.pattern)) {
is_grammar_trigger = true;
length = match.pos + match.matchLength;
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
// slot.stopped_word = true;
slot.stopping_word = match.pattern;
slot.has_next_token = false;
is_stop_full = true;
// length = pos + match.pos;
length = match.pos;
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
if (stop_pos != std::string::npos) {
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
} else if (slot.has_next_token) {
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
send_text = stop_pos == std::string::npos;
}
slot.generated_text.erase(
slot.generated_text.begin() + length,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, length);
// check if there is any token to predict
if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
if (send_text) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();

View file

@ -1199,7 +1199,11 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root);
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(

View file

@ -960,10 +960,26 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
return new llama_grammar {
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
/* .trigger_words = */ {},
};
}
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
llama_grammar_parser parser;
// if there is a grammar, parse it
@ -1035,10 +1051,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
}
} while (true);
std::vector<llama_token> vec_trigger_tokens;
std::vector<std::string> vec_trigger_words;
for (size_t i = 0; i < num_trigger_tokens; i++) {
GGML_ASSERT(trigger_tokens != nullptr);
vec_trigger_tokens.push_back(trigger_tokens[i]);
}
for (size_t i = 0; i < num_trigger_words; i++) {
GGML_ASSERT(trigger_words != nullptr);
vec_trigger_words.push_back(trigger_words[i]);
}
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
return new llama_grammar {
vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0,
/* .trigger_buffer = */ "",
std::move(vec_trigger_tokens),
std::move(vec_trigger_words),
};
}
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@ -1055,6 +1092,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_tokens,
grammar.trigger_words,
};
// redirect elements in stacks to point to new rules
@ -1115,6 +1156,28 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
return;
} else {
grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word);
if (pos == std::string::npos) {
continue;
}
grammar.awaiting_trigger = false;
auto constrained_str = grammar.trigger_buffer.substr(pos);
llama_grammar_accept_str(grammar, constrained_str);
grammar.trigger_buffer.clear();
return;
}
return;
}
}
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {

View file

@ -3,6 +3,7 @@
#include "llama.h"
#include <map>
#include <set>
#include <string>
#include <vector>
@ -114,6 +115,11 @@ struct llama_grammar {
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
bool awaiting_trigger;
std::string trigger_buffer;
std::vector<llama_token> trigger_tokens;
std::vector<std::string> trigger_words;
};
//
@ -127,7 +133,14 @@ struct llama_grammar * llama_grammar_init_impl(
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
void llama_grammar_free_impl(struct llama_grammar * grammar);

View file

@ -1465,7 +1465,18 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
return;
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
std::vector<const char *> trigger_words;
for (auto & word : ctx->grammar->trigger_words) {
trigger_words.push_back(word.c_str());
}
auto * grammar_new = llama_grammar_init_impl(
ctx->grammar->vocab,
ctx->grammar_str.c_str(),
ctx->grammar_root.c_str(),
trigger_words.data(),
trigger_words.size(),
ctx->grammar->trigger_tokens.data(),
ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new;
@ -1474,7 +1485,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0);
// copy the state
{
@ -1511,15 +1522,24 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .free = */ llama_sampler_grammar_free,
};
struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
// struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = {
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
};
} else {
*ctx = {

View file

@ -133,7 +133,6 @@ llama_target_and_test(test-chat-template.cpp)
# llama_target_and_test(test-opt.cpp) # SLOW
llama_target_and_test(test-gguf.cpp)
llama_target_and_test(test-backend-ops.cpp)
llama_target_and_test(test-antiprompts.cpp)
llama_target_and_test(test-tool-call.cpp)
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")

View file

@ -1,109 +0,0 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include "llama.h"
#include "common.h"
#include <cassert>
template <typename T>
void assert_equal(const T & actual, const T & expected) {
if (expected == actual) return;
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
assert(expected == actual);
}
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
int main()
{
auto tokenizer = [&](const std::string & text) {
std::vector<llama_token> tokens;
for (size_t i = 0; i < text.length(); ++i) {
tokens.push_back(text[i]);
}
return tokens;
};
const std::vector<std::string> stop_words { };
const std::vector<std::string> grammar_trigger_words { };
printf("Testing antiprompts\n");
llama_antiprompts antiprompts;
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
assert_equal(antiprompts.findSingleTokenMatch('x'), {
/* .pos = */ 0,
/* .pattern = */ "x",
/* .is_partial = */ false,
/* .matchLength = */ 1,
/* .is_grammar_trigger = */ true,
});
assert_equal(antiprompts.findSingleTokenMatch('a'), {
/* .pos = */ std::string::npos,
/* .pattern = */ "",
/* .is_partial = */ false,
/* .matchLength = */ 0,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
/* .pos = */ 1,
/* .pattern = */ "",
/* .is_partial = */ true,
/* .matchLength = */ 2,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
/* .pos = */ 1,
/* .pattern = */ "abc",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" ab c", 0), {
/* .pos = */ std::string::npos,
/* .pattern = */ "",
/* .is_partial = */ false,
/* .matchLength = */ 0,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" abc abc", 0), {
/* .pos = */ 1,
/* .pattern = */ "abc",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" ab abc", 0), {
/* .pos = */ 4,
/* .pattern = */ "abc",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
/* .pos = */ 1,
/* .pattern = */ "",
/* .is_partial = */ true,
/* .matchLength = */ 2,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
/* .pos = */ 1,
/* .pattern = */ "bcd",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ false,
});
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
/* .pos = */ 1,
/* .pattern = */ "bca",
/* .is_partial = */ false,
/* .matchLength = */ 3,
/* .is_grammar_trigger = */ true,
});
printf("OK\n");
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
return 0;
}

View file

@ -13,7 +13,7 @@
using json = nlohmann::ordered_json;
static llama_grammar * build_grammar(const std::string & grammar_str) {
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
}
static bool test_build_grammar_fails(const std::string & grammar_str) {

View file

@ -37,7 +37,7 @@ static std::string read_file(const std::string &path) {
}
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"));
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0));
}
// TODO: extract to common helper (copied from test-grammar-integration.cpp)