sync: minja

This commit is contained in:
ochafik 2025-01-26 21:36:25 +00:00
parent 11594557e3
commit 43385b2ff2
3 changed files with 128 additions and 105 deletions

View file

@ -61,7 +61,7 @@ class chat_template {
});
supports_tools_ = source.find("tools") != std::string::npos;
requires_object_arguments_ =
requires_object_arguments_ =
try_raw_render({
{
{"role", "user"},
@ -298,7 +298,7 @@ class chat_template {
if (!tools.is_null()) {
auto tools_val = minja::Value(actual_tools);
context->set("tools", tools_val);
if (has_code_interpreter) {
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
context->set("builtin_tools", builtin_tools_val);
}

View file

@ -2648,31 +2648,34 @@ inline std::shared_ptr<Context> Context::builtins() {
return filter.call(context, actual_args);
});
};
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
auto select_or_reject = [make_filter](bool is_select) {
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) {
filter_args.push_back(args.args[i]);
}
auto filter = make_filter(filter_fn, filter_args);
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
ArgumentsValue filter_args;
filter_args.args.emplace_back(item);
auto pred_res = filter.call(context, filter_args);
if (!pred_res.to_bool()) {
res.push_back(item);
auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) {
filter_args.push_back(args.args[i]);
}
}
return res;
}));
auto filter = make_filter(filter_fn, filter_args);
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
ArgumentsValue filter_args;
filter_args.args.emplace_back(item);
auto pred_res = filter.call(context, filter_args);
if (pred_res.to_bool() == (is_select ? true : false)) {
res.push_back(item);
}
}
return res;
});
};
globals.set("select", select_or_reject(/* is_select= */ true));
globals.set("reject", select_or_reject(/* is_select= */ false));
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
auto res = Value::array();
if (args.args.size() == 1 &&
@ -2720,41 +2723,45 @@ inline std::shared_ptr<Context> Context::builtins() {
if (!text.empty() && text.back() == '\n') out += "\n";
return out;
}));
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
if (items.is_null())
return Value::array();
auto attr_name = args.args[1].get<std::string>();
auto select_or_reject_attr = [](bool is_select) {
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0];
if (items.is_null())
return Value::array();
auto attr_name = args.args[1].get<std::string>();
bool has_test = false;
Value test_fn;
ArgumentsValue test_args {{Value()}, {}};
if (args.args.size() >= 3) {
has_test = true;
test_fn = context->get(args.args[2]);
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
for (size_t i = 3, n = args.args.size(); i < n; i++) {
test_args.args.emplace_back(args.args[i]);
}
test_args.kwargs = args.kwargs;
}
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
auto attr = item.get(attr_name);
if (has_test) {
test_args.args[0] = attr;
if (test_fn.call(context, test_args).to_bool()) {
res.push_back(item);
bool has_test = false;
Value test_fn;
ArgumentsValue test_args {{Value()}, {}};
if (args.args.size() >= 3) {
has_test = true;
test_fn = context->get(args.args[2]);
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
for (size_t i = 3, n = args.args.size(); i < n; i++) {
test_args.args.emplace_back(args.args[i]);
}
} else {
res.push_back(attr);
test_args.kwargs = args.kwargs;
}
}
return res;
}));
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
auto & item = items.at(i);
auto attr = item.get(attr_name);
if (has_test) {
test_args.args[0] = attr;
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
res.push_back(item);
}
} else {
res.push_back(attr);
}
}
return res;
});
};
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
std::vector<int64_t> startEndStep(3);
std::vector<bool> param_set(3);

View file

@ -211,7 +211,6 @@ struct server_task {
static slot_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
const common_chat_template * tmpl,
const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@ -330,30 +329,19 @@ struct server_task {
}
}
if (tmpl && params_base.use_jinja) {
common_chat_params chat_params;
chat_params.messages = json_value(data, "messages", json::array());
chat_params.tools = json_value(data, "tools", json());
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
chat_params.json_schema = json_value(data, "json_schema", json());
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
chat_params.stream = json_value(data, "stream", false);
auto chat_data = common_chat_init(*tmpl, chat_params);
params.chat_parser = std::move(chat_data.handler);
params.sampling.grammar = chat_data.grammar;
for (const auto & stop : chat_data.additional_stops) {
params.antiprompt.push_back(stop);
if (!params_base.use_jinja) {
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
}
@ -363,15 +351,13 @@ struct server_task {
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema);
params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object()));
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
{
params.sampling.logit_bias.clear();
@ -2248,9 +2234,15 @@ struct server_context {
}
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
if (!opt_msg) {
return;
common_chat_msg msg;
if (slot.params.chat_parser) {
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
msg = *opt_msg;
} else {
return;
}
} else {
msg.content = tkn.text_to_send;
}
auto res = std::make_unique<server_task_result_cmpl_partial>();
@ -2267,7 +2259,7 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = *opt_msg;
res->oaicompat_chat_msg = msg;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -2308,7 +2300,11 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg {
/* .role = */ "assistant",
/* .content = */ slot.generated_text,
/* .tool_calls = */ {}
};
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
std::function<bool()> is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat,
const common_chat_template * tmpl) {
const common_chat_template * tmpl = nullptr) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
try {
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
std::string prompt;
common_chat_data chat_data;
if (tmpl && ctx_server.params_base.use_jinja) {
auto chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_data(data, "messages", json::array()),
/* .tools = */ json_data(data, "tools", json()),
/
chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
/* .tools = */ json_value(data, "tools", json()),
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
/* .json_schema = */ json_value(data, "json_schema", json()),
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
/* .stream = */ json_value(data, "json_schema", false),
/* .grammar = */ json_value(data, "grammar", std::string("")),
});
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
}
} else {
prompt = data.at("prompt").get<std::string>();
chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
} else if (data.contains("json_schema")) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
task.params.chat_parser = common_chat_init()
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
nullptr,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.chat_parser = common_chat_init()
task.params.oaicompat_tools = json_value(data, "tools", json());
task.params.oaicompat_tool_call_style = tool_call_style;
task.params.sampling.grammar = chat_data.grammar;
for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
task.params.sampling.grammar_trigger_words.push_back(trigger);
}
task.params.antiprompt = chat_data.additional_stops;
if (chat_data.parser) {
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
}
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task);
@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
data,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT);
OAICOMPAT_TYPE_CHAT,
&chat_template);
};
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {