grammar
: trigger words + refactor of antiprompts
This commit is contained in:
parent
70392f1f81
commit
5b6d5040d5
12 changed files with 436 additions and 108 deletions
6
Makefile
6
Makefile
|
@ -44,6 +44,7 @@ BUILD_TARGETS = \
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
TEST_TARGETS = \
|
TEST_TARGETS = \
|
||||||
|
tests/test-antiprompts \
|
||||||
tests/test-arg-parser \
|
tests/test-arg-parser \
|
||||||
tests/test-autorelease \
|
tests/test-autorelease \
|
||||||
tests/test-backend-ops \
|
tests/test-backend-ops \
|
||||||
|
@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
|
||||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
tests/test-antiprompts: tests/test-antiprompts.cpp \
|
||||||
|
$(OBJ_ALL)
|
||||||
|
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
tests/test-grad0: tests/test-grad0.cpp \
|
tests/test-grad0: tests/test-grad0.cpp \
|
||||||
$(OBJ_GGML)
|
$(OBJ_GGML)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
|
198
common/common.h
198
common/common.h
|
@ -4,9 +4,11 @@
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
@ -134,6 +136,7 @@ struct gpt_sampler_params {
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
|
||||||
|
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
||||||
|
@ -533,6 +536,201 @@ struct llama_control_vector_load_info {
|
||||||
// On error, returns {-1, empty}
|
// On error, returns {-1, empty}
|
||||||
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
|
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Antiprompt utils
|
||||||
|
//
|
||||||
|
|
||||||
|
class llama_antiprompts {
|
||||||
|
public:
|
||||||
|
|
||||||
|
struct llama_antiprompt {
|
||||||
|
std::string value;
|
||||||
|
bool is_grammar_trigger;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> stop_words;
|
||||||
|
std::vector<std::string> grammar_trigger_words;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
|
||||||
|
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
|
||||||
|
struct TrieNode {
|
||||||
|
std::unordered_map<char, TrieNode> children;
|
||||||
|
TrieNode* fail = nullptr;
|
||||||
|
int output = -1;
|
||||||
|
size_t depth = 0;
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
children.clear();
|
||||||
|
fail = nullptr;
|
||||||
|
output = -1;
|
||||||
|
depth = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TrieNode root;
|
||||||
|
std::vector<llama_antiprompt> antiprompts;
|
||||||
|
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
|
||||||
|
|
||||||
|
void build_trie() {
|
||||||
|
// root = std::unique_ptr<TrieNode>(new TrieNode());
|
||||||
|
for (size_t i = 0; i < antiprompts.size(); ++i) {
|
||||||
|
TrieNode* node = &root;
|
||||||
|
const auto & pattern = antiprompts[i].value;
|
||||||
|
for (size_t j = 0; j < pattern.length(); ++j) {
|
||||||
|
char c = pattern[j];
|
||||||
|
auto & child = node->children[c];
|
||||||
|
if (child.depth == 0) {
|
||||||
|
child.depth = j + 1;
|
||||||
|
}
|
||||||
|
node = &child;
|
||||||
|
}
|
||||||
|
node->output = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void build_failure_and_dict_links() {
|
||||||
|
std::queue<TrieNode*> q;
|
||||||
|
for (auto& child : root.children) {
|
||||||
|
child.second.fail = &root;
|
||||||
|
q.push(&child.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!q.empty()) {
|
||||||
|
auto node = q.front();
|
||||||
|
q.pop();
|
||||||
|
|
||||||
|
for (auto & pair : node->children) {
|
||||||
|
auto & c = pair.first;
|
||||||
|
auto & child = pair.second;
|
||||||
|
auto f = node->fail;
|
||||||
|
|
||||||
|
while (f != &root && f->children.find(c) == f->children.end()) {
|
||||||
|
f = f->fail;
|
||||||
|
}
|
||||||
|
|
||||||
|
child.fail = (f == &root && f->children.find(c) == f->children.end())
|
||||||
|
? &root : &f->children[c];
|
||||||
|
|
||||||
|
if (child.fail->output != -1) {
|
||||||
|
child.output = child.fail->output;
|
||||||
|
}
|
||||||
|
|
||||||
|
q.push(&child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return antiprompts.empty() && stop_tokens.empty();
|
||||||
|
}
|
||||||
|
void clear() {
|
||||||
|
root.clear();
|
||||||
|
antiprompts.clear();
|
||||||
|
stop_tokens.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||||
|
build(
|
||||||
|
[&](const std::string & text) {
|
||||||
|
return llama_tokenize(ctx, text, /* special= */ true);
|
||||||
|
},
|
||||||
|
stop_words,
|
||||||
|
grammar_trigger_words
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||||
|
clear();
|
||||||
|
this->stop_words = stop_words;
|
||||||
|
this->grammar_trigger_words = grammar_trigger_words;
|
||||||
|
|
||||||
|
for (const std::string & stop_word : stop_words) {
|
||||||
|
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
|
||||||
|
}
|
||||||
|
for (const std::string & trigger : grammar_trigger_words) {
|
||||||
|
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
|
||||||
|
const auto & antiprompt = antiprompts[i];
|
||||||
|
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
|
||||||
|
if (tokens.size() == 1) {
|
||||||
|
stop_tokens[tokens[0]] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
build_trie();
|
||||||
|
build_failure_and_dict_links();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatchResult {
|
||||||
|
size_t pos;
|
||||||
|
std::string pattern;
|
||||||
|
bool is_partial;
|
||||||
|
size_t matchLength;
|
||||||
|
bool is_grammar_trigger;
|
||||||
|
|
||||||
|
bool operator==(const MatchResult & other) const {
|
||||||
|
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
|
||||||
|
}
|
||||||
|
operator std::string() const {
|
||||||
|
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
MatchResult findSingleTokenMatch(llama_token token) const {
|
||||||
|
auto it = stop_tokens.find(token);
|
||||||
|
if (it != stop_tokens.end()) {
|
||||||
|
const auto & antiprompt = antiprompts[it->second];
|
||||||
|
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
|
||||||
|
}
|
||||||
|
return {std::string::npos, "", false, 0, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
|
||||||
|
TrieNode* current = &root;
|
||||||
|
MatchResult partialMatch{std::string::npos, "", true, 0, false};
|
||||||
|
|
||||||
|
for (size_t i = offset; i < text.length(); ++i) {
|
||||||
|
char c = text[i];
|
||||||
|
while (current != &root && current->children.find(c) == current->children.end()) {
|
||||||
|
current = current->fail;
|
||||||
|
}
|
||||||
|
auto it = current->children.find(c);
|
||||||
|
if (it != current->children.end()) {
|
||||||
|
current = &it->second;
|
||||||
|
}
|
||||||
|
if (current->output != -1) {
|
||||||
|
const auto & antiprompt = antiprompts[current->output];
|
||||||
|
return {
|
||||||
|
i - antiprompt.value.length() + 1,
|
||||||
|
antiprompt.value,
|
||||||
|
false,
|
||||||
|
antiprompt.value.length(),
|
||||||
|
antiprompt.is_grammar_trigger,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// Update partial match if we're at a deeper node
|
||||||
|
if (current->depth > partialMatch.matchLength) {
|
||||||
|
partialMatch.pos = i - current->depth + 1;
|
||||||
|
partialMatch.pattern = ""; // We don't know which pattern it partially matches
|
||||||
|
partialMatch.matchLength = current->depth;
|
||||||
|
partialMatch.is_grammar_trigger = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've found a partial match and haven't returned a full match, return the partial match
|
||||||
|
if (partialMatch.pos != std::string::npos) {
|
||||||
|
return partialMatch;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::string::npos, "", false, 0, false};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// Split utils
|
// Split utils
|
||||||
//
|
//
|
||||||
|
|
|
@ -139,6 +139,15 @@ std::string gpt_sampler_params::print() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
|
||||||
|
if (gsmpl->grmr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
|
||||||
|
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
||||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||||
|
|
||||||
|
@ -146,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
|
|
||||||
auto * result = new gpt_sampler {
|
auto * result = new gpt_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
/* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
|
@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
|
|
||||||
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_sampler_free(gsmpl->grmr);
|
if (gsmpl->grmr) {
|
||||||
|
llama_sampler_free(gsmpl->grmr);
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
|
||||||
|
|
|
@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n
|
||||||
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
|
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
|
||||||
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
|
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
|
||||||
|
|
||||||
|
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger);
|
||||||
|
|
||||||
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
|
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
|
@ -36,7 +36,7 @@ static llama_model ** g_model;
|
||||||
static gpt_sampler ** g_smpl;
|
static gpt_sampler ** g_smpl;
|
||||||
static gpt_params * g_params;
|
static gpt_params * g_params;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::string * g_output_s;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
static bool is_interacting = false;
|
static bool is_interacting = false;
|
||||||
static bool need_insert_eot = false;
|
static bool need_insert_eot = false;
|
||||||
|
@ -115,7 +115,7 @@ static void sigint_handler(int signo) {
|
||||||
console::cleanup();
|
console::cleanup();
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
gpt_perf_print(*g_ctx, *g_smpl);
|
gpt_perf_print(*g_ctx, *g_smpl);
|
||||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, *g_output_s, *g_output_tokens);
|
||||||
|
|
||||||
// make sure all logs are flushed
|
// make sure all logs are flushed
|
||||||
LOG("Interrupted by user\n");
|
LOG("Interrupted by user\n");
|
||||||
|
@ -507,7 +507,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||||
std::ostringstream output_ss; g_output_ss = &output_ss;
|
std::string output_s; g_output_s = &output_s;
|
||||||
|
size_t last_partial_stop = std::string::npos;
|
||||||
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
|
||||||
|
|
||||||
// the first thing we will do is to output the prompt, so set color accordingly
|
// the first thing we will do is to output the prompt, so set color accordingly
|
||||||
|
@ -516,13 +517,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
// tokenized antiprompts
|
llama_antiprompts antiprompts;
|
||||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
antiprompts.build(ctx, params.antiprompt, {});
|
||||||
|
|
||||||
antiprompt_ids.reserve(params.antiprompt.size());
|
|
||||||
for (const std::string & antiprompt : params.antiprompt) {
|
|
||||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
|
@ -727,7 +723,7 @@ int main(int argc, char ** argv) {
|
||||||
} else {
|
} else {
|
||||||
// Outgoing Generated Tokens
|
// Outgoing Generated Tokens
|
||||||
output_tokens.push_back(id);
|
output_tokens.push_back(id);
|
||||||
output_ss << token_str;
|
output_s.append(token_str);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -740,44 +736,34 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// if not currently processing queued inputs;
|
// if not currently processing queued inputs;
|
||||||
if ((int) embd_inp.size() <= n_consumed) {
|
if ((int) embd_inp.size() <= n_consumed) {
|
||||||
// check for reverse prompt in the last n_prev tokens
|
// check for reverse prompt
|
||||||
if (!params.antiprompt.empty()) {
|
if (!antiprompts.empty()) {
|
||||||
const int n_prev = 32;
|
|
||||||
const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
|
|
||||||
|
|
||||||
is_antiprompt = false;
|
is_antiprompt = false;
|
||||||
// Check if each of the reverse prompts appears at the end of the output.
|
|
||||||
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
|
||||||
// so we'll compensate for that by widening the search window a bit.
|
|
||||||
for (std::string & antiprompt : params.antiprompt) {
|
|
||||||
size_t extra_padding = params.interactive ? 0 : 2;
|
|
||||||
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
|
||||||
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
|
||||||
: 0;
|
|
||||||
|
|
||||||
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
|
||||||
if (params.interactive) {
|
|
||||||
is_interacting = true;
|
|
||||||
}
|
|
||||||
is_antiprompt = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for reverse prompt using special tokens
|
// check for reverse prompt using special tokens
|
||||||
llama_token last_token = gpt_sampler_last(smpl);
|
llama_token last_token = gpt_sampler_last(smpl);
|
||||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
auto match = antiprompts.findSingleTokenMatch(last_token);
|
||||||
if (ids.size() == 1 && last_token == ids[0]) {
|
if (match.pos != std::string::npos) {
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
|
}
|
||||||
|
is_antiprompt = true;
|
||||||
|
} else {
|
||||||
|
match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop);
|
||||||
|
if (match.pos != std::string::npos) {
|
||||||
|
if (match.is_partial) {
|
||||||
|
last_partial_stop = match.pos;
|
||||||
|
} else {
|
||||||
|
if (params.interactive) {
|
||||||
|
is_interacting = true;
|
||||||
|
}
|
||||||
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
is_antiprompt = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_antiprompt) {
|
if (is_antiprompt) {
|
||||||
LOG_DBG("found antiprompt: %s\n", last_output.c_str());
|
LOG_DBG("found antiprompt: %s\n", match.pattern.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -786,9 +772,9 @@ int main(int argc, char ** argv) {
|
||||||
LOG_DBG("found an EOG token\n");
|
LOG_DBG("found an EOG token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
if (!params.antiprompt.empty()) {
|
if (!antiprompts.stop_words.empty()) {
|
||||||
// tokenize and inject first reverse prompt
|
// tokenize and inject first reverse prompt
|
||||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
|
const auto first_antiprompt = ::llama_tokenize(ctx, antiprompts.stop_words.front(), false, true);
|
||||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
|
@ -882,7 +868,7 @@ int main(int argc, char ** argv) {
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
const llama_token token = embd_inp[i];
|
const llama_token token = embd_inp[i];
|
||||||
output_tokens.push_back(token);
|
output_tokens.push_back(token);
|
||||||
output_ss << llama_token_to_piece(ctx, token);
|
output_s.append(llama_token_to_piece(ctx, token));
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset assistant message
|
// reset assistant message
|
||||||
|
@ -926,7 +912,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
LOG("\n\n");
|
LOG("\n\n");
|
||||||
gpt_perf_print(ctx, smpl);
|
gpt_perf_print(ctx, smpl);
|
||||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
write_logfile(ctx, params, model, input_tokens, output_s, output_tokens);
|
||||||
|
|
||||||
gpt_sampler_free(smpl);
|
gpt_sampler_free(smpl);
|
||||||
|
|
||||||
|
|
|
@ -131,8 +131,6 @@ struct slot_params {
|
||||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
|
||||||
|
|
||||||
json input_prefix;
|
json input_prefix;
|
||||||
json input_suffix;
|
json input_suffix;
|
||||||
};
|
};
|
||||||
|
@ -183,6 +181,8 @@ struct server_slot {
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
|
llama_antiprompts antiprompts;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
@ -281,34 +281,6 @@ struct server_slot {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
|
|
||||||
size_t stop_pos = std::string::npos;
|
|
||||||
|
|
||||||
for (const std::string & word : params.antiprompt) {
|
|
||||||
size_t pos;
|
|
||||||
|
|
||||||
if (type == STOP_TYPE_FULL) {
|
|
||||||
const size_t tmp = word.size() + last_token_size;
|
|
||||||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
||||||
|
|
||||||
pos = text.find(word, from_pos);
|
|
||||||
} else {
|
|
||||||
pos = find_partial_stop_string(word, text);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
||||||
if (type == STOP_TYPE_FULL) {
|
|
||||||
stopped_word = true;
|
|
||||||
stopping_word = word;
|
|
||||||
has_next_token = false;
|
|
||||||
}
|
|
||||||
stop_pos = pos;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return stop_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_timings() const {
|
void print_timings() const {
|
||||||
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
||||||
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
||||||
|
@ -999,16 +971,26 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
slot.params.antiprompt.clear();
|
slot.antiprompts.clear();
|
||||||
|
|
||||||
const auto & stop = data.find("stop");
|
auto copy_string_array = [&](const json & data, const std::string & key, std::vector<std::string> & vec) {
|
||||||
if (stop != data.end() && stop->is_array()) {
|
const auto & arr = data.find(key);
|
||||||
for (const auto & word : *stop) {
|
if (arr != data.end() && arr->is_array()) {
|
||||||
if (!word.empty()) {
|
for (const auto & word : *arr) {
|
||||||
slot.params.antiprompt.push_back(word);
|
if (word.is_string()) {
|
||||||
|
vec.push_back(word);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> stop_words;
|
||||||
|
std::vector<std::string> grammar_trigger_words;
|
||||||
|
|
||||||
|
copy_string_array(data, "stop", stop_words);
|
||||||
|
copy_string_array(data, "grammar_trigger_words", grammar_trigger_words);
|
||||||
|
|
||||||
|
slot.antiprompts.build(ctx, stop_words, grammar_trigger_words);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -1110,6 +1092,18 @@ struct server_context {
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
|
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
|
||||||
|
if (match.pos != std::string::npos && !match.is_partial) {
|
||||||
|
if (match.is_grammar_trigger) {
|
||||||
|
gpt_sampler_trigger_grammar(model, slot.smpl, llama_token_to_piece(ctx, result.tok, params.special));
|
||||||
|
} else {
|
||||||
|
slot.stopped_word = true;
|
||||||
|
slot.stopping_word = match.pattern;
|
||||||
|
slot.has_next_token = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// search stop word and delete it
|
// search stop word and delete it
|
||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
@ -1139,23 +1133,33 @@ struct server_context {
|
||||||
if (!incomplete) {
|
if (!incomplete) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
match = slot.antiprompts.findFirstMatch(slot.generated_text, pos);
|
||||||
bool is_stop_full = false;
|
|
||||||
|
bool is_stop_full = false;
|
||||||
|
bool is_grammar_trigger = false;
|
||||||
|
size_t length = slot.generated_text.size();
|
||||||
|
|
||||||
|
// If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar
|
||||||
|
if (match.is_grammar_trigger && gpt_sampler_trigger_grammar(model, slot.smpl, match.pattern)) {
|
||||||
|
is_grammar_trigger = true;
|
||||||
|
length = pos + match.pos + match.matchLength;
|
||||||
|
} else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) {
|
||||||
|
slot.stopped_word = true;
|
||||||
|
slot.stopping_word = match.pattern;
|
||||||
|
slot.has_next_token = false;
|
||||||
|
|
||||||
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
|
|
||||||
if (stop_pos != std::string::npos) {
|
|
||||||
is_stop_full = true;
|
is_stop_full = true;
|
||||||
slot.generated_text.erase(
|
// length = pos + match.pos;
|
||||||
slot.generated_text.begin() + pos + stop_pos,
|
length = match.pos;
|
||||||
slot.generated_text.end());
|
|
||||||
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
||||||
} else {
|
|
||||||
is_stop_full = false;
|
|
||||||
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.generated_text.erase(
|
||||||
|
slot.generated_text.begin() + length,
|
||||||
|
slot.generated_text.end());
|
||||||
|
pos = std::min(slot.n_sent_text, length);
|
||||||
|
|
||||||
// check if there is any token to predict
|
// check if there is any token to predict
|
||||||
if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
|
if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) {
|
||||||
// no send the stop word in the response
|
// no send the stop word in the response
|
||||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
slot.n_sent_text += result.text_to_send.size();
|
slot.n_sent_text += result.text_to_send.size();
|
||||||
|
@ -1243,7 +1247,8 @@ struct server_context {
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
{"penalize_nl", slot.sparams.penalize_nl},
|
{"penalize_nl", slot.sparams.penalize_nl},
|
||||||
{"stop", slot.params.antiprompt},
|
{"stop", slot.antiprompts.stop_words},
|
||||||
|
{"grammar_trigger", slot.antiprompts.grammar_trigger_words},
|
||||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_discard", slot.params.n_discard},
|
{"n_discard", slot.params.n_discard},
|
||||||
|
|
|
@ -196,20 +196,15 @@ static size_t common_part(const std::string & a, const std::string & b) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
static size_t find_partial_stop_string(const std::string & stop, const std::string & text) {
|
||||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
|
|
||||||
if (!text.empty() && !stop.empty()) {
|
if (!text.empty() && !stop.empty()) {
|
||||||
const char text_last_char = text.back();
|
auto it = std::find(stop.rbegin(), stop.rend(), text.back());
|
||||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
while (it != stop.rend()) {
|
||||||
if (stop[char_index] == text_last_char) {
|
size_t length = std::distance(it, stop.rend());
|
||||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) {
|
||||||
if (ends_with(text, current_partial)) {
|
return text.length() - length;
|
||||||
return text.size() - char_index - 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
it = std::find(std::next(it), stop.rend(), text.back());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1121,7 +1121,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
||||||
|
llama_grammar_accept_str(grammar, piece);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
|
@ -142,3 +142,7 @@ void llama_grammar_apply_impl(
|
||||||
void llama_grammar_accept_impl(
|
void llama_grammar_accept_impl(
|
||||||
struct llama_grammar & grammar,
|
struct llama_grammar & grammar,
|
||||||
llama_token token);
|
llama_token token);
|
||||||
|
|
||||||
|
void llama_grammar_accept_str(
|
||||||
|
struct llama_grammar & grammar,
|
||||||
|
const std::string & piece);
|
||||||
|
|
|
@ -193,6 +193,12 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampler_accept_str(struct llama_sampler * smpl, const char * piece) {
|
||||||
|
if (smpl->iface->accept_str) {
|
||||||
|
smpl->iface->accept_str(smpl, piece);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(smpl->iface->apply);
|
GGML_ASSERT(smpl->iface->apply);
|
||||||
smpl->iface->apply(smpl, cur_p);
|
smpl->iface->apply(smpl, cur_p);
|
||||||
|
@ -325,6 +331,7 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||||
/* .name = */ llama_sampler_chain_name,
|
/* .name = */ llama_sampler_chain_name,
|
||||||
/* .accept = */ llama_sampler_chain_accept,
|
/* .accept = */ llama_sampler_chain_accept,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_chain_apply,
|
/* .apply = */ llama_sampler_chain_apply,
|
||||||
/* .reset = */ llama_sampler_chain_reset,
|
/* .reset = */ llama_sampler_chain_reset,
|
||||||
/* .clone = */ llama_sampler_chain_clone,
|
/* .clone = */ llama_sampler_chain_clone,
|
||||||
|
@ -399,6 +406,7 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
||||||
static struct llama_sampler_i llama_sampler_greedy_i = {
|
static struct llama_sampler_i llama_sampler_greedy_i = {
|
||||||
/* .name = */ llama_sampler_greedy_name,
|
/* .name = */ llama_sampler_greedy_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_greedy_apply,
|
/* .apply = */ llama_sampler_greedy_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ nullptr,
|
/* .clone = */ nullptr,
|
||||||
|
@ -457,6 +465,7 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_dist_i = {
|
static struct llama_sampler_i llama_sampler_dist_i = {
|
||||||
/* .name = */ llama_sampler_dist_name,
|
/* .name = */ llama_sampler_dist_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_dist_apply,
|
/* .apply = */ llama_sampler_dist_apply,
|
||||||
/* .reset = */ llama_sampler_dist_reset,
|
/* .reset = */ llama_sampler_dist_reset,
|
||||||
/* .clone = */ llama_sampler_dist_clone,
|
/* .clone = */ llama_sampler_dist_clone,
|
||||||
|
@ -488,6 +497,7 @@ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_t
|
||||||
static struct llama_sampler_i llama_sampler_softmax_i = {
|
static struct llama_sampler_i llama_sampler_softmax_i = {
|
||||||
/* .name = */ llama_sampler_softmax_name,
|
/* .name = */ llama_sampler_softmax_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_softmax_apply,
|
/* .apply = */ llama_sampler_softmax_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ nullptr,
|
/* .clone = */ nullptr,
|
||||||
|
@ -528,6 +538,7 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_top_k_i = {
|
static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||||
/* .name = */ llama_sampler_top_k_name,
|
/* .name = */ llama_sampler_top_k_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_top_k_apply,
|
/* .apply = */ llama_sampler_top_k_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_top_k_clone,
|
/* .clone = */ llama_sampler_top_k_clone,
|
||||||
|
@ -594,6 +605,7 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_top_p_i = {
|
static struct llama_sampler_i llama_sampler_top_p_i = {
|
||||||
/* .name = */ llama_sampler_top_p_name,
|
/* .name = */ llama_sampler_top_p_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_top_p_apply,
|
/* .apply = */ llama_sampler_top_p_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_top_p_clone,
|
/* .clone = */ llama_sampler_top_p_clone,
|
||||||
|
@ -690,6 +702,7 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_min_p_i = {
|
static struct llama_sampler_i llama_sampler_min_p_i = {
|
||||||
/* .name = */ llama_sampler_min_p_name,
|
/* .name = */ llama_sampler_min_p_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_min_p_apply,
|
/* .apply = */ llama_sampler_min_p_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_min_p_clone,
|
/* .clone = */ llama_sampler_min_p_clone,
|
||||||
|
@ -785,6 +798,7 @@ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
||||||
/* .name = */ llama_sampler_tail_free_name,
|
/* .name = */ llama_sampler_tail_free_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_tail_free_apply,
|
/* .apply = */ llama_sampler_tail_free_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_tail_free_clone,
|
/* .clone = */ llama_sampler_tail_free_clone,
|
||||||
|
@ -884,6 +898,7 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_typical_i = {
|
static struct llama_sampler_i llama_sampler_typical_i = {
|
||||||
/* .name = */ llama_sampler_typical_name,
|
/* .name = */ llama_sampler_typical_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_typical_apply,
|
/* .apply = */ llama_sampler_typical_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_typical_clone,
|
/* .clone = */ llama_sampler_typical_clone,
|
||||||
|
@ -929,6 +944,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_temp_i = {
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
/* .name = */ llama_sampler_temp_name,
|
/* .name = */ llama_sampler_temp_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_temp_apply,
|
/* .apply = */ llama_sampler_temp_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_temp_clone,
|
/* .clone = */ llama_sampler_temp_clone,
|
||||||
|
@ -1042,6 +1058,7 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
/* .name = */ llama_sampler_temp_ext_name,
|
/* .name = */ llama_sampler_temp_ext_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_temp_ext_apply,
|
/* .apply = */ llama_sampler_temp_ext_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_temp_ext_clone,
|
/* .clone = */ llama_sampler_temp_ext_clone,
|
||||||
|
@ -1145,6 +1162,7 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||||
/* .name = */ llama_sampler_mirostat_name,
|
/* .name = */ llama_sampler_mirostat_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_mirostat_apply,
|
/* .apply = */ llama_sampler_mirostat_apply,
|
||||||
/* .reset = */ llama_sampler_mirostat_reset,
|
/* .reset = */ llama_sampler_mirostat_reset,
|
||||||
/* .clone = */ llama_sampler_mirostat_clone,
|
/* .clone = */ llama_sampler_mirostat_clone,
|
||||||
|
@ -1244,6 +1262,7 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
||||||
/* .name = */ llama_sampler_mirostat_v2_name,
|
/* .name = */ llama_sampler_mirostat_v2_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
||||||
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
||||||
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
||||||
|
@ -1287,6 +1306,13 @@ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_grammar_accept_str(struct llama_sampler * smpl, const char * piece) {
|
||||||
|
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||||
|
if (ctx->grammar) {
|
||||||
|
llama_grammar_accept_str(*ctx->grammar, piece);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||||
if (ctx->grammar) {
|
if (ctx->grammar) {
|
||||||
|
@ -1339,6 +1365,7 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_grammar_i = {
|
static struct llama_sampler_i llama_sampler_grammar_i = {
|
||||||
/* .name = */ llama_sampler_grammar_name,
|
/* .name = */ llama_sampler_grammar_name,
|
||||||
/* .accept = */ llama_sampler_grammar_accept_impl,
|
/* .accept = */ llama_sampler_grammar_accept_impl,
|
||||||
|
/* .accept_str = */ llama_sampler_grammar_accept_str,
|
||||||
/* .apply = */ llama_sampler_grammar_apply,
|
/* .apply = */ llama_sampler_grammar_apply,
|
||||||
/* .reset = */ llama_sampler_grammar_reset,
|
/* .reset = */ llama_sampler_grammar_reset,
|
||||||
/* .clone = */ llama_sampler_grammar_clone,
|
/* .clone = */ llama_sampler_grammar_clone,
|
||||||
|
@ -1522,6 +1549,7 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_penalties_i = {
|
static struct llama_sampler_i llama_sampler_penalties_i = {
|
||||||
/* .name = */ llama_sampler_penalties_name,
|
/* .name = */ llama_sampler_penalties_name,
|
||||||
/* .accept = */ llama_sampler_penalties_accept,
|
/* .accept = */ llama_sampler_penalties_accept,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_penalties_apply,
|
/* .apply = */ llama_sampler_penalties_apply,
|
||||||
/* .reset = */ llama_sampler_penalties_reset,
|
/* .reset = */ llama_sampler_penalties_reset,
|
||||||
/* .clone = */ llama_sampler_penalties_clone,
|
/* .clone = */ llama_sampler_penalties_clone,
|
||||||
|
@ -1624,6 +1652,7 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
||||||
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
||||||
/* .name = */ llama_sampler_logit_bias_name,
|
/* .name = */ llama_sampler_logit_bias_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
/* .accept_str = */ nullptr,
|
||||||
/* .apply = */ llama_sampler_logit_bias_apply,
|
/* .apply = */ llama_sampler_logit_bias_apply,
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_logit_bias_clone,
|
/* .clone = */ llama_sampler_logit_bias_clone,
|
||||||
|
|
|
@ -122,6 +122,7 @@ llama_target_and_test(test-grad0.cpp)
|
||||||
llama_target_and_test(test-barrier.cpp)
|
llama_target_and_test(test-barrier.cpp)
|
||||||
# llama_target_and_test(test-opt.cpp) # SLOW
|
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||||
llama_target_and_test(test-backend-ops.cpp)
|
llama_target_and_test(test-backend-ops.cpp)
|
||||||
|
llama_target_and_test(test-antiprompts.cpp)
|
||||||
|
|
||||||
llama_target_and_test(test-rope.cpp)
|
llama_target_and_test(test-rope.cpp)
|
||||||
|
|
||||||
|
|
88
tests/test-antiprompts.cpp
Normal file
88
tests/test-antiprompts.cpp
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
#ifdef NDEBUG
|
||||||
|
#undef NDEBUG
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void assert_equal(const T & actual, const T & expected) {
|
||||||
|
if (expected == actual) return;
|
||||||
|
printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str());
|
||||||
|
assert(expected == actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
auto tokenizer = [&](const std::string & text) {
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
for (size_t i = 0; i < text.length(); ++i) {
|
||||||
|
tokens.push_back(text[i]);
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
};
|
||||||
|
const std::vector<std::string> stop_words { };
|
||||||
|
const std::vector<std::string> grammar_trigger_words { };
|
||||||
|
|
||||||
|
printf("Testing antiprompts\n");
|
||||||
|
|
||||||
|
llama_antiprompts antiprompts;
|
||||||
|
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
|
||||||
|
|
||||||
|
assert_equal(antiprompts.findSingleTokenMatch('x'), {
|
||||||
|
.pos = 0,
|
||||||
|
.pattern = "x",
|
||||||
|
.is_partial = false,
|
||||||
|
.matchLength = 1,
|
||||||
|
.is_grammar_trigger = true,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findSingleTokenMatch('a'), {
|
||||||
|
.pos = std::string::npos,
|
||||||
|
.pattern = "",
|
||||||
|
.is_partial = false,
|
||||||
|
.matchLength = 0,
|
||||||
|
.is_grammar_trigger = false,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findFirstMatch(" ab", 0), {
|
||||||
|
.pos = 1,
|
||||||
|
.pattern = "",
|
||||||
|
.is_partial = true,
|
||||||
|
.matchLength = 2,
|
||||||
|
.is_grammar_trigger = false,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findFirstMatch(" abc", 0), {
|
||||||
|
.pos = 1,
|
||||||
|
.pattern = "abc",
|
||||||
|
.is_partial = false,
|
||||||
|
.matchLength = 3,
|
||||||
|
.is_grammar_trigger = false,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findFirstMatch(" bc", 0), {
|
||||||
|
.pos = 1,
|
||||||
|
.pattern = "",
|
||||||
|
.is_partial = true,
|
||||||
|
.matchLength = 2,
|
||||||
|
.is_grammar_trigger = false,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findFirstMatch(" bcd", 0), {
|
||||||
|
.pos = 1,
|
||||||
|
.pattern = "bcd",
|
||||||
|
.is_partial = false,
|
||||||
|
.matchLength = 3,
|
||||||
|
.is_grammar_trigger = false,
|
||||||
|
});
|
||||||
|
assert_equal(antiprompts.findFirstMatch(" bca", 0), {
|
||||||
|
.pos = 1,
|
||||||
|
.pattern = "bca",
|
||||||
|
.is_partial = false,
|
||||||
|
.matchLength = 3,
|
||||||
|
.is_grammar_trigger = true,
|
||||||
|
});
|
||||||
|
printf("OK\n");
|
||||||
|
// llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false});
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue