sync: minja
This commit is contained in:
parent
11594557e3
commit
43385b2ff2
3 changed files with 128 additions and 105 deletions
|
@ -61,7 +61,7 @@ class chat_template {
|
|||
});
|
||||
supports_tools_ = source.find("tools") != std::string::npos;
|
||||
|
||||
requires_object_arguments_ =
|
||||
requires_object_arguments_ =
|
||||
try_raw_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
|
@ -298,7 +298,7 @@ class chat_template {
|
|||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(actual_tools);
|
||||
context->set("tools", tools_val);
|
||||
if (has_code_interpreter) {
|
||||
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
|
||||
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
||||
context->set("builtin_tools", builtin_tools_val);
|
||||
}
|
||||
|
|
117
common/minja.hpp
117
common/minja.hpp
|
@ -2648,31 +2648,34 @@ inline std::shared_ptr<Context> Context::builtins() {
|
|||
return filter.call(context, actual_args);
|
||||
});
|
||||
};
|
||||
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
|
||||
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
auto filter_fn = context->get(args.args[1]);
|
||||
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||
auto select_or_reject = [make_filter](bool is_select) {
|
||||
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
auto filter_fn = context->get(args.args[1]);
|
||||
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||
|
||||
auto filter_args = Value::array();
|
||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||
filter_args.push_back(args.args[i]);
|
||||
}
|
||||
auto filter = make_filter(filter_fn, filter_args);
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
ArgumentsValue filter_args;
|
||||
filter_args.args.emplace_back(item);
|
||||
auto pred_res = filter.call(context, filter_args);
|
||||
if (!pred_res.to_bool()) {
|
||||
res.push_back(item);
|
||||
auto filter_args = Value::array();
|
||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||
filter_args.push_back(args.args[i]);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
auto filter = make_filter(filter_fn, filter_args);
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
ArgumentsValue filter_args;
|
||||
filter_args.args.emplace_back(item);
|
||||
auto pred_res = filter.call(context, filter_args);
|
||||
if (pred_res.to_bool() == (is_select ? true : false)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
});
|
||||
};
|
||||
globals.set("select", select_or_reject(/* is_select= */ true));
|
||||
globals.set("reject", select_or_reject(/* is_select= */ false));
|
||||
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
auto res = Value::array();
|
||||
if (args.args.size() == 1 &&
|
||||
|
@ -2720,41 +2723,45 @@ inline std::shared_ptr<Context> Context::builtins() {
|
|||
if (!text.empty() && text.back() == '\n') out += "\n";
|
||||
return out;
|
||||
}));
|
||||
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
if (items.is_null())
|
||||
return Value::array();
|
||||
auto attr_name = args.args[1].get<std::string>();
|
||||
auto select_or_reject_attr = [](bool is_select) {
|
||||
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||
auto & items = args.args[0];
|
||||
if (items.is_null())
|
||||
return Value::array();
|
||||
auto attr_name = args.args[1].get<std::string>();
|
||||
|
||||
bool has_test = false;
|
||||
Value test_fn;
|
||||
ArgumentsValue test_args {{Value()}, {}};
|
||||
if (args.args.size() >= 3) {
|
||||
has_test = true;
|
||||
test_fn = context->get(args.args[2]);
|
||||
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
||||
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
||||
test_args.args.emplace_back(args.args[i]);
|
||||
}
|
||||
test_args.kwargs = args.kwargs;
|
||||
}
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
auto attr = item.get(attr_name);
|
||||
if (has_test) {
|
||||
test_args.args[0] = attr;
|
||||
if (test_fn.call(context, test_args).to_bool()) {
|
||||
res.push_back(item);
|
||||
bool has_test = false;
|
||||
Value test_fn;
|
||||
ArgumentsValue test_args {{Value()}, {}};
|
||||
if (args.args.size() >= 3) {
|
||||
has_test = true;
|
||||
test_fn = context->get(args.args[2]);
|
||||
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
||||
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
||||
test_args.args.emplace_back(args.args[i]);
|
||||
}
|
||||
} else {
|
||||
res.push_back(attr);
|
||||
test_args.kwargs = args.kwargs;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
auto & item = items.at(i);
|
||||
auto attr = item.get(attr_name);
|
||||
if (has_test) {
|
||||
test_args.args[0] = attr;
|
||||
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
} else {
|
||||
res.push_back(attr);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
});
|
||||
};
|
||||
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
|
||||
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
|
||||
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
|
||||
std::vector<int64_t> startEndStep(3);
|
||||
std::vector<bool> param_set(3);
|
||||
|
|
|
@ -211,7 +211,6 @@ struct server_task {
|
|||
static slot_params params_from_json_cmpl(
|
||||
const llama_context * ctx,
|
||||
const common_params & params_base,
|
||||
const common_chat_template * tmpl,
|
||||
const json & data) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
@ -330,30 +329,19 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
|
||||
if (tmpl && params_base.use_jinja) {
|
||||
common_chat_params chat_params;
|
||||
chat_params.messages = json_value(data, "messages", json::array());
|
||||
chat_params.tools = json_value(data, "tools", json());
|
||||
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
|
||||
chat_params.json_schema = json_value(data, "json_schema", json());
|
||||
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
|
||||
chat_params.stream = json_value(data, "stream", false);
|
||||
|
||||
auto chat_data = common_chat_init(*tmpl, chat_params);
|
||||
params.chat_parser = std::move(chat_data.handler);
|
||||
params.sampling.grammar = chat_data.grammar;
|
||||
for (const auto & stop : chat_data.additional_stops) {
|
||||
params.antiprompt.push_back(stop);
|
||||
if (!params_base.use_jinja) {
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -363,15 +351,13 @@ struct server_task {
|
|||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||
params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object()));
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||
}
|
||||
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
@ -2248,9 +2234,15 @@ struct server_context {
|
|||
}
|
||||
|
||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
|
||||
if (!opt_msg) {
|
||||
return;
|
||||
common_chat_msg msg;
|
||||
if (slot.params.chat_parser) {
|
||||
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
|
||||
msg = *opt_msg;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
msg.content = tkn.text_to_send;
|
||||
}
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
|
||||
|
@ -2267,7 +2259,7 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_msg = *opt_msg;
|
||||
res->oaicompat_chat_msg = msg;
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -2308,7 +2300,11 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
|
||||
res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg {
|
||||
/* .role = */ "assistant",
|
||||
/* .content = */ slot.generated_text,
|
||||
/* .tool_calls = */ {}
|
||||
};
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
|
|||
std::function<bool()> is_connection_closed,
|
||||
httplib::Response & res,
|
||||
oaicompat_type oaicompat,
|
||||
const common_chat_template * tmpl) {
|
||||
const common_chat_template * tmpl = nullptr) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
if (ctx_server.params_base.embedding) {
|
||||
|
@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
|
||||
try {
|
||||
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
|
||||
std::string prompt;
|
||||
common_chat_data chat_data;
|
||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||
auto chat_data = common_chat_init(*tmpl, {
|
||||
/* .messages = */ json_data(data, "messages", json::array()),
|
||||
/* .tools = */ json_data(data, "tools", json()),
|
||||
/
|
||||
chat_data = common_chat_init(*tmpl, {
|
||||
/* .messages = */ json_value(data, "messages", json::array()),
|
||||
/* .tools = */ json_value(data, "tools", json()),
|
||||
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||
/* .json_schema = */ json_value(data, "json_schema", json()),
|
||||
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
|
||||
/* .stream = */ json_value(data, "json_schema", false),
|
||||
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||
});
|
||||
|
||||
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
|
||||
if (data.contains("grammar")) {
|
||||
chat_data.grammar = data.at("grammar");
|
||||
}
|
||||
} else {
|
||||
prompt = data.at("prompt").get<std::string>();
|
||||
chat_data.prompt = data.at("prompt");
|
||||
if (data.contains("grammar")) {
|
||||
chat_data.grammar = data.at("grammar");
|
||||
} else if (data.contains("json_schema")) {
|
||||
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||
}
|
||||
}
|
||||
task.params.chat_parser = common_chat_init()
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
|
|||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
nullptr,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.chat_parser = common_chat_init()
|
||||
task.params.oaicompat_tools = json_value(data, "tools", json());
|
||||
task.params.oaicompat_tool_call_style = tool_call_style;
|
||||
task.params.sampling.grammar = chat_data.grammar;
|
||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
||||
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
||||
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
task.params.antiprompt = chat_data.additional_stops;
|
||||
if (chat_data.parser) {
|
||||
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
|
||||
}
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(task);
|
||||
|
@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
|
|||
data,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
&chat_template);
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue