From dbf841b0d29a1d2d9e5a9bacf94c0603959ed67f Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 01:25:54 +0000 Subject: [PATCH] Push laziness down to grammar impl --- Makefile | 1 - common/common.h | 212 +-------------------- common/sampling.cpp | 24 ++- common/sampling.h | 2 - examples/agent/.gitignore | 3 - examples/gbnf-validator/gbnf-validator.cpp | 2 +- examples/main/main.cpp | 66 ++++--- examples/server/server.cpp | 99 +++++----- include/llama.h | 6 +- src/llama-grammar.cpp | 69 ++++++- src/llama-grammar.h | 15 +- src/llama-sampling.cpp | 34 +++- tests/CMakeLists.txt | 1 - tests/test-antiprompts.cpp | 109 ----------- tests/test-grammar-integration.cpp | 2 +- tests/test-tool-call.cpp | 2 +- 16 files changed, 224 insertions(+), 423 deletions(-) delete mode 100644 examples/agent/.gitignore delete mode 100644 tests/test-antiprompts.cpp diff --git a/Makefile b/Makefile index 400f1d1e4..50dc14fa6 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,6 @@ TEST_TARGETS = \ tests/test-grammar-integration \ tests/test-grammar-parser \ tests/test-json-schema-to-grammar \ - tests/test-minja \ tests/test-llama-grammar \ tests/test-log \ tests/test-model-load-cancel \ diff --git a/common/common.h b/common/common.h index 19c1bada0..75a189de6 100644 --- a/common/common.h +++ b/common/common.h @@ -158,7 +158,8 @@ struct common_params_sampling { }; std::string grammar; // optional BNF-like grammar to constrain sampling - std::vector grammar_trigger_words; // optional trigger words to enable grammar + std::vector grammar_trigger_words; // optional trigger words to enable grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to enable grammar std::vector logit_bias; // logit biases to apply @@ -687,215 +688,6 @@ struct common_control_vector_load_info { // On error, returns {-1, empty} common_control_vector_data common_control_vector_load(const std::vector & load_infos); -// -// Antiprompt utils -// - -class llama_antiprompts { - public: - - struct llama_antiprompt { - std::string value; - bool is_grammar_trigger; - }; - - std::vector stop_words; - std::vector grammar_triggers; - -private: - // The Aho–Corasick algorithm allows efficient string matching with multiple patterns. - // See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm - struct TrieNode { - std::unordered_map children; - TrieNode* fail = nullptr; - int output = -1; - size_t depth = 0; - - ~TrieNode() { - clear(); - } - - void clear() { - for (auto & pair : children) { - delete pair.second; - } - children.clear(); - fail = nullptr; - output = -1; - depth = 0; - } - }; - - TrieNode root; - std::vector antiprompts; - std::unordered_map stop_tokens; // Single token antiprompts (and their index in antiprompts), if any. - - void build_trie() { - // root = std::unique_ptr(new TrieNode()); - for (size_t i = 0; i < antiprompts.size(); ++i) { - TrieNode* node = &root; - const auto & pattern = antiprompts[i].value; - for (size_t j = 0; j < pattern.length(); ++j) { - char c = pattern[j]; - auto it = node->children.find(c); - if (it != node->children.end()) { - node = it->second; - } else { - node = node->children[c] = new TrieNode(); - } - if (node->depth == 0) { - node->depth = j + 1; - } - } - node->output = i; - } - } - - void build_failure_and_dict_links() { - std::queue q; - for (auto& child : root.children) { - child.second->fail = &root; - q.push(child.second); - } - - while (!q.empty()) { - auto node = q.front(); - q.pop(); - - for (auto & pair : node->children) { - auto & c = pair.first; - auto & child = pair.second; - auto f = node->fail; - - while (f != &root && f->children.find(c) == f->children.end()) { - f = f->fail; - } - - child->fail = (f == &root && f->children.find(c) == f->children.end()) - ? &root : f->children[c]; - - if (child->fail->output != -1) { - child->output = child->fail->output; - } - - q.push(child); - } - } - } - - public: - - bool empty() const { - return antiprompts.empty() && stop_tokens.empty(); - } - void clear() { - root.clear(); - antiprompts.clear(); - stop_tokens.clear(); - } - - void build(const llama_context * ctx, const std::vector & stop_words, const std::vector & grammar_triggers) { - build( - [&](const std::string & text) { - return common_tokenize(ctx, text, /* special= */ true); - }, - stop_words, - grammar_triggers - ); - } - - void build(const std::function(const std::string &)> & tokenizer, const std::vector & stop_words, const std::vector & grammar_triggers) { - clear(); - this->stop_words = stop_words; - this->grammar_triggers = grammar_triggers; - - for (const std::string & stop_word : stop_words) { - antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false}); - } - for (const std::string & trigger : grammar_triggers) { - antiprompts.push_back({trigger, /* is_grammar_trigger= */ true}); - } - - for (size_t i = 0, n = antiprompts.size(); i < n; i++) { - const auto & antiprompt = antiprompts[i]; - std::vector tokens = tokenizer(antiprompt.value); - if (tokens.size() == 1) { - stop_tokens[tokens[0]] = i; - } - } - - build_trie(); - build_failure_and_dict_links(); - } - - struct MatchResult { - size_t pos; - std::string pattern; - bool is_partial; - size_t matchLength; - bool is_grammar_trigger; - - bool operator==(const MatchResult & other) const { - return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger; - } - operator std::string() const { - return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}"; - } - }; - - MatchResult findSingleTokenMatch(llama_token token) const { - auto it = stop_tokens.find(token); - if (it != stop_tokens.end()) { - const auto & antiprompt = antiprompts[it->second]; - return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger}; - } - return {std::string::npos, "", false, 0, false}; - } - - MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { - TrieNode* current = &root; - MatchResult partialMatch{std::string::npos, "", true, 0, false}; - auto text_length = text.length(); - - for (size_t i = offset; i < text_length; ++i) { - char c = text[i]; - while (current != &root && current->children.find(c) == current->children.end()) { - current = current->fail; - } - auto it = current->children.find(c); - if (it != current->children.end()) { - current = it->second; - } - if (current->output != -1) { - const auto & antiprompt = antiprompts[current->output]; - return { - i - antiprompt.value.length() + 1, - antiprompt.value, - false, - antiprompt.value.length(), - antiprompt.is_grammar_trigger, - }; - } - // Update partial match if we're at a deeper node - if (current->depth > partialMatch.matchLength) { - partialMatch.pos = i - current->depth + 1; - partialMatch.pattern = ""; // We don't know which pattern it partially matches - partialMatch.matchLength = current->depth; - partialMatch.is_grammar_trigger = false; - } - } - - // If we've found a partial match and haven't returned a full match, return the partial match - if (partialMatch.pos != std::string::npos) { - if (partialMatch.pos + partialMatch.matchLength == text_length) { - return partialMatch; - } - } - - return {std::string::npos, "", false, 0, false}; - } -}; - // // Split utils // diff --git a/common/sampling.cpp b/common/sampling.cpp index 66d8052c5..78c4061f2 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -144,15 +144,6 @@ std::string common_params_sampling::print() const { return std::string(result); } -bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger) { - if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) { - return false; - } - gsmpl->grmr = llama_sampler_init_grammar(vocab, gsmpl->params.grammar.c_str(), "root"); - llama_sampler_accept_str(gsmpl->grmr, trigger.c_str()); - return true; -} - struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -160,9 +151,22 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; + std::vector c_trigger_words; + c_trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + c_trigger_words.push_back(str.c_str()); + } auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"), + /* .grmr = */ llama_sampler_init_grammar( + vocab, + params.grammar.c_str(), + "root", + c_trigger_words.data(), + c_trigger_words.size(), + params.grammar_trigger_tokens.data(), + params.grammar_trigger_tokens.size() + ), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, diff --git a/common/sampling.h b/common/sampling.h index e7c0a3dce..348911b18 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -100,7 +100,5 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, char common_sampler_type_to_chr(enum common_sampler_type cnstr); std::string common_sampler_type_to_str(enum common_sampler_type cnstr); -bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger); - std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); diff --git a/examples/agent/.gitignore b/examples/agent/.gitignore deleted file mode 100644 index f65f2615f..000000000 --- a/examples/agent/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -squid/ssl_cert/ -squid/ssl_db/ -squid/cache/ diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 17a0e27c4..83cc71817 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -76,7 +76,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); if (grammar == nullptr) { fprintf(stdout, "Failed to initialize llama_grammar\n"); return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e49172bde..821eb0b03 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -38,7 +38,7 @@ static llama_model ** g_model; static common_sampler ** g_smpl; static common_params * g_params; static std::vector * g_input_tokens; -static std::string * g_output_s; +static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; static bool need_insert_eot = false; @@ -494,8 +494,7 @@ int main(int argc, char ** argv) { std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; - std::string output_s; g_output_s = &output_s; - size_t last_partial_stop = std::string::npos; + std::ostringstream output_ss; g_output_ss = &output_ss; std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode // the first thing we will do is to output the prompt, so set color accordingly @@ -504,8 +503,16 @@ int main(int argc, char ** argv) { std::vector embd; - llama_antiprompts antiprompts; - antiprompts.build(ctx, params.antiprompt, {}); + // single-token antiprompts + std::vector antiprompt_single_token; + + antiprompt_single_token.reserve(params.antiprompt.size()); + for (const std::string & antiprompt : params.antiprompt) { + auto ids = ::common_tokenize(ctx, antiprompt, false, true); + if (ids.size() == 1) { + antiprompt_single_token.push_back(ids[0]); + } + } if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -710,7 +717,7 @@ int main(int argc, char ** argv) { } else { // Outgoing Generated Tokens output_tokens.push_back(id); - output_s.append(token_str); + output_ss << token_str; } } } @@ -723,34 +730,41 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { - // check for reverse prompt - if (!antiprompts.empty()) { + // check for reverse prompt in the last n_prev tokens + if (!params.antiprompt.empty()) { + const int n_prev = 32; + const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev); + is_antiprompt = false; + // Check if each of the reverse prompts appears at the end of the output. + // If we're not running interactively, the reverse prompt might be tokenized with some following characters + // so we'll compensate for that by widening the search window a bit. + for (std::string & antiprompt : params.antiprompt) { + size_t extra_padding = params.interactive ? 0 : 2; + size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; + } + } // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - auto match = antiprompts.findSingleTokenMatch(last_token); - if (match.pos != std::string::npos) { + if (std::find(antiprompt_single_token.begin(), antiprompt_single_token.end(), last_token) != antiprompt_single_token.end()) { if (params.interactive) { is_interacting = true; } is_antiprompt = true; - } else { - match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop); - if (match.pos != std::string::npos) { - if (match.is_partial) { - last_partial_stop = match.pos; - } else { - if (params.interactive) { - is_interacting = true; - } - is_antiprompt = true; - } - } } if (is_antiprompt) { - LOG_DBG("found antiprompt: %s\n", match.pattern.c_str()); + LOG_DBG("found antiprompt: %s\n", last_output.c_str()); } } @@ -759,9 +773,9 @@ int main(int argc, char ** argv) { LOG_DBG("found an EOG token\n"); if (params.interactive) { - if (!antiprompts.stop_words.empty()) { + if (!params.antiprompt.empty()) { // tokenize and inject first reverse prompt - const auto first_antiprompt = common_tokenize(ctx, antiprompts.stop_words.front(), false, true); + const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } @@ -855,7 +869,7 @@ int main(int argc, char ** argv) { for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_s.append(common_token_to_piece(ctx, token)); + output_ss << common_token_to_piece(ctx, token); } // reset assistant message diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 214a93a9c..10e8a1bdb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -389,7 +389,15 @@ struct server_task { { const auto grammar_trigger_words = data.find("grammar_trigger_words"); if (grammar_trigger_words != data.end()) { - params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words); + auto words = to_string_vec(*grammar_trigger_words); + for (const auto & word : params.sampling.grammar_trigger_words) { + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + params.sampling.grammar_trigger_tokens.push_back(ids[0]); + continue; + } + params.sampling.grammar_trigger_words.push_back(word); + } } } @@ -1224,8 +1232,6 @@ struct server_slot { std::string stopping_word; - llama_antiprompts antiprompts; - // sampling json json_schema; @@ -1329,6 +1335,35 @@ struct server_slot { return timings; } + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string & word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = find_partial_stop_string(word, text); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + void print_timings() const { const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; @@ -1976,11 +2011,6 @@ struct server_context { slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } - { - slot.antiprompts.clear(); - slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words); - } - { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); @@ -2016,25 +2046,13 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { - auto match = slot.antiprompts.findSingleTokenMatch(result.tok); - // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + // TODO: // const std::string token_str = result.text_to_send; - const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger)); + // const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger)); slot.sampled = result.tok; - if (match.pos != std::string::npos && !match.is_partial) { - if (match.is_grammar_trigger) { - common_sampler_trigger_grammar(vocab, slot.smpl, token_str); - } else { - // slot.stopped_word = true; - slot.stopping_word = match.pattern; - slot.has_next_token = false; - return false; - } - } - - // search stop word and delete it slot.generated_text += token_str; if (slot.params.return_tokens) { slot.generated_tokens.push_back(result.tok); @@ -2048,33 +2066,22 @@ struct server_context { if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - match = slot.antiprompts.findFirstMatch(slot.generated_text, pos); + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; - bool is_stop_full = false; - bool is_grammar_trigger = false; - size_t length = slot.generated_text.size(); - - // If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar - if (match.is_grammar_trigger && common_sampler_trigger_grammar(vocab, slot.smpl, match.pattern)) { - is_grammar_trigger = true; - length = match.pos + match.matchLength; - } else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) { - // slot.stopped_word = true; - slot.stopping_word = match.pattern; - slot.has_next_token = false; - - is_stop_full = true; - // length = pos + match.pos; - length = match.pos; + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } - slot.generated_text.erase( - slot.generated_text.begin() + length, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, length); - // check if there is any token to predict - if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); diff --git a/include/llama.h b/include/llama.h index 4e63cd61a..f6217d98c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1199,7 +1199,11 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root); + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bc6c255b3..b02c4e3cc 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -960,10 +960,26 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1035,10 +1051,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1055,6 +1092,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, }; // redirect elements in stacks to point to new rules @@ -1115,6 +1156,28 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token)); + return; + } else { + grammar.trigger_buffer += grammar.vocab->token_to_piece(token); + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos == std::string::npos) { + continue; + } + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + llama_grammar_accept_str(grammar, constrained_str); + grammar.trigger_buffer.clear(); + return; + } + return; + } + } + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { diff --git a/src/llama-grammar.h b/src/llama-grammar.h index e2425b8f3..d96a685e2 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #include #include @@ -114,6 +115,11 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + bool awaiting_trigger; + std::string trigger_buffer; + std::vector trigger_tokens; + std::vector trigger_words; }; // @@ -127,7 +133,14 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 22cf5d76c..387ec6567 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1465,7 +1465,18 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { return; } - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + std::vector trigger_words; + for (auto & word : ctx->grammar->trigger_words) { + trigger_words.push_back(word.c_str()); + } + auto * grammar_new = llama_grammar_init_impl( + ctx->grammar->vocab, + ctx->grammar_str.c_str(), + ctx->grammar_root.c_str(), + trigger_words.data(), + trigger_words.size(), + ctx->grammar->trigger_tokens.data(), + ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; @@ -1474,7 +1485,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0); // copy the state { @@ -1511,15 +1522,24 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; -struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { +// struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { - /* .vocab = */ vocab, - /* .grammar_str = */ grammar_str, - /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root), + /* .vocab = */ vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cee622c59..b1c43da98 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -133,7 +133,6 @@ llama_target_and_test(test-chat-template.cpp) # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-gguf.cpp) llama_target_and_test(test-backend-ops.cpp) -llama_target_and_test(test-antiprompts.cpp) llama_target_and_test(test-tool-call.cpp) llama_target_and_test(test-model-load-cancel.cpp LABEL "model") diff --git a/tests/test-antiprompts.cpp b/tests/test-antiprompts.cpp deleted file mode 100644 index 4fa688a39..000000000 --- a/tests/test-antiprompts.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#ifdef NDEBUG -#undef NDEBUG -#endif - -#include "llama.h" -#include "common.h" - -#include - -template -void assert_equal(const T & actual, const T & expected) { - if (expected == actual) return; - printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str()); - assert(expected == actual); -} - -// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts -int main() -{ - auto tokenizer = [&](const std::string & text) { - std::vector tokens; - for (size_t i = 0; i < text.length(); ++i) { - tokens.push_back(text[i]); - } - return tokens; - }; - const std::vector stop_words { }; - const std::vector grammar_trigger_words { }; - - printf("Testing antiprompts\n"); - - llama_antiprompts antiprompts; - antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"}); - - assert_equal(antiprompts.findSingleTokenMatch('x'), { - /* .pos = */ 0, - /* .pattern = */ "x", - /* .is_partial = */ false, - /* .matchLength = */ 1, - /* .is_grammar_trigger = */ true, - }); - assert_equal(antiprompts.findSingleTokenMatch('a'), { - /* .pos = */ std::string::npos, - /* .pattern = */ "", - /* .is_partial = */ false, - /* .matchLength = */ 0, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" ab", 0), { - /* .pos = */ 1, - /* .pattern = */ "", - /* .is_partial = */ true, - /* .matchLength = */ 2, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" abc", 0), { - /* .pos = */ 1, - /* .pattern = */ "abc", - /* .is_partial = */ false, - /* .matchLength = */ 3, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" ab c", 0), { - /* .pos = */ std::string::npos, - /* .pattern = */ "", - /* .is_partial = */ false, - /* .matchLength = */ 0, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" abc abc", 0), { - /* .pos = */ 1, - /* .pattern = */ "abc", - /* .is_partial = */ false, - /* .matchLength = */ 3, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" ab abc", 0), { - /* .pos = */ 4, - /* .pattern = */ "abc", - /* .is_partial = */ false, - /* .matchLength = */ 3, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" bc", 0), { - /* .pos = */ 1, - /* .pattern = */ "", - /* .is_partial = */ true, - /* .matchLength = */ 2, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" bcd", 0), { - /* .pos = */ 1, - /* .pattern = */ "bcd", - /* .is_partial = */ false, - /* .matchLength = */ 3, - /* .is_grammar_trigger = */ false, - }); - assert_equal(antiprompts.findFirstMatch(" bca", 0), { - /* .pos = */ 1, - /* .pattern = */ "bca", - /* .is_partial = */ false, - /* .matchLength = */ 3, - /* .is_grammar_trigger = */ true, - }); - printf("OK\n"); - // llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false}); - - return 0; -} diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index e1bdbb925..60169dfd6 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -13,7 +13,7 @@ using json = nlohmann::ordered_json; static llama_grammar * build_grammar(const std::string & grammar_str) { - return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0); } static bool test_build_grammar_fails(const std::string & grammar_str) { diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 95762395b..b25d6c91e 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -37,7 +37,7 @@ static std::string read_file(const std::string &path) { } static std::unique_ptr build_grammar(const std::string & grammar_str) { - return std::unique_ptr(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root")); + return std::unique_ptr(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0)); } // TODO: extract to common helper (copied from test-grammar-integration.cpp)