diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 05f093159..e0a9a1c56 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -61,7 +61,7 @@ class chat_template { }); supports_tools_ = source.find("tools") != std::string::npos; - requires_object_arguments_ = + requires_object_arguments_ = try_raw_render({ { {"role", "user"}, @@ -298,7 +298,7 @@ class chat_template { if (!tools.is_null()) { auto tools_val = minja::Value(actual_tools); context->set("tools", tools_val); - if (has_code_interpreter) { + if (has_code_interpreter && !extra_context.contains("builtin_tools")) { auto builtin_tools_val = minja::Value(json {"code_interpreter"}); context->set("builtin_tools", builtin_tools_val); } diff --git a/common/minja.hpp b/common/minja.hpp index 80bdd4b41..604e61389 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -2648,31 +2648,34 @@ inline std::shared_ptr Context::builtins() { return filter.call(context, actual_args); }); }; - // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject - globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + auto select_or_reject = [make_filter](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); - auto filter_args = Value::array(); - for (size_t i = 2, n = args.args.size(); i < n; i++) { - filter_args.push_back(args.args[i]); - } - auto filter = make_filter(filter_fn, filter_args); - - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - ArgumentsValue filter_args; - filter_args.args.emplace_back(item); - auto pred_res = filter.call(context, filter_args); - if (!pred_res.to_bool()) { - res.push_back(item); + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); } - } - return res; - })); + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (pred_res.to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } + return res; + }); + }; + globals.set("select", select_or_reject(/* is_select= */ true)); + globals.set("reject", select_or_reject(/* is_select= */ false)); globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { auto res = Value::array(); if (args.args.size() == 1 && @@ -2720,41 +2723,45 @@ inline std::shared_ptr Context::builtins() { if (!text.empty() && text.back() == '\n') out += "\n"; return out; })); - globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - if (items.is_null()) - return Value::array(); - auto attr_name = args.args[1].get(); + auto select_or_reject_attr = [](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); - bool has_test = false; - Value test_fn; - ArgumentsValue test_args {{Value()}, {}}; - if (args.args.size() >= 3) { - has_test = true; - test_fn = context->get(args.args[2]); - if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); - for (size_t i = 3, n = args.args.size(); i < n; i++) { - test_args.args.emplace_back(args.args[i]); - } - test_args.kwargs = args.kwargs; - } - - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - auto attr = item.get(attr_name); - if (has_test) { - test_args.args[0] = attr; - if (test_fn.call(context, test_args).to_bool()) { - res.push_back(item); + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); } - } else { - res.push_back(attr); + test_args.kwargs = args.kwargs; } - } - return res; - })); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + }); + }; + globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); + globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { std::vector startEndStep(3); std::vector param_set(3); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 345b6ee8a..925f4f8ef 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -211,7 +211,6 @@ struct server_task { static slot_params params_from_json_cmpl( const llama_context * ctx, const common_params & params_base, - const common_chat_template * tmpl, const json & data) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -330,30 +329,19 @@ struct server_task { } } - if (tmpl && params_base.use_jinja) { - common_chat_params chat_params; - chat_params.messages = json_value(data, "messages", json::array()); - chat_params.tools = json_value(data, "tools", json()); - chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto")); - chat_params.json_schema = json_value(data, "json_schema", json()); - chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false); - chat_params.stream = json_value(data, "stream", false); - - auto chat_data = common_chat_init(*tmpl, chat_params); - params.chat_parser = std::move(chat_data.handler); - params.sampling.grammar = chat_data.grammar; - for (const auto & stop : chat_data.additional_stops) { - params.antiprompt.push_back(stop); + if (!params_base.use_jinja) { + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); } - for (const auto & trigger : chat_data.grammar_triggers) { - auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]); - params.sampling.grammar_trigger_tokens.push_back(ids[0]); - continue; + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + params.sampling.grammar = json_schema_to_grammar(schema); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } - LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str()); - params.sampling.grammar_trigger_words.push_back(trigger); + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); } } @@ -363,15 +351,13 @@ struct server_task { } if (data.contains("json_schema") && !data.contains("grammar")) { try { - auto schema = json_value(data, "json_schema", json::object()); - params.sampling.grammar = json_schema_to_grammar(schema); + params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object())); } catch (const std::exception & e) { throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } } else { params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); } - LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str()); { params.sampling.logit_bias.clear(); @@ -2248,9 +2234,15 @@ struct server_context { } void send_partial_response(server_slot & slot, const completion_token_output & tkn) { - auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send); - if (!opt_msg) { - return; + common_chat_msg msg; + if (slot.params.chat_parser) { + if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) { + msg = *opt_msg; + } else { + return; + } + } else { + msg.content = tkn.text_to_send; } auto res = std::make_unique(); @@ -2267,7 +2259,7 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_chat_msg = *opt_msg; + res->oaicompat_chat_msg = msg; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -2308,7 +2300,11 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text); + res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg { + /* .role = */ "assistant", + /* .content = */ slot.generated_text, + /* .tool_calls = */ {} + }; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) { std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat, - const common_chat_template * tmpl) { + const common_chat_template * tmpl = nullptr) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) { std::vector tasks; try { - fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get().c_str()); - std::string prompt; + common_chat_data chat_data; if (tmpl && ctx_server.params_base.use_jinja) { - auto chat_data = common_chat_init(*tmpl, { - /* .messages = */ json_data(data, "messages", json::array()), - /* .tools = */ json_data(data, "tools", json()), - / + chat_data = common_chat_init(*tmpl, { + /* .messages = */ json_value(data, "messages", json::array()), + /* .tools = */ json_value(data, "tools", json()), + /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")), + /* .json_schema = */ json_value(data, "json_schema", json()), + /* .parallel_tool_calls = */ json_value(data, "json_schema", true), + /* .stream = */ json_value(data, "json_schema", false), + /* .grammar = */ json_value(data, "grammar", std::string("")), }); - - prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get()); + if (data.contains("grammar")) { + chat_data.grammar = data.at("grammar"); + } } else { - prompt = data.at("prompt").get(); + chat_data.prompt = data.at("prompt"); + if (data.contains("grammar")) { + chat_data.grammar = data.at("grammar"); + } else if (data.contains("json_schema")) { + chat_data.grammar = json_schema_to_grammar(data.at("json_schema")); + } } - task.params.chat_parser = common_chat_init() - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type); @@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) { task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, - nullptr, data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; task.params.oaicompat_cmpl_id = completion_id; - task.params.chat_parser = common_chat_init() - task.params.oaicompat_tools = json_value(data, "tools", json()); - task.params.oaicompat_tool_call_style = tool_call_style; + task.params.sampling.grammar = chat_data.grammar; + for (const auto & trigger : chat_data.grammar_triggers) { + auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]); + task.params.sampling.grammar_trigger_tokens.push_back(ids[0]); + continue; + } + LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str()); + task.params.sampling.grammar_trigger_words.push_back(trigger); + } + task.params.antiprompt = chat_data.additional_stops; + if (chat_data.parser) { + task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone()); + } // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); @@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) { data, req.is_connection_closed, res, - OAICOMPAT_TYPE_CHAT); + OAICOMPAT_TYPE_CHAT, + &chat_template); }; const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {