Push laziness down to grammar impl
This commit is contained in:
parent
77f4098c83
commit
dbf841b0d2
16 changed files with 224 additions and 423 deletions
1
Makefile
1
Makefile
|
@ -58,7 +58,6 @@ TEST_TARGETS = \
|
||||||
tests/test-grammar-integration \
|
tests/test-grammar-integration \
|
||||||
tests/test-grammar-parser \
|
tests/test-grammar-parser \
|
||||||
tests/test-json-schema-to-grammar \
|
tests/test-json-schema-to-grammar \
|
||||||
tests/test-minja \
|
|
||||||
tests/test-llama-grammar \
|
tests/test-llama-grammar \
|
||||||
tests/test-log \
|
tests/test-log \
|
||||||
tests/test-model-load-cancel \
|
tests/test-model-load-cancel \
|
||||||
|
|
212
common/common.h
212
common/common.h
|
@ -158,7 +158,8 @@ struct common_params_sampling {
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||||
|
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to enable grammar
|
||||||
|
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
||||||
|
@ -687,215 +688,6 @@ struct common_control_vector_load_info {
|
||||||
// On error, returns {-1, empty}
|
// On error, returns {-1, empty}
|
||||||
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
||||||
|
|
||||||
//
|
|
||||||
// Antiprompt utils
|
|
||||||
//
|
|
||||||
|
|
||||||
class llama_antiprompts {
|
|
||||||
public:
|
|
||||||
|
|
||||||
struct llama_antiprompt {
|
|
||||||
std::string value;
|
|
||||||
bool is_grammar_trigger;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<std::string> stop_words;
|
|
||||||
std::vector<std::string> grammar_triggers;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
|
|
||||||
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
|
|
||||||
struct TrieNode {
|
|
||||||
std::unordered_map<char, TrieNode*> children;
|
|
||||||
TrieNode* fail = nullptr;
|
|
||||||
int output = -1;
|
|
||||||
size_t depth = 0;
|
|
||||||
|
|
||||||
~TrieNode() {
|
|
||||||
clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void clear() {
|
|
||||||
for (auto & pair : children) {
|
|
||||||
delete pair.second;
|
|
||||||
}
|
|
||||||
children.clear();
|
|
||||||
fail = nullptr;
|
|
||||||
output = -1;
|
|
||||||
depth = 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TrieNode root;
|
|
||||||
std::vector<llama_antiprompt> antiprompts;
|
|
||||||
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
|
|
||||||
|
|
||||||
void build_trie() {
|
|
||||||
// root = std::unique_ptr<TrieNode>(new TrieNode());
|
|
||||||
for (size_t i = 0; i < antiprompts.size(); ++i) {
|
|
||||||
TrieNode* node = &root;
|
|
||||||
const auto & pattern = antiprompts[i].value;
|
|
||||||
for (size_t j = 0; j < pattern.length(); ++j) {
|
|
||||||
char c = pattern[j];
|
|
||||||
auto it = node->children.find(c);
|
|
||||||
if (it != node->children.end()) {
|
|
||||||
node = it->second;
|
|
||||||
} else {
|
|
||||||
node = node->children[c] = new TrieNode();
|
|
||||||
}
|
|
||||||
if (node->depth == 0) {
|
|
||||||
node->depth = j + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node->output = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void build_failure_and_dict_links() {
|
|
||||||
std::queue<TrieNode*> q;
|
|
||||||
for (auto& child : root.children) {
|
|
||||||
child.second->fail = &root;
|
|
||||||
q.push(child.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!q.empty()) {
|
|
||||||
auto node = q.front();
|
|
||||||
q.pop();
|
|
||||||
|
|
||||||
for (auto & pair : node->children) {
|
|
||||||
auto & c = pair.first;
|
|
||||||
auto & child = pair.second;
|
|
||||||
auto f = node->fail;
|
|
||||||
|
|
||||||
while (f != &root && f->children.find(c) == f->children.end()) {
|
|
||||||
f = f->fail;
|
|
||||||
}
|
|
||||||
|
|
||||||
child->fail = (f == &root && f->children.find(c) == f->children.end())
|
|
||||||
? &root : f->children[c];
|
|
||||||
|
|
||||||
if (child->fail->output != -1) {
|
|
||||||
child->output = child->fail->output;
|
|
||||||
}
|
|
||||||
|
|
||||||
q.push(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
bool empty() const {
|
|
||||||
return antiprompts.empty() && stop_tokens.empty();
|
|
||||||
}
|
|
||||||
void clear() {
|
|
||||||
root.clear();
|
|
||||||
antiprompts.clear();
|
|
||||||
stop_tokens.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
|
||||||
build(
|
|
||||||
[&](const std::string & text) {
|
|
||||||
return common_tokenize(ctx, text, /* special= */ true);
|
|
||||||
},
|
|
||||||
stop_words,
|
|
||||||
grammar_triggers
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
|
||||||
clear();
|
|
||||||
this->stop_words = stop_words;
|
|
||||||
this->grammar_triggers = grammar_triggers;
|
|
||||||
|
|
||||||
for (const std::string & stop_word : stop_words) {
|
|
||||||
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
|
|
||||||
}
|
|
||||||
for (const std::string & trigger : grammar_triggers) {
|
|
||||||
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
|
|
||||||
const auto & antiprompt = antiprompts[i];
|
|
||||||
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
|
|
||||||
if (tokens.size() == 1) {
|
|
||||||
stop_tokens[tokens[0]] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
build_trie();
|
|
||||||
build_failure_and_dict_links();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MatchResult {
|
|
||||||
size_t pos;
|
|
||||||
std::string pattern;
|
|
||||||
bool is_partial;
|
|
||||||
size_t matchLength;
|
|
||||||
bool is_grammar_trigger;
|
|
||||||
|
|
||||||
bool operator==(const MatchResult & other) const {
|
|
||||||
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
|
|
||||||
}
|
|
||||||
operator std::string() const {
|
|
||||||
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
MatchResult findSingleTokenMatch(llama_token token) const {
|
|
||||||
auto it = stop_tokens.find(token);
|
|
||||||
if (it != stop_tokens.end()) {
|
|
||||||
const auto & antiprompt = antiprompts[it->second];
|
|
||||||
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
|
|
||||||
}
|
|
||||||
return {std::string::npos, "", false, 0, false};
|
|
||||||
}
|
|
||||||
|
|
||||||
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
|
|
||||||
TrieNode* current = &root;
|
|
||||||
MatchResult partialMatch{std::string::npos, "", true, 0, false};
|
|
||||||
auto text_length = text.length();
|
|
||||||
|
|
||||||
for (size_t i = offset; i < text_length; ++i) {
|
|
||||||
char c = text[i];
|
|
||||||
while (current != &root && current->children.find(c) == current->children.end()) {
|
|
||||||
current = current->fail;
|
|
||||||
}
|
|
||||||
auto it = current->children.find(c);
|
|
||||||
if (it != current->children.end()) {
|
|
||||||
current = it->second;
|
|
||||||
}
|
|
||||||
if (current->output != -1) {
|
|
||||||
const auto & antiprompt = antiprompts[current->output];
|
|
||||||
return {
|
|
||||||
i - antiprompt.value.length() + 1,
|
|
||||||
antiprompt.value,
|
|
||||||
false,
|
|
||||||
antiprompt.value.length(),
|
|
||||||
antiprompt.is_grammar_trigger,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// Update partial match if we're at a deeper node
|
|
||||||
if (current->depth > partialMatch.matchLength) {
|
|
||||||
partialMatch.pos = i - current->depth + 1;
|
|
||||||
partialMatch.pattern = ""; // We don't know which pattern it partially matches
|
|
||||||
partialMatch.matchLength = current->depth;
|
|
||||||
partialMatch.is_grammar_trigger = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've found a partial match and haven't returned a full match, return the partial match
|
|
||||||
if (partialMatch.pos != std::string::npos) {
|
|
||||||
if (partialMatch.pos + partialMatch.matchLength == text_length) {
|
|
||||||
return partialMatch;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {std::string::npos, "", false, 0, false};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Split utils
|
// Split utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -144,15 +144,6 @@ std::string common_params_sampling::print() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger) {
|
|
||||||
if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
gsmpl->grmr = llama_sampler_init_grammar(vocab, gsmpl->params.grammar.c_str(), "root");
|
|
||||||
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
@ -160,9 +151,22 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
std::vector<const char *> c_trigger_words;
|
||||||
|
c_trigger_words.reserve(params.grammar_trigger_words.size());
|
||||||
|
for (const auto & str : params.grammar_trigger_words) {
|
||||||
|
c_trigger_words.push_back(str.c_str());
|
||||||
|
}
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"),
|
/* .grmr = */ llama_sampler_init_grammar(
|
||||||
|
vocab,
|
||||||
|
params.grammar.c_str(),
|
||||||
|
"root",
|
||||||
|
c_trigger_words.data(),
|
||||||
|
c_trigger_words.size(),
|
||||||
|
params.grammar_trigger_tokens.data(),
|
||||||
|
params.grammar_trigger_tokens.size()
|
||||||
|
),
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
|
|
|
@ -100,7 +100,5 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
|
||||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
||||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
||||||
|
|
||||||
bool common_sampler_trigger_grammar(const struct llama_vocab * vocab, common_sampler * gsmpl, const std::string & trigger);
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
3
examples/agent/.gitignore
vendored
3
examples/agent/.gitignore
vendored
|
@ -1,3 +0,0 @@
|
||||||
squid/ssl_cert/
|
|
||||||
squid/ssl_db/
|
|
||||||
squid/cache/
|
|
|
@ -76,7 +76,7 @@ int main(int argc, char** argv) {
|
||||||
grammar_str = buffer.str();
|
grammar_str = buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
|
||||||
if (grammar == nullptr) {
|
if (grammar == nullptr) {
|
||||||
fprintf(stdout, "Failed to initialize llama_grammar\n");
|
fprintf(stdout, "Failed to initialize llama_grammar\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -38,7 +38,7 @@ static llama_model ** g_model;
|
||||||
static common_sampler ** g_smpl;
|
static common_sampler ** g_smpl;
|
||||||
static common_params * g_params;
|
static common_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::string * g_output_s;
|
static std::ostringstream * g_output_ss;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
static bool need_insert_eot = false;
|
static bool need_insert_eot = false;
|
||||||
|
@ -494,8 +494,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||||
std::string output_s; g_output_s = &output_s;
|
std::ostringstream output_ss; g_output_ss = &output_ss;
|
||||||
size_t last_partial_stop = std::string::npos;
|
|
||||||
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
||||||
|
|
||||||
// the first thing we will do is to output the prompt, so set color accordingly
|
// the first thing we will do is to output the prompt, so set color accordingly
|
||||||
|
@ -504,8 +503,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
llama_antiprompts antiprompts;
|
// single-token antiprompts
|
||||||
antiprompts.build(ctx, params.antiprompt, {});
|
std::vector<llama_token> antiprompt_single_token;
|
||||||
|
|
||||||
|
antiprompt_single_token.reserve(params.antiprompt.size());
|
||||||
|
for (const std::string & antiprompt : params.antiprompt) {
|
||||||
|
auto ids = ::common_tokenize(ctx, antiprompt, false, true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
antiprompt_single_token.push_back(ids[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
|
@ -710,7 +717,7 @@ int main(int argc, char ** argv) {
|
||||||
} else {
|
} else {
|
||||||
// Outgoing Generated Tokens
|
// Outgoing Generated Tokens
|
||||||
output_tokens.push_back(id);
|
output_tokens.push_back(id);
|
||||||
output_s.append(token_str);
|
output_ss << token_str;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -723,34 +730,41 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// if not currently processing queued inputs;
|
// if not currently processing queued inputs;
|
||||||
if ((int) embd_inp.size() <= n_consumed) {
|
if ((int) embd_inp.size() <= n_consumed) {
|
||||||
// check for reverse prompt
|
// check for reverse prompt in the last n_prev tokens
|
||||||
if (!antiprompts.empty()) {
|
if (!params.antiprompt.empty()) {
|
||||||
|
const int n_prev = 32;
|
||||||
|
const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
|
||||||
|
|
||||||
is_antiprompt = false;
|
is_antiprompt = false;
|
||||||
|
// Check if each of the reverse prompts appears at the end of the output.
|
||||||
|
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
||||||
|
// so we'll compensate for that by widening the search window a bit.
|
||||||
|
for (std::string & antiprompt : params.antiprompt) {
|
||||||
|
size_t extra_padding = params.interactive ? 0 : 2;
|
||||||
|
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||||
|
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
||||||
|
if (params.interactive) {
|
||||||
|
is_interacting = true;
|
||||||
|
}
|
||||||
|
is_antiprompt = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check for reverse prompt using special tokens
|
// check for reverse prompt using special tokens
|
||||||
llama_token last_token = common_sampler_last(smpl);
|
llama_token last_token = common_sampler_last(smpl);
|
||||||
auto match = antiprompts.findSingleTokenMatch(last_token);
|
if (std::find(antiprompt_single_token.begin(), antiprompt_single_token.end(), last_token) != antiprompt_single_token.end()) {
|
||||||
if (match.pos != std::string::npos) {
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
}
|
}
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
} else {
|
|
||||||
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
|
|
||||||
if (match.pos != std::string::npos) {
|
|
||||||
if (match.is_partial) {
|
|
||||||
last_partial_stop = match.pos;
|
|
||||||
} else {
|
|
||||||
if (params.interactive) {
|
|
||||||
is_interacting = true;
|
|
||||||
}
|
|
||||||
is_antiprompt = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_antiprompt) {
|
if (is_antiprompt) {
|
||||||
LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
|
LOG_DBG("found antiprompt: %s\n", last_output.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -759,9 +773,9 @@ int main(int argc, char ** argv) {
|
||||||
LOG_DBG("found an EOG token\n");
|
LOG_DBG("found an EOG token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
if (!antiprompts.stop_words.empty()) {
|
if (!params.antiprompt.empty()) {
|
||||||
// tokenize and inject first reverse prompt
|
// tokenize and inject first reverse prompt
|
||||||
const auto first_antiprompt = common_tokenize(ctx, antiprompts.stop_words.front(), false, true);
|
const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
|
||||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
|
@ -855,7 +869,7 @@ int main(int argc, char ** argv) {
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
const llama_token token = embd_inp[i];
|
const llama_token token = embd_inp[i];
|
||||||
output_tokens.push_back(token);
|
output_tokens.push_back(token);
|
||||||
output_s.append(common_token_to_piece(ctx, token));
|
output_ss << common_token_to_piece(ctx, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset assistant message
|
// reset assistant message
|
||||||
|
|
|
@ -389,7 +389,15 @@ struct server_task {
|
||||||
{
|
{
|
||||||
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
||||||
if (grammar_trigger_words != data.end()) {
|
if (grammar_trigger_words != data.end()) {
|
||||||
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
|
auto words = to_string_vec(*grammar_trigger_words);
|
||||||
|
for (const auto & word : params.sampling.grammar_trigger_words) {
|
||||||
|
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
params.sampling.grammar_trigger_words.push_back(word);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1224,8 +1232,6 @@ struct server_slot {
|
||||||
|
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
llama_antiprompts antiprompts;
|
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
@ -1329,6 +1335,35 @@ struct server_slot {
|
||||||
return timings;
|
return timings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||||
|
size_t stop_pos = std::string::npos;
|
||||||
|
|
||||||
|
for (const std::string & word : params.antiprompt) {
|
||||||
|
size_t pos;
|
||||||
|
|
||||||
|
if (is_full_stop) {
|
||||||
|
const size_t tmp = word.size() + last_token_size;
|
||||||
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
||||||
|
|
||||||
|
pos = text.find(word, from_pos);
|
||||||
|
} else {
|
||||||
|
// otherwise, partial stop
|
||||||
|
pos = find_partial_stop_string(word, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||||
|
if (is_full_stop) {
|
||||||
|
stop = STOP_TYPE_WORD;
|
||||||
|
stopping_word = word;
|
||||||
|
has_next_token = false;
|
||||||
|
}
|
||||||
|
stop_pos = pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stop_pos;
|
||||||
|
}
|
||||||
|
|
||||||
void print_timings() const {
|
void print_timings() const {
|
||||||
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
||||||
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||||
|
@ -1976,11 +2011,6 @@ struct server_context {
|
||||||
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
|
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
slot.antiprompts.clear();
|
|
||||||
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
if (slot.smpl != nullptr) {
|
if (slot.smpl != nullptr) {
|
||||||
common_sampler_free(slot.smpl);
|
common_sampler_free(slot.smpl);
|
||||||
|
@ -2016,25 +2046,13 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||||
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
|
|
||||||
|
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
|
const std::string token_str = result.text_to_send;
|
||||||
|
// TODO:
|
||||||
// const std::string token_str = result.text_to_send;
|
// const std::string token_str = result.text_to_send;
|
||||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
// const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special || (match.pos != std::string::npos && match.is_grammar_trigger));
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
if (match.pos != std::string::npos && !match.is_partial) {
|
|
||||||
if (match.is_grammar_trigger) {
|
|
||||||
common_sampler_trigger_grammar(vocab, slot.smpl, token_str);
|
|
||||||
} else {
|
|
||||||
// slot.stopped_word = true;
|
|
||||||
slot.stopping_word = match.pattern;
|
|
||||||
slot.has_next_token = false;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// search stop word and delete it
|
|
||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
if (slot.params.return_tokens) {
|
if (slot.params.return_tokens) {
|
||||||
slot.generated_tokens.push_back(result.tok);
|
slot.generated_tokens.push_back(result.tok);
|
||||||
|
@ -2048,33 +2066,22 @@ struct server_context {
|
||||||
if (!incomplete) {
|
if (!incomplete) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
bool send_text = true;
|
||||||
|
|
||||||
bool is_stop_full = false;
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
||||||
bool is_grammar_trigger = false;
|
if (stop_pos != std::string::npos) {
|
||||||
size_t length = slot.generated_text.size();
|
slot.generated_text.erase(
|
||||||
|
slot.generated_text.begin() + pos + stop_pos,
|
||||||
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
|
slot.generated_text.end());
|
||||||
if (match.is_grammar_trigger && common_sampler_trigger_grammar(vocab, slot.smpl, match.pattern)) {
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
is_grammar_trigger = true;
|
} else if (slot.has_next_token) {
|
||||||
length = match.pos + match.matchLength;
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
||||||
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
|
send_text = stop_pos == std::string::npos;
|
||||||
// slot.stopped_word = true;
|
|
||||||
slot.stopping_word = match.pattern;
|
|
||||||
slot.has_next_token = false;
|
|
||||||
|
|
||||||
is_stop_full = true;
|
|
||||||
// length = pos + match.pos;
|
|
||||||
length = match.pos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.generated_text.erase(
|
|
||||||
slot.generated_text.begin() + length,
|
|
||||||
slot.generated_text.end());
|
|
||||||
pos = std::min(slot.n_sent_text, length);
|
|
||||||
|
|
||||||
// check if there is any token to predict
|
// check if there is any token to predict
|
||||||
if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
|
if (send_text) {
|
||||||
// no send the stop word in the response
|
// no send the stop word in the response
|
||||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
slot.n_sent_text += result.text_to_send.size();
|
slot.n_sent_text += result.text_to_send.size();
|
||||||
|
|
|
@ -1199,7 +1199,11 @@ extern "C" {
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root,
|
||||||
|
const char ** trigger_words,
|
||||||
|
size_t num_trigger_words,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens);
|
||||||
|
|
||||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||||
|
|
|
@ -960,10 +960,26 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
return new llama_grammar {
|
||||||
|
vocab,
|
||||||
|
std::move(vec_rules),
|
||||||
|
std::move(stacks),
|
||||||
|
/* .partial_utf8 = */ {},
|
||||||
|
/* .awaiting_trigger = */ false,
|
||||||
|
/* .trigger_buffer = */ "",
|
||||||
|
/* .trigger_tokens = */ {},
|
||||||
|
/* .trigger_words = */ {},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
const char ** trigger_words,
|
||||||
|
size_t num_trigger_words,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens) {
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
|
@ -1035,10 +1051,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
|
std::vector<llama_token> vec_trigger_tokens;
|
||||||
|
std::vector<std::string> vec_trigger_words;
|
||||||
|
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
||||||
|
GGML_ASSERT(trigger_tokens != nullptr);
|
||||||
|
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < num_trigger_words; i++) {
|
||||||
|
GGML_ASSERT(trigger_words != nullptr);
|
||||||
|
vec_trigger_words.push_back(trigger_words[i]);
|
||||||
|
}
|
||||||
|
|
||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
return new llama_grammar {
|
||||||
|
vocab,
|
||||||
|
|
||||||
|
std::move(vec_rules),
|
||||||
|
std::move(stacks),
|
||||||
|
/* .partial_utf8 = */ {},
|
||||||
|
/* .awaiting_trigger = */ vec_trigger_tokens.size() > 0 || vec_trigger_words.size() > 0,
|
||||||
|
/* .trigger_buffer = */ "",
|
||||||
|
std::move(vec_trigger_tokens),
|
||||||
|
std::move(vec_trigger_words),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
|
@ -1055,6 +1092,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||||
grammar.rules,
|
grammar.rules,
|
||||||
grammar.stacks,
|
grammar.stacks,
|
||||||
grammar.partial_utf8,
|
grammar.partial_utf8,
|
||||||
|
grammar.awaiting_trigger,
|
||||||
|
grammar.trigger_buffer,
|
||||||
|
grammar.trigger_tokens,
|
||||||
|
grammar.trigger_words,
|
||||||
};
|
};
|
||||||
|
|
||||||
// redirect elements in stacks to point to new rules
|
// redirect elements in stacks to point to new rules
|
||||||
|
@ -1115,6 +1156,28 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||||
GGML_ASSERT(grammar.vocab != nullptr);
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
|
if (grammar.awaiting_trigger) {
|
||||||
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
|
grammar.awaiting_trigger = false;
|
||||||
|
llama_grammar_accept_str(grammar, grammar.vocab->token_to_piece(token));
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
grammar.trigger_buffer += grammar.vocab->token_to_piece(token);
|
||||||
|
for (const auto & word : grammar.trigger_words) {
|
||||||
|
auto pos = grammar.trigger_buffer.find(word);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
grammar.awaiting_trigger = false;
|
||||||
|
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
||||||
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
|
grammar.trigger_buffer.clear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (grammar.vocab->is_eog(token)) {
|
if (grammar.vocab->is_eog(token)) {
|
||||||
for (const auto & stack : grammar.stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -114,6 +115,11 @@ struct llama_grammar {
|
||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
|
||||||
|
bool awaiting_trigger;
|
||||||
|
std::string trigger_buffer;
|
||||||
|
std::vector<llama_token> trigger_tokens;
|
||||||
|
std::vector<std::string> trigger_words;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -127,7 +133,14 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index);
|
size_t start_rule_index);
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
const char ** trigger_words,
|
||||||
|
size_t num_trigger_words,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens);
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
|
|
@ -1465,7 +1465,18 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
std::vector<const char *> trigger_words;
|
||||||
|
for (auto & word : ctx->grammar->trigger_words) {
|
||||||
|
trigger_words.push_back(word.c_str());
|
||||||
|
}
|
||||||
|
auto * grammar_new = llama_grammar_init_impl(
|
||||||
|
ctx->grammar->vocab,
|
||||||
|
ctx->grammar_str.c_str(),
|
||||||
|
ctx->grammar_root.c_str(),
|
||||||
|
trigger_words.data(),
|
||||||
|
trigger_words.size(),
|
||||||
|
ctx->grammar->trigger_tokens.data(),
|
||||||
|
ctx->grammar->trigger_tokens.size());
|
||||||
|
|
||||||
llama_grammar_free_impl(ctx->grammar);
|
llama_grammar_free_impl(ctx->grammar);
|
||||||
ctx->grammar = grammar_new;
|
ctx->grammar = grammar_new;
|
||||||
|
@ -1474,7 +1485,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||||
|
|
||||||
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
|
auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr, nullptr, 0, nullptr, 0);
|
||||||
|
|
||||||
// copy the state
|
// copy the state
|
||||||
{
|
{
|
||||||
|
@ -1511,15 +1522,24 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
||||||
/* .free = */ llama_sampler_grammar_free,
|
/* .free = */ llama_sampler_grammar_free,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
|
||||||
|
struct llama_sampler * llama_sampler_init_grammar(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
const char ** trigger_words,
|
||||||
|
size_t num_trigger_words,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens) {
|
||||||
|
// struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
auto * ctx = new llama_sampler_grammar;
|
auto * ctx = new llama_sampler_grammar;
|
||||||
|
|
||||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_str = */ grammar_str,
|
/* .grammar_str = */ grammar_str,
|
||||||
/* .grammar_root = */ grammar_root,
|
/* .grammar_root = */ grammar_root,
|
||||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
|
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
|
|
|
@ -133,7 +133,6 @@ llama_target_and_test(test-chat-template.cpp)
|
||||||
# llama_target_and_test(test-opt.cpp) # SLOW
|
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||||
llama_target_and_test(test-gguf.cpp)
|
llama_target_and_test(test-gguf.cpp)
|
||||||
llama_target_and_test(test-backend-ops.cpp)
|
llama_target_and_test(test-backend-ops.cpp)
|
||||||
llama_target_and_test(test-antiprompts.cpp)
|
|
||||||
llama_target_and_test(test-tool-call.cpp)
|
llama_target_and_test(test-tool-call.cpp)
|
||||||
|
|
||||||
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||||
|
|
|
@ -1,109 +0,0 @@
|
||||||
#ifdef NDEBUG
|
|
||||||
#undef NDEBUG
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void assert_equal(const T & actual, const T & expected) {
|
|
||||||
if (expected == actual) return;
|
|
||||||
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
|
|
||||||
assert(expected == actual);
|
|
||||||
}
|
|
||||||
|
|
||||||
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
|
|
||||||
int main()
|
|
||||||
{
|
|
||||||
auto tokenizer = [&](const std::string & text) {
|
|
||||||
std::vector<llama_token> tokens;
|
|
||||||
for (size_t i = 0; i < text.length(); ++i) {
|
|
||||||
tokens.push_back(text[i]);
|
|
||||||
}
|
|
||||||
return tokens;
|
|
||||||
};
|
|
||||||
const std::vector<std::string> stop_words { };
|
|
||||||
const std::vector<std::string> grammar_trigger_words { };
|
|
||||||
|
|
||||||
printf("Testing antiprompts\n");
|
|
||||||
|
|
||||||
llama_antiprompts antiprompts;
|
|
||||||
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
|
|
||||||
|
|
||||||
assert_equal(antiprompts.findSingleTokenMatch('x'), {
|
|
||||||
/* .pos = */ 0,
|
|
||||||
/* .pattern = */ "x",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 1,
|
|
||||||
/* .is_grammar_trigger = */ true,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findSingleTokenMatch('a'), {
|
|
||||||
/* .pos = */ std::string::npos,
|
|
||||||
/* .pattern = */ "",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 0,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "",
|
|
||||||
/* .is_partial = */ true,
|
|
||||||
/* .matchLength = */ 2,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "abc",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 3,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" ab c", 0), {
|
|
||||||
/* .pos = */ std::string::npos,
|
|
||||||
/* .pattern = */ "",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 0,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" abc abc", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "abc",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 3,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" ab abc", 0), {
|
|
||||||
/* .pos = */ 4,
|
|
||||||
/* .pattern = */ "abc",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 3,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "",
|
|
||||||
/* .is_partial = */ true,
|
|
||||||
/* .matchLength = */ 2,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "bcd",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 3,
|
|
||||||
/* .is_grammar_trigger = */ false,
|
|
||||||
});
|
|
||||||
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
|
|
||||||
/* .pos = */ 1,
|
|
||||||
/* .pattern = */ "bca",
|
|
||||||
/* .is_partial = */ false,
|
|
||||||
/* .matchLength = */ 3,
|
|
||||||
/* .is_grammar_trigger = */ true,
|
|
||||||
});
|
|
||||||
printf("OK\n");
|
|
||||||
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
|
@ -13,7 +13,7 @@
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||||
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||||
|
|
|
@ -37,7 +37,7 @@ static std::string read_file(const std::string &path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
||||||
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"));
|
return std::unique_ptr<llama_grammar>(llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", nullptr, 0, nullptr, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
|
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue