sync: minja

This commit is contained in:
ochafik 2025-01-26 21:36:25 +00:00
parent 11594557e3
commit 43385b2ff2
3 changed files with 128 additions and 105 deletions

View file

@ -298,7 +298,7 @@ class chat_template {
if (!tools.is_null()) { if (!tools.is_null()) {
auto tools_val = minja::Value(actual_tools); auto tools_val = minja::Value(actual_tools);
context->set("tools", tools_val); context->set("tools", tools_val);
if (has_code_interpreter) { if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
auto builtin_tools_val = minja::Value(json {"code_interpreter"}); auto builtin_tools_val = minja::Value(json {"code_interpreter"});
context->set("builtin_tools", builtin_tools_val); context->set("builtin_tools", builtin_tools_val);
} }

View file

@ -2648,9 +2648,9 @@ inline std::shared_ptr<Context> Context::builtins() {
return filter.call(context, actual_args); return filter.call(context, actual_args);
}); });
}; };
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject auto select_or_reject = [make_filter](bool is_select) {
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) { return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0}); args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0]; auto & items = args.args[0];
auto filter_fn = context->get(args.args[1]); auto filter_fn = context->get(args.args[1]);
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
@ -2667,12 +2667,15 @@ inline std::shared_ptr<Context> Context::builtins() {
ArgumentsValue filter_args; ArgumentsValue filter_args;
filter_args.args.emplace_back(item); filter_args.args.emplace_back(item);
auto pred_res = filter.call(context, filter_args); auto pred_res = filter.call(context, filter_args);
if (!pred_res.to_bool()) { if (pred_res.to_bool() == (is_select ? true : false)) {
res.push_back(item); res.push_back(item);
} }
} }
return res; return res;
})); });
};
globals.set("select", select_or_reject(/* is_select= */ true));
globals.set("reject", select_or_reject(/* is_select= */ false));
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) { globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
auto res = Value::array(); auto res = Value::array();
if (args.args.size() == 1 && if (args.args.size() == 1 &&
@ -2720,8 +2723,9 @@ inline std::shared_ptr<Context> Context::builtins() {
if (!text.empty() && text.back() == '\n') out += "\n"; if (!text.empty() && text.back() == '\n') out += "\n";
return out; return out;
})); }));
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) { auto select_or_reject_attr = [](bool is_select) {
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0}); return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
auto & items = args.args[0]; auto & items = args.args[0];
if (items.is_null()) if (items.is_null())
return Value::array(); return Value::array();
@ -2746,7 +2750,7 @@ inline std::shared_ptr<Context> Context::builtins() {
auto attr = item.get(attr_name); auto attr = item.get(attr_name);
if (has_test) { if (has_test) {
test_args.args[0] = attr; test_args.args[0] = attr;
if (test_fn.call(context, test_args).to_bool()) { if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
res.push_back(item); res.push_back(item);
} }
} else { } else {
@ -2754,7 +2758,10 @@ inline std::shared_ptr<Context> Context::builtins() {
} }
} }
return res; return res;
})); });
};
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) { globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
std::vector<int64_t> startEndStep(3); std::vector<int64_t> startEndStep(3);
std::vector<bool> param_set(3); std::vector<bool> param_set(3);

View file

@ -211,7 +211,6 @@ struct server_task {
static slot_params params_from_json_cmpl( static slot_params params_from_json_cmpl(
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
const common_chat_template * tmpl,
const json & data) { const json & data) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -330,34 +329,7 @@ struct server_task {
} }
} }
if (tmpl && params_base.use_jinja) { if (!params_base.use_jinja) {
common_chat_params chat_params;
chat_params.messages = json_value(data, "messages", json::array());
chat_params.tools = json_value(data, "tools", json());
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
chat_params.json_schema = json_value(data, "json_schema", json());
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
chat_params.stream = json_value(data, "stream", false);
auto chat_data = common_chat_init(*tmpl, chat_params);
params.chat_parser = std::move(chat_data.handler);
params.sampling.grammar = chat_data.grammar;
for (const auto & stop : chat_data.additional_stops) {
params.antiprompt.push_back(stop);
}
for (const auto & trigger : chat_data.grammar_triggers) {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
} }
@ -371,7 +343,21 @@ struct server_task {
} else { } else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
} }
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str()); }
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object()));
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
{ {
params.sampling.logit_bias.clear(); params.sampling.logit_bias.clear();
@ -2248,10 +2234,16 @@ struct server_context {
} }
void send_partial_response(server_slot & slot, const completion_token_output & tkn) { void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send); common_chat_msg msg;
if (!opt_msg) { if (slot.params.chat_parser) {
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
msg = *opt_msg;
} else {
return; return;
} }
} else {
msg.content = tkn.text_to_send;
}
auto res = std::make_unique<server_task_result_cmpl_partial>(); auto res = std::make_unique<server_task_result_cmpl_partial>();
res->id = slot.id_task; res->id = slot.id_task;
@ -2267,7 +2259,7 @@ struct server_context {
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = *opt_msg; res->oaicompat_chat_msg = msg;
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -2308,7 +2300,11 @@ struct server_context {
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text); res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg {
/* .role = */ "assistant",
/* .content = */ slot.generated_text,
/* .tool_calls = */ {}
};
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
std::function<bool()> is_connection_closed, std::function<bool()> is_connection_closed,
httplib::Response & res, httplib::Response & res,
oaicompat_type oaicompat, oaicompat_type oaicompat,
const common_chat_template * tmpl) { const common_chat_template * tmpl = nullptr) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
try { try {
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str()); common_chat_data chat_data;
std::string prompt;
if (tmpl && ctx_server.params_base.use_jinja) { if (tmpl && ctx_server.params_base.use_jinja) {
auto chat_data = common_chat_init(*tmpl, { chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_data(data, "messages", json::array()), /* .messages = */ json_value(data, "messages", json::array()),
/* .tools = */ json_data(data, "tools", json()), /* .tools = */ json_value(data, "tools", json()),
/ /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
/* .json_schema = */ json_value(data, "json_schema", json()),
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
/* .stream = */ json_value(data, "json_schema", false),
/* .grammar = */ json_value(data, "grammar", std::string("")),
}); });
if (data.contains("grammar")) {
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>()); chat_data.grammar = data.at("grammar");
} else {
prompt = data.at("prompt").get<std::string>();
} }
task.params.chat_parser = common_chat_init() } else {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
} else if (data.contains("json_schema")) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);
@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
task.params = server_task::params_from_json_cmpl( task.params = server_task::params_from_json_cmpl(
ctx_server.ctx, ctx_server.ctx,
ctx_server.params_base, ctx_server.params_base,
nullptr,
data); data);
task.id_selected_slot = json_value(data, "id_slot", -1); task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat // OAI-compat
task.params.oaicompat = oaicompat; task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_cmpl_id = completion_id;
task.params.chat_parser = common_chat_init() task.params.sampling.grammar = chat_data.grammar;
task.params.oaicompat_tools = json_value(data, "tools", json()); for (const auto & trigger : chat_data.grammar_triggers) {
task.params.oaicompat_tool_call_style = tool_call_style; auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
continue;
}
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
task.params.sampling.grammar_trigger_words.push_back(trigger);
}
task.params.antiprompt = chat_data.additional_stops;
if (chat_data.parser) {
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
}
// oaicompat_model is already populated by params_from_json_cmpl // oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task); tasks.push_back(task);
@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
data, data,
req.is_connection_closed, req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_CHAT); OAICOMPAT_TYPE_CHAT,
&chat_template);
}; };
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {