fix editorconfig lints

This commit is contained in:
ochafik 2024-09-26 02:27:46 +01:00
parent ab25e3fbf9
commit 1b6280102b
17 changed files with 114 additions and 106 deletions

View file

@ -30,3 +30,11 @@ indent_style = tab
[examples/cvector-generator/*.txt] [examples/cvector-generator/*.txt]
trim_trailing_whitespace = unset trim_trailing_whitespace = unset
insert_final_newline = unset insert_final_newline = unset
[{tests/chat/templates/*.jinja,tests/chat/goldens/*.txt}]
indent_style = unset
indent_size = unset
end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset

View file

@ -1516,7 +1516,7 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
nullptr, nullptr,
tmpl.c_str(), tmpl.c_str(),
chat, chat,
1, 1,
/* add_ass= */ true, /* add_ass= */ true,
/* buffer= */ nullptr, /* buffer= */ nullptr,
/* length= */ 0, /* length= */ 0,

View file

@ -624,7 +624,7 @@ private:
f = f->fail; f = f->fail;
} }
child.fail = (f == &root && f->children.find(c) == f->children.end()) child.fail = (f == &root && f->children.find(c) == f->children.end())
? &root : &f->children[c]; ? &root : &f->children[c];
if (child.fail->output != -1) { if (child.fail->output != -1) {
@ -654,7 +654,7 @@ private:
}, },
stop_words, stop_words,
grammar_trigger_words grammar_trigger_words
); );
} }
void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) { void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
@ -708,7 +708,7 @@ private:
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
TrieNode* current = &root; TrieNode* current = &root;
MatchResult partialMatch{std::string::npos, "", true, 0, false}; MatchResult partialMatch{std::string::npos, "", true, 0, false};
for (size_t i = offset; i < text.length(); ++i) { for (size_t i = offset; i < text.length(); ++i) {
char c = text[i]; char c = text[i];
while (current != &root && current->children.find(c) == current->children.end()) { while (current != &root && current->children.find(c) == current->children.end()) {
@ -736,12 +736,12 @@ private:
partialMatch.is_grammar_trigger = false; partialMatch.is_grammar_trigger = false;
} }
} }
// If we've found a partial match and haven't returned a full match, return the partial match // If we've found a partial match and haven't returned a full match, return the partial match
if (partialMatch.pos != std::string::npos) { if (partialMatch.pos != std::string::npos) {
return partialMatch; return partialMatch;
} }
return {std::string::npos, "", false, 0, false}; return {std::string::npos, "", false, 0, false};
} }
}; };

View file

@ -48,7 +48,7 @@ public:
} }
return Value(); return Value();
} }
bool empty() { bool empty() {
return args.empty() && kwargs.empty(); return args.empty() && kwargs.empty();
} }
@ -61,7 +61,7 @@ public:
} }
} }
}; };
using CallableType = std::function<Value(const std::shared_ptr<Context> &, Arguments &)>; using CallableType = std::function<Value(const std::shared_ptr<Context> &, Arguments &)>;
using FilterType = std::function<Value(const std::shared_ptr<Context> &, Arguments &)>; using FilterType = std::function<Value(const std::shared_ptr<Context> &, Arguments &)>;
@ -143,7 +143,7 @@ private:
} else if (is_boolean()) { } else if (is_boolean()) {
out << (this->to_bool() ? "True" : "False"); out << (this->to_bool() ? "True" : "False");
} else if (is_string()) { } else if (is_string()) {
dump_string(primitive_, out, string_quote); dump_string(primitive_, out, string_quote);
} else { } else {
out << primitive_.dump(); out << primitive_.dump();
} }
@ -175,7 +175,7 @@ public:
primitive_ = v; primitive_ = v;
} }
} }
std::vector<Value> keys() { std::vector<Value> keys() {
if (!object_) throw std::runtime_error("Value is not an object: " + dump()); if (!object_) throw std::runtime_error("Value is not an object: " + dump());
std::vector<Value> res; std::vector<Value> res;
@ -267,7 +267,7 @@ public:
if (is_string()) return !get<std::string>().empty(); if (is_string()) return !get<std::string>().empty();
if (is_array()) return !empty(); if (is_array()) return !empty();
return true; return true;
} }
bool operator<(const Value & other) const { bool operator<(const Value & other) const {
if (is_null()) if (is_null())
@ -369,7 +369,7 @@ public:
if (!contains(key)) return default_value; if (!contains(key)) return default_value;
return at(key).get<T>(); return at(key).get<T>();
} }
template <typename T> template <typename T>
T get() const { T get() const {
if (is_primitive()) return primitive_.get<T>(); if (is_primitive()) return primitive_.get<T>();
@ -730,7 +730,7 @@ class TemplateNode {
Location location_; Location location_;
protected: protected:
virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0; virtual void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const = 0;
public: public:
TemplateNode(const Location & location) : location_(location) {} TemplateNode(const Location & location) : location_(location) {}
void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const { void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
@ -817,7 +817,7 @@ public:
ForNode(const Location & location, std::vector<std::string> && var_names, std::unique_ptr<Expression> && iterable, ForNode(const Location & location, std::vector<std::string> && var_names, std::unique_ptr<Expression> && iterable,
std::unique_ptr<Expression> && condition, std::unique_ptr<TemplateNode> && body, bool recursive, std::unique_ptr<TemplateNode> && else_body) std::unique_ptr<Expression> && condition, std::unique_ptr<TemplateNode> && body, bool recursive, std::unique_ptr<TemplateNode> && else_body)
: TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
// https://jinja.palletsprojects.com/en/3.0.x/templates/#for // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
@ -920,7 +920,7 @@ public:
auto & arg_name = arg.first; auto & arg_name = arg.first;
auto it = named_param_positions.find(arg_name); auto it = named_param_positions.find(arg_name);
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
call_context->set(arg_name, arg.second); call_context->set(arg_name, arg.second);
param_set[it->second] = true; param_set[it->second] = true;
} }
@ -1098,7 +1098,7 @@ public:
: Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto l = left->evaluate(context); auto l = left->evaluate(context);
auto do_eval = [&](const Value & l) -> Value { auto do_eval = [&](const Value & l) -> Value {
if (op == Op::Is || op == Op::IsNot) { if (op == Op::Is || op == Op::IsNot) {
auto t = dynamic_cast<VariableExpr*>(right.get()); auto t = dynamic_cast<VariableExpr*>(right.get());
@ -1297,7 +1297,7 @@ private:
std::shared_ptr<std::string> template_str; std::shared_ptr<std::string> template_str;
CharIterator start, end, it; CharIterator start, end, it;
Options options; Options options;
Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) { Parser(const std::shared_ptr<std::string>& template_str, const Options & options) : template_str(template_str), options(options) {
if (!template_str) throw std::runtime_error("Template string is null"); if (!template_str) throw std::runtime_error("Template string is null");
start = it = this->template_str->begin(); start = it = this->template_str->begin();
@ -1326,7 +1326,7 @@ private:
case 'b': result += '\b'; break; case 'b': result += '\b'; break;
case 'f': result += '\f'; break; case 'f': result += '\f'; break;
case '\\': result += '\\'; break; case '\\': result += '\\'; break;
default: default:
if (*it == quote) { if (*it == quote) {
result += quote; result += quote;
} else { } else {
@ -1562,7 +1562,7 @@ private:
if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
return nonstd_make_unique<BinaryOpExpr>( return nonstd_make_unique<BinaryOpExpr>(
left->location, left->location,
std::move(left), std::move(identifier), std::move(left), std::move(identifier),
negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
} }
@ -1588,7 +1588,7 @@ private:
if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list");
Expression::Parameters result; Expression::Parameters result;
while (it != end) { while (it != end) {
if (!consumeToken(")").empty()) { if (!consumeToken(")").empty()) {
return result; return result;
@ -1622,7 +1622,7 @@ private:
if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args");
Expression::Arguments result; Expression::Arguments result;
while (it != end) { while (it != end) {
if (!consumeToken(")").empty()) { if (!consumeToken(")").empty()) {
return result; return result;
@ -1655,7 +1655,7 @@ private:
static std::regex ident_regex(R"((?!not|is|and|or|del)[a-zA-Z_]\w*)"); static std::regex ident_regex(R"((?!not|is|and|or|del)[a-zA-Z_]\w*)");
auto location = get_location(); auto location = get_location();
auto ident = consumeToken(ident_regex); auto ident = consumeToken(ident_regex);
if (ident.empty()) if (ident.empty())
return nullptr; return nullptr;
return nonstd_make_unique<VariableExpr>(location, ident); return nonstd_make_unique<VariableExpr>(location, ident);
} }
@ -1699,7 +1699,7 @@ private:
} }
return left; return left;
} }
std::unique_ptr<Expression> parseMathMulDiv() { std::unique_ptr<Expression> parseMathMulDiv() {
auto left = parseMathUnaryPlusMinus(); auto left = parseMathUnaryPlusMinus();
if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
@ -1709,9 +1709,9 @@ private:
while (!(op_str = consumeToken(mul_div_tok)).empty()) { while (!(op_str = consumeToken(mul_div_tok)).empty()) {
auto right = parseMathUnaryPlusMinus(); auto right = parseMathUnaryPlusMinus();
if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression");
auto op = op_str == "*" ? BinaryOpExpr::Op::Mul auto op = op_str == "*" ? BinaryOpExpr::Op::Mul
: op_str == "**" ? BinaryOpExpr::Op::MulMul : op_str == "**" ? BinaryOpExpr::Op::MulMul
: op_str == "/" ? BinaryOpExpr::Op::Div : op_str == "/" ? BinaryOpExpr::Op::Div
: op_str == "//" ? BinaryOpExpr::Op::DivDiv : op_str == "//" ? BinaryOpExpr::Op::DivDiv
: BinaryOpExpr::Op::Mod; : BinaryOpExpr::Op::Mod;
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op); left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
@ -1741,14 +1741,14 @@ private:
auto op_str = consumeToken(unary_plus_minus_tok); auto op_str = consumeToken(unary_plus_minus_tok);
auto expr = parseValueExpression(); auto expr = parseValueExpression();
if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus' expression"); if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus' expression");
if (!op_str.empty()) { if (!op_str.empty()) {
auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
return nonstd_make_unique<UnaryOpExpr>(get_location(), std::move(expr), op); return nonstd_make_unique<UnaryOpExpr>(get_location(), std::move(expr), op);
} }
return expr; return expr;
} }
std::unique_ptr<Expression> parseValueExpression() { std::unique_ptr<Expression> parseValueExpression() {
auto parseValue = [&]() -> std::unique_ptr<Expression> { auto parseValue = [&]() -> std::unique_ptr<Expression> {
auto location = get_location(); auto location = get_location();
@ -1774,7 +1774,7 @@ private:
}; };
auto value = parseValue(); auto value = parseValue();
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) { if (!consumeToken("[").empty()) {
std::unique_ptr<Expression> index; std::unique_ptr<Expression> index;
@ -1797,7 +1797,7 @@ private:
} }
if (!index) throw std::runtime_error("Empty index in subscript"); if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = nonstd_make_unique<SubscriptExpr>(value->location, std::move(value), std::move(index)); value = nonstd_make_unique<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) { } else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier(); auto identifier = parseIdentifier();
@ -1825,10 +1825,10 @@ private:
std::unique_ptr<Expression> parseBracedExpressionOrArray() { std::unique_ptr<Expression> parseBracedExpressionOrArray() {
if (consumeToken("(").empty()) return nullptr; if (consumeToken("(").empty()) return nullptr;
auto expr = parseExpression(); auto expr = parseExpression();
if (!expr) throw std::runtime_error("Expected expression in braced expression"); if (!expr) throw std::runtime_error("Expected expression in braced expression");
if (!consumeToken(")").empty()) { if (!consumeToken(")").empty()) {
return expr; // Drop the parentheses return expr; // Drop the parentheses
} }
@ -1851,7 +1851,7 @@ private:
std::unique_ptr<Expression> parseArray() { std::unique_ptr<Expression> parseArray() {
if (consumeToken("[").empty()) return nullptr; if (consumeToken("[").empty()) return nullptr;
std::vector<std::unique_ptr<Expression>> elements; std::vector<std::unique_ptr<Expression>> elements;
if (!consumeToken("]").empty()) { if (!consumeToken("]").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements)); return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements));
@ -1876,7 +1876,7 @@ private:
std::unique_ptr<Expression> parseDictionary() { std::unique_ptr<Expression> parseDictionary() {
if (consumeToken("{").empty()) return nullptr; if (consumeToken("{").empty()) return nullptr;
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements; std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements;
if (!consumeToken("}").empty()) { if (!consumeToken("}").empty()) {
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements)); return nonstd_make_unique<DictExpr>(get_location(), std::move(elements));
@ -1892,7 +1892,7 @@ private:
}; };
parseKeyValuePair(); parseKeyValuePair();
while (it != end) { while (it != end) {
if (!consumeToken(",").empty()) { if (!consumeToken(",").empty()) {
parseKeyValuePair(); parseKeyValuePair();
@ -1950,15 +1950,15 @@ private:
static std::regex text_regex(R"([\s\S\n]*?($|(?=\{\{|\{%|\{#)))"); static std::regex text_regex(R"([\s\S\n]*?($|(?=\{\{|\{%|\{#)))");
static std::regex expr_close_regex(R"([\s\n]*([-~])?\}\})"); static std::regex expr_close_regex(R"([\s\n]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n]*([-~])?%\})"); static std::regex block_close_regex(R"([\s\n]*([-~])?%\})");
TemplateTokenVector tokens; TemplateTokenVector tokens;
std::vector<std::string> group; std::vector<std::string> group;
std::string text; std::string text;
try { try {
while (it != end) { while (it != end) {
auto location = get_location(); auto location = get_location();
if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) {
auto pre_space = parsePreSpace(group[1]); auto pre_space = parsePreSpace(group[1]);
auto content = group[2]; auto content = group[2];
@ -1985,7 +1985,7 @@ private:
}; };
if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword");
if (keyword == "if") { if (keyword == "if") {
auto condition = parseExpression(); auto condition = parseExpression();
if (!condition) throw std::runtime_error("Expected condition in if block"); if (!condition) throw std::runtime_error("Expected condition in if block");
@ -2019,7 +2019,7 @@ private:
condition = parseExpression(); condition = parseExpression();
} }
auto recursive = !consumeToken(recursive_tok).empty(); auto recursive = !consumeToken(recursive_tok).empty();
auto post_space = parseBlockClose(); auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); tokens.push_back(nonstd_make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
} else if (keyword == "endfor") { } else if (keyword == "endfor") {
@ -2034,7 +2034,7 @@ private:
if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
ns = group[1]; ns = group[1];
var_names.push_back(group[2]); var_names.push_back(group[2]);
if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block");
value = parseExpression(); value = parseExpression();
@ -2115,7 +2115,7 @@ private:
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) { } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
auto text = text_token->text; auto text = text_token->text;
if (pre_space == SpaceHandling::Strip) { if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
@ -2131,7 +2131,7 @@ private:
static std::regex trailing_last_line_space_regex(R"((^|\n)[ \t]*$)"); static std::regex trailing_last_line_space_regex(R"((^|\n)[ \t]*$)");
text = std::regex_replace(text, trailing_last_line_space_regex, "$1"); text = std::regex_replace(text, trailing_last_line_space_regex, "$1");
} }
if (it == end && !options.keep_trailing_newline) { if (it == end && !options.keep_trailing_newline) {
static std::regex r(R"([\n\r]$)"); static std::regex r(R"([\n\r]$)");
text = std::regex_replace(text, r, ""); // Strip one trailing newline text = std::regex_replace(text, r, ""); // Strip one trailing newline
@ -2473,7 +2473,7 @@ inline std::shared_ptr<Context> Context::builtins() {
int64_t start = param_set[0] ? startEndStep[0] : 0; int64_t start = param_set[0] ? startEndStep[0] : 0;
int64_t end = startEndStep[1]; int64_t end = startEndStep[1];
int64_t step = param_set[2] ? startEndStep[2] : 1; int64_t step = param_set[2] ? startEndStep[2] : 1;
auto res = Value::array(); auto res = Value::array();
if (step > 0) { if (step > 0) {
for (int64_t i = start; i < end; i += step) { for (int64_t i = start; i < end; i += step) {

View file

@ -147,7 +147,7 @@ bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler *
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str()); llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
return true; return true;
} }
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

View file

@ -84,7 +84,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
std::regex start_pattern(R"([\n\s]*<tool_call>)"); std::regex start_pattern(R"([\n\s]*<tool_call>)");
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)"); std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)"); std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
auto end = input.end(); auto end = input.end();
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern); std::sregex_iterator rit(input.begin(), end, start_pattern);
@ -176,7 +176,7 @@ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const std::str
it = rit->suffix().first; it = rit->suffix().first;
auto name = rit->str(1); auto name = rit->str(1);
json arguments; json arguments;
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
throw std::runtime_error("Failed to parse json tool call arguments"); throw std::runtime_error("Failed to parse json tool call arguments");
@ -229,7 +229,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
const nlohmann::ordered_json & tools) const nlohmann::ordered_json & tools)
{ {
llama_tool_call_handler handler; llama_tool_call_handler handler;
if (needs_functionary_v3_tool_call(chat_template)) { if (needs_functionary_v3_tool_call(chat_template)) {
// MeetKaiFunctionary_3_2 // MeetKaiFunctionary_3_2
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
@ -312,7 +312,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("<|python_tag|>"); handler.grammar_trigger_words.push_back("<|python_tag|>");
} }
} else { } else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back( tool_rules.push_back(
builder.add_rule( builder.add_rule(
name + "-call", name + "-call",

View file

@ -182,7 +182,7 @@ struct server_slot {
std::string stopping_word; std::string stopping_word;
llama_antiprompts antiprompts; llama_antiprompts antiprompts;
// sampling // sampling
json json_schema; json json_schema;

View file

@ -654,7 +654,7 @@ async def step_tool_called(context, expected_name, expected_arguments):
expected_name = expected_name if expected_name else None expected_name = expected_name if expected_name else None
expected_arguments = json.loads(expected_arguments) if expected_arguments else None expected_arguments = json.loads(expected_arguments) if expected_arguments else None
def check(tool_calls): def check(tool_calls):
if tool_calls is None: if tool_calls is None:
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}' assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
@ -1055,7 +1055,7 @@ async def oai_chat_completions(user_prompt,
user_api_key = user_api_key if user_api_key is not None else 'nope' user_api_key = user_api_key if user_api_key is not None else 'nope'
assert isinstance(seed, int), f'seed: {seed}' assert isinstance(seed, int), f'seed: {seed}'
seed = seed if seed is not None else 42 seed = seed if seed is not None else 42
enable_streaming = enable_streaming if enable_streaming is not None else False enable_streaming = enable_streaming if enable_streaming is not None else False
messages = [] messages = []
if system_prompt: if system_prompt:

View file

@ -353,7 +353,7 @@ static json oaicompat_completion_params_parse(
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty(); auto has_tools = tools.is_array() && !tools.empty();
// Apply chat template to the list of messages // Apply chat template to the list of messages
auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src; auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src;
llama_params["chat_template"] = chat_template; llama_params["chat_template"] = chat_template;
@ -420,7 +420,7 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true; llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls; llama_params["parallel_tool_calls"] = parallel_tool_calls;
} }
// Handle "n" field // Handle "n" field
int n_choices = json_value(body, "n", 1); int n_choices = json_value(body, "n", 1);
if (n_choices != 1) { if (n_choices != 1) {

View file

@ -12,4 +12,4 @@
"add_generation_prompt": true, "add_generation_prompt": true,
"bos_token": "<|startoftext|>", "bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>" "eos_token": "<|endoftext|>"
} }

View file

@ -16,4 +16,4 @@
"add_generation_prompt": true, "add_generation_prompt": true,
"bos_token": "<|startoftext|>", "bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>" "eos_token": "<|endoftext|>"
} }

View file

@ -161,4 +161,4 @@
} }
} }
] ]
} }

View file

@ -26,12 +26,12 @@ int main()
}; };
const std::vector<std::string> stop_words { }; const std::vector<std::string> stop_words { };
const std::vector<std::string> grammar_trigger_words { }; const std::vector<std::string> grammar_trigger_words { };
printf("Testing antiprompts\n"); printf("Testing antiprompts\n");
llama_antiprompts antiprompts; llama_antiprompts antiprompts;
antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"}); antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"});
assert_equal(antiprompts.findSingleTokenMatch('x'), { assert_equal(antiprompts.findSingleTokenMatch('x'), {
.pos = 0, .pos = 0,
.pattern = "x", .pattern = "x",

View file

@ -17,7 +17,7 @@ int main(void) {
std::string expected_output; std::string expected_output;
std::string jinja_expected_output; std::string jinja_expected_output;
}; };
std::vector<llama_chat_message> conversation { std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"}, {"system", "You are a helpful assistant"},
{"user", "Hello"}, {"user", "Hello"},
@ -100,7 +100,7 @@ int main(void) {
.tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
.expected_output = "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:", .expected_output = "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
}, },
{ {
// No template included in tokenizer_config.json, so this template likely needs to be manually set. // No template included in tokenizer_config.json, so this template likely needs to be manually set.
.name = "Orca-Vicuna", .name = "Orca-Vicuna",
.tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", .tmpl = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
@ -157,7 +157,7 @@ int main(void) {
.expected_output = u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<end▁of▁sentence>User: Who are you\n\nAssistant: I am an assistant <end▁of▁sentence>User: Another question\n\nAssistant:", .expected_output = u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<end▁of▁sentence>User: Who are you\n\nAssistant: I am an assistant <end▁of▁sentence>User: Another question\n\nAssistant:",
} }
}; };
std::vector<char> formatted_chat(1024); std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;

View file

@ -1,6 +1,6 @@
/* /*
Minimalistic Jinja templating engine for llama.cpp. C++11, no deps (single-header), decent language support but very few functions (easy to extend), just whats needed for actual prompt templates. Minimalistic Jinja templating engine for llama.cpp. C++11, no deps (single-header), decent language support but very few functions (easy to extend), just whats needed for actual prompt templates.
Models have increasingly complex templates (e.g. Llama 3.1, Hermes 2 Pro w/ tool_use), so we need a proper template engine to get the best out of them. Models have increasingly complex templates (e.g. Llama 3.1, Hermes 2 Pro w/ tool_use), so we need a proper template engine to get the best out of them.
Supports: Supports:
@ -20,7 +20,7 @@
- No tuples (templates seem to rely on lists only) - No tuples (templates seem to rely on lists only)
- No `if` expressions w/o `else` (but `if` statements are fine) - No `if` expressions w/o `else` (but `if` statements are fine)
- No `{% raw %}`, `{% block %}`, `{% include %}`, `{% extends %}, - No `{% raw %}`, `{% block %}`, `{% include %}`, `{% extends %},
Model templates verified to work: Model templates verified to work:
- Meta-Llama-3.1-8B-Instruct - Meta-Llama-3.1-8B-Instruct
- Phi-3.5-mini-instruct - Phi-3.5-mini-instruct
@ -160,7 +160,7 @@ static void test_template_features() {
test_render(R"({{ {"a": "b"} | tojson }})", {}, {}, R"({"a": "b"})"); test_render(R"({{ {"a": "b"} | tojson }})", {}, {}, R"({"a": "b"})");
test_render(R"({{ {"a": "b"} }})", {}, {}, R"({'a': 'b'})"); test_render(R"({{ {"a": "b"} }})", {}, {}, R"({'a': 'b'})");
std::string trim_tmpl = std::string trim_tmpl =
"\n" "\n"
" {% if true %}Hello{% endif %} \n" " {% if true %}Hello{% endif %} \n"
"...\n" "...\n"
@ -228,7 +228,7 @@ static void test_template_features() {
({{ i }}, {{ loop.cycle('odd', 'even') }}), ({{ i }}, {{ loop.cycle('odd', 'even') }}),
{%- endfor -%} {%- endfor -%}
)", {}, {}, "(0, odd),(1, even),(2, odd),(3, even),(4, odd),"); )", {}, {}, "(0, odd),(1, even),(2, odd),(3, even),(4, odd),");
test_render( test_render(
"{%- for i in range(5) if i % 2 == 0 -%}\n" "{%- for i in range(5) if i % 2 == 0 -%}\n"
"{{ i }}, first={{ loop.first }}, last={{ loop.last }}, index={{ loop.index }}, index0={{ loop.index0 }}, revindex={{ loop.revindex }}, revindex0={{ loop.revindex0 }}, prev={{ loop.previtem }}, next={{ loop.nextitem }},\n" "{{ i }}, first={{ loop.first }}, last={{ loop.last }}, index={{ loop.index }}, index0={{ loop.index0 }}, revindex={{ loop.revindex }}, revindex0={{ loop.revindex0 }}, prev={{ loop.previtem }}, next={{ loop.nextitem }},\n"
@ -237,7 +237,7 @@ static void test_template_features() {
"0, first=True, last=False, index=1, index0=0, revindex=3, revindex0=2, prev=, next=2,\n" "0, first=True, last=False, index=1, index0=0, revindex=3, revindex0=2, prev=, next=2,\n"
"2, first=False, last=False, index=2, index0=1, revindex=2, revindex0=1, prev=0, next=4,\n" "2, first=False, last=False, index=2, index0=1, revindex=2, revindex0=1, prev=0, next=4,\n"
"4, first=False, last=True, index=3, index0=2, revindex=1, revindex0=0, prev=2, next=,\n"); "4, first=False, last=True, index=3, index0=2, revindex=1, revindex0=0, prev=2, next=,\n");
test_render( test_render(
R"( R"(
{%- set res = [] -%} {%- set res = [] -%}
@ -262,7 +262,7 @@ static void test_template_features() {
{% macro input(name, value='', type='text', size=20) -%} {% macro input(name, value='', type='text', size=20) -%}
<input type="{{ type }}" name="{{ name }}" value="{{ value|e }}" size="{{ size }}"> <input type="{{ type }}" name="{{ name }}" value="{{ value|e }}" size="{{ size }}">
{%- endmacro -%} {%- endmacro -%}
<p>{{ input('username') }}</p> <p>{{ input('username') }}</p>
<p>{{ input('password', type='password') }}</p>)", <p>{{ input('password', type='password') }}</p>)",
{}, {}, R"( {}, {}, R"(
@ -314,14 +314,14 @@ static void test_template_features() {
{{- x }},{{ y -}}; {{- x }},{{ y -}};
{%- endfor -%} {%- endfor -%}
)", {{"z", json({json({1, 10}), json({2, 20})})}}, {}, "1,10;2,20;"); )", {{"z", json({json({1, 10}), json({2, 20})})}}, {}, "1,10;2,20;");
test_render(" a {{ 'b' -}} c ", {}, {}, " a bc "); test_render(" a {{ 'b' -}} c ", {}, {}, " a bc ");
test_render(" a {{- 'b' }} c ", {}, {}, " ab c "); test_render(" a {{- 'b' }} c ", {}, {}, " ab c ");
test_render("a\n{{- 'b' }}\nc", {}, {}, "ab\nc"); test_render("a\n{{- 'b' }}\nc", {}, {}, "ab\nc");
test_render("a\n{{ 'b' -}}\nc", {}, {}, "a\nbc"); test_render("a\n{{ 'b' -}}\nc", {}, {}, "a\nbc");
test_error_contains("{{ raise_exception('hey') }}", {}, {}, "hey"); test_error_contains("{{ raise_exception('hey') }}", {}, {}, "hey");
test_render("{{ [] is iterable }}", {}, {}, "True"); test_render("{{ [] is iterable }}", {}, {}, "True");
test_render("{{ [] is not number }}", {}, {}, "True"); test_render("{{ [] is not number }}", {}, {}, "True");
test_render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}, "[1, 2, 3][0, 1][1, 2]"); test_render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}, "[1, 2, 3][0, 1][1, 2]");
@ -343,16 +343,16 @@ static void test_template_features() {
test_error_contains("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}, "Unterminated if"); test_error_contains("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}, "Unterminated if");
test_render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {}, ""); test_render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {}, "");
test_render( test_render(
"{% set x = [] %}{% set _ = x.append(1) %}{{ x | tojson(indent=2) }}", {}, {}, "{% set x = [] %}{% set _ = x.append(1) %}{{ x | tojson(indent=2) }}", {}, {},
"[\n 1\n]"); "[\n 1\n]");
test_render( test_render(
"{{ not [] }}", {}, {}, "{{ not [] }}", {}, {},
"True"); "True");
test_render("{{ tool.function.name == 'ipython' }}", test_render("{{ tool.function.name == 'ipython' }}",
json({{"tool", json({ json({{"tool", json({
{"function", {{"name", "ipython"}}} {"function", {{"name", "ipython"}}}
})}}), })}}),
@ -369,7 +369,7 @@ static void test_template_features() {
static void test_chat_templates_with_common_contexts_against_goldens() { static void test_chat_templates_with_common_contexts_against_goldens() {
auto jinja_template_files = find_files("tests/chat/templates", ".jinja"); auto jinja_template_files = find_files("tests/chat/templates", ".jinja");
auto context_files = find_files("tests/chat/contexts", ".json"); auto context_files = find_files("tests/chat/contexts", ".json");
auto get_golden_file = [&](const std::string & tmpl_file, const std::string & ctx_file) { auto get_golden_file = [&](const std::string & tmpl_file, const std::string & ctx_file) {
auto tmpl_name = filename_without_extension(tmpl_file); auto tmpl_name = filename_without_extension(tmpl_file);
auto ctx_name = filename_without_extension(ctx_file); auto ctx_name = filename_without_extension(ctx_file);
@ -431,4 +431,4 @@ int main() {
} }
return 0; return 0;
} }

View file

@ -58,7 +58,7 @@ int main() {
json request = { json request = {
{"tools", tools} {"tools", tools}
}; };
std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it"; std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it";
test_parse_tool_call(tools, hermes_2_pro_like_tmpl, test_parse_tool_call(tools, hermes_2_pro_like_tmpl,
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>", "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
@ -71,7 +71,7 @@ int main() {
}).dump()} }).dump()}
}} }}
}}); }});
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
test_parse_tool_call(tools, functionary_v3_like_tmpl, test_parse_tool_call(tools, functionary_v3_like_tmpl,
">>>ipython\nprint('Hello, world!')", ">>>ipython\nprint('Hello, world!')",
@ -84,7 +84,7 @@ int main() {
}).dump()} }).dump()}
}} }}
}}); }});
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it"; std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!", "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
@ -107,7 +107,7 @@ int main() {
}} }}
}, },
}); });
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
test_parse_tool_call(tools, llama_3_1_like_tmpl, test_parse_tool_call(tools, llama_3_1_like_tmpl,
"<|python_tag|>this could be anything", "<|python_tag|>this could be anything",
@ -145,4 +145,4 @@ int main() {
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
return 0; return 0;
} }

View file

@ -8,10 +8,10 @@
# /// # ///
''' '''
Fetches the Jinja2 templates of a few known models and use them to generate prompt goldens for a few predefined chat contexts. Fetches the Jinja2 templates of a few known models and use them to generate prompt goldens for a few predefined chat contexts.
Examples: Examples:
python ./tests/update_jinja_goldens.py python ./tests/update_jinja_goldens.py
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
''' '''
@ -33,12 +33,12 @@ model_ids = [
"Qwen/Qwen2-7B-Instruct", "Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-Math-7B-Instruct", "Qwen/Qwen2.5-Math-7B-Instruct",
"microsoft/Phi-3-mini-4k-instruct", "microsoft/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-small-8k-instruct", "microsoft/Phi-3-small-8k-instruct",
"microsoft/Phi-3-medium-4k-instruct", "microsoft/Phi-3-medium-4k-instruct",
"microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3.5-mini-instruct",
"indischepartij/MiniCPM-3B-OpenHermes-2.5-v2", "indischepartij/MiniCPM-3B-OpenHermes-2.5-v2",
"teknium/OpenHermes-2.5-Mistral-7B", "teknium/OpenHermes-2.5-Mistral-7B",
"TheBloke/FusionNet_34Bx2_MoE-AWQ", "TheBloke/FusionNet_34Bx2_MoE-AWQ",
"bofenghuang/vigogne-2-70b-chat", "bofenghuang/vigogne-2-70b-chat",
@ -46,18 +46,18 @@ model_ids = [
"OrionStarAI/Orion-14B-Chat", "OrionStarAI/Orion-14B-Chat",
"openchat/openchat-3.5-0106", "openchat/openchat-3.5-0106",
"deepseek-ai/deepseek-coder-33b-instruct", "deepseek-ai/deepseek-coder-33b-instruct",
"abacusai/Fewshot-Metamath-OrcaVicuna-Mistral", "abacusai/Fewshot-Metamath-OrcaVicuna-Mistral",
"CohereForAI/c4ai-command-r-plus", "CohereForAI/c4ai-command-r-plus",
"THUDM/chatglm3-6b", "THUDM/chatglm3-6b",
"derek33125/project-angel-chatglm4", "derek33125/project-angel-chatglm4",
"deepseek-ai/DeepSeek-Coder-V2-Instruct", "deepseek-ai/DeepSeek-Coder-V2-Instruct",
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
"deepseek-ai/DeepSeek-V2.5", "deepseek-ai/DeepSeek-V2.5",
# Needs debugging: # Needs debugging:
# "eachadea/vicuna-13b-1.1", # "eachadea/vicuna-13b-1.1",
# "microsoft/Phi-3-vision-instruct", # "microsoft/Phi-3-vision-instruct",
# Gated models: # Gated models:
"meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct",
"google/gemma-7b-it", "google/gemma-7b-it",
@ -83,9 +83,9 @@ def handle_chat_template(model_id, variant, template_src):
print(f'template_file: {template_file}') print(f'template_file: {template_file}')
with open(template_file, 'w') as f: with open(template_file, 'w') as f:
f.write(template_src) f.write(template_src)
print(f"- {template_file}", flush=True) print(f"- {template_file}", flush=True)
env = jinja2.Environment( env = jinja2.Environment(
trim_blocks=True, trim_blocks=True,
lstrip_blocks=True, lstrip_blocks=True,
@ -99,25 +99,25 @@ def handle_chat_template(model_id, variant, template_src):
template_handles_tools = 'tools' in template_src template_handles_tools = 'tools' in template_src
template_hates_the_system = 'System role not supported' in template_src template_hates_the_system = 'System role not supported' in template_src
template = env.from_string(template_src) template = env.from_string(template_src)
context_files = glob.glob('tests/chat/contexts/*.json') context_files = glob.glob('tests/chat/contexts/*.json')
for context_file in context_files: for context_file in context_files:
context_name = context_file.split("/")[-1].replace(".json", "") context_name = context_file.split("/")[-1].replace(".json", "")
with open(context_file, 'r') as f: with open(context_file, 'r') as f:
context = json.load(f) context = json.load(f)
if not template_handles_tools and 'tools' in context: if not template_handles_tools and 'tools' in context:
continue continue
if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']):
continue continue
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt' output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
print(f"- {output_file}", flush=True) print(f"- {output_file}", flush=True)
try: try:
output = template.render(**context) output = template.render(**context)
except: except:
# Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message. # Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message.
for message in context["messages"]: for message in context["messages"]:
@ -132,27 +132,27 @@ def handle_chat_template(model_id, variant, template_src):
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
f.write(output) f.write(output)
print() print()
def main(): def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']: for dir in ['tests/chat/templates', 'tests/chat/goldens']:
if not os.path.isdir(dir): if not os.path.isdir(dir):
os.mkdir(dir) os.mkdir(dir)
for model_id in model_ids: for model_id in model_ids:
# response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") # response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
# response.raise_for_status() # response.raise_for_status()
# config_str = response.text # config_str = response.text
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read() config_str = f.read()
try: try:
config = json.loads(config_str) config = json.loads(config_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file) # (Remove extra '}' near the end of the file)
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template'] chat_template = config['chat_template']
if isinstance(chat_template, str): if isinstance(chat_template, str):
@ -162,4 +162,4 @@ def main():
handle_chat_template(model_id, ct['name'], ct['template']) handle_chat_template(model_id, ct['name'], ct['template'])
if __name__ == '__main__': if __name__ == '__main__':
main() main()