sync: minja
This commit is contained in:
parent
11594557e3
commit
43385b2ff2
3 changed files with 128 additions and 105 deletions
|
@ -298,7 +298,7 @@ class chat_template {
|
||||||
if (!tools.is_null()) {
|
if (!tools.is_null()) {
|
||||||
auto tools_val = minja::Value(actual_tools);
|
auto tools_val = minja::Value(actual_tools);
|
||||||
context->set("tools", tools_val);
|
context->set("tools", tools_val);
|
||||||
if (has_code_interpreter) {
|
if (has_code_interpreter && !extra_context.contains("builtin_tools")) {
|
||||||
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
auto builtin_tools_val = minja::Value(json {"code_interpreter"});
|
||||||
context->set("builtin_tools", builtin_tools_val);
|
context->set("builtin_tools", builtin_tools_val);
|
||||||
}
|
}
|
||||||
|
|
117
common/minja.hpp
117
common/minja.hpp
|
@ -2648,31 +2648,34 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||||
return filter.call(context, actual_args);
|
return filter.call(context, actual_args);
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
// https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject
|
auto select_or_reject = [make_filter](bool is_select) {
|
||||||
globals.set("reject", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||||
args.expectArgs("reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||||
auto & items = args.args[0];
|
auto & items = args.args[0];
|
||||||
auto filter_fn = context->get(args.args[1]);
|
auto filter_fn = context->get(args.args[1]);
|
||||||
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||||
|
|
||||||
auto filter_args = Value::array();
|
auto filter_args = Value::array();
|
||||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||||
filter_args.push_back(args.args[i]);
|
filter_args.push_back(args.args[i]);
|
||||||
}
|
|
||||||
auto filter = make_filter(filter_fn, filter_args);
|
|
||||||
|
|
||||||
auto res = Value::array();
|
|
||||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
|
||||||
auto & item = items.at(i);
|
|
||||||
ArgumentsValue filter_args;
|
|
||||||
filter_args.args.emplace_back(item);
|
|
||||||
auto pred_res = filter.call(context, filter_args);
|
|
||||||
if (!pred_res.to_bool()) {
|
|
||||||
res.push_back(item);
|
|
||||||
}
|
}
|
||||||
}
|
auto filter = make_filter(filter_fn, filter_args);
|
||||||
return res;
|
|
||||||
}));
|
auto res = Value::array();
|
||||||
|
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||||
|
auto & item = items.at(i);
|
||||||
|
ArgumentsValue filter_args;
|
||||||
|
filter_args.args.emplace_back(item);
|
||||||
|
auto pred_res = filter.call(context, filter_args);
|
||||||
|
if (pred_res.to_bool() == (is_select ? true : false)) {
|
||||||
|
res.push_back(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
globals.set("select", select_or_reject(/* is_select= */ true));
|
||||||
|
globals.set("reject", select_or_reject(/* is_select= */ false));
|
||||||
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
globals.set("map", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||||
auto res = Value::array();
|
auto res = Value::array();
|
||||||
if (args.args.size() == 1 &&
|
if (args.args.size() == 1 &&
|
||||||
|
@ -2720,41 +2723,45 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||||
if (!text.empty() && text.back() == '\n') out += "\n";
|
if (!text.empty() && text.back() == '\n') out += "\n";
|
||||||
return out;
|
return out;
|
||||||
}));
|
}));
|
||||||
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
auto select_or_reject_attr = [](bool is_select) {
|
||||||
args.expectArgs("selectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
return Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
|
||||||
auto & items = args.args[0];
|
args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits<size_t>::max)()}, {0, 0});
|
||||||
if (items.is_null())
|
auto & items = args.args[0];
|
||||||
return Value::array();
|
if (items.is_null())
|
||||||
auto attr_name = args.args[1].get<std::string>();
|
return Value::array();
|
||||||
|
auto attr_name = args.args[1].get<std::string>();
|
||||||
|
|
||||||
bool has_test = false;
|
bool has_test = false;
|
||||||
Value test_fn;
|
Value test_fn;
|
||||||
ArgumentsValue test_args {{Value()}, {}};
|
ArgumentsValue test_args {{Value()}, {}};
|
||||||
if (args.args.size() >= 3) {
|
if (args.args.size() >= 3) {
|
||||||
has_test = true;
|
has_test = true;
|
||||||
test_fn = context->get(args.args[2]);
|
test_fn = context->get(args.args[2]);
|
||||||
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump());
|
||||||
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
for (size_t i = 3, n = args.args.size(); i < n; i++) {
|
||||||
test_args.args.emplace_back(args.args[i]);
|
test_args.args.emplace_back(args.args[i]);
|
||||||
}
|
|
||||||
test_args.kwargs = args.kwargs;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = Value::array();
|
|
||||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
|
||||||
auto & item = items.at(i);
|
|
||||||
auto attr = item.get(attr_name);
|
|
||||||
if (has_test) {
|
|
||||||
test_args.args[0] = attr;
|
|
||||||
if (test_fn.call(context, test_args).to_bool()) {
|
|
||||||
res.push_back(item);
|
|
||||||
}
|
}
|
||||||
} else {
|
test_args.kwargs = args.kwargs;
|
||||||
res.push_back(attr);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return res;
|
auto res = Value::array();
|
||||||
}));
|
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||||
|
auto & item = items.at(i);
|
||||||
|
auto attr = item.get(attr_name);
|
||||||
|
if (has_test) {
|
||||||
|
test_args.args[0] = attr;
|
||||||
|
if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) {
|
||||||
|
res.push_back(item);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res.push_back(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
globals.set("selectattr", select_or_reject_attr(/* is_select= */ true));
|
||||||
|
globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false));
|
||||||
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
|
globals.set("range", Value::callable([=](const std::shared_ptr<Context> &, ArgumentsValue & args) {
|
||||||
std::vector<int64_t> startEndStep(3);
|
std::vector<int64_t> startEndStep(3);
|
||||||
std::vector<bool> param_set(3);
|
std::vector<bool> param_set(3);
|
||||||
|
|
|
@ -211,7 +211,6 @@ struct server_task {
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
const common_chat_template * tmpl,
|
|
||||||
const json & data) {
|
const json & data) {
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
@ -330,30 +329,19 @@ struct server_task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tmpl && params_base.use_jinja) {
|
if (!params_base.use_jinja) {
|
||||||
common_chat_params chat_params;
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
chat_params.messages = json_value(data, "messages", json::array());
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
chat_params.tools = json_value(data, "tools", json());
|
|
||||||
chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto"));
|
|
||||||
chat_params.json_schema = json_value(data, "json_schema", json());
|
|
||||||
chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false);
|
|
||||||
chat_params.stream = json_value(data, "stream", false);
|
|
||||||
|
|
||||||
auto chat_data = common_chat_init(*tmpl, chat_params);
|
|
||||||
params.chat_parser = std::move(chat_data.handler);
|
|
||||||
params.sampling.grammar = chat_data.grammar;
|
|
||||||
for (const auto & stop : chat_data.additional_stops) {
|
|
||||||
params.antiprompt.push_back(stop);
|
|
||||||
}
|
}
|
||||||
for (const auto & trigger : chat_data.grammar_triggers) {
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
try {
|
||||||
if (ids.size() == 1) {
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
params.sampling.grammar = json_schema_to_grammar(schema);
|
||||||
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
} catch (const std::exception & e) {
|
||||||
continue;
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||||
}
|
}
|
||||||
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
} else {
|
||||||
params.sampling.grammar_trigger_words.push_back(trigger);
|
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,15 +351,13 @@ struct server_task {
|
||||||
}
|
}
|
||||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
params.sampling.grammar = json_schema_to_grammar(json_value(data, "json_schema", json::object()));
|
||||||
params.sampling.grammar = json_schema_to_grammar(schema);
|
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
||||||
}
|
}
|
||||||
LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str());
|
|
||||||
|
|
||||||
{
|
{
|
||||||
params.sampling.logit_bias.clear();
|
params.sampling.logit_bias.clear();
|
||||||
|
@ -2248,9 +2234,15 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||||
auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send);
|
common_chat_msg msg;
|
||||||
if (!opt_msg) {
|
if (slot.params.chat_parser) {
|
||||||
return;
|
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
|
||||||
|
msg = *opt_msg;
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg.content = tkn.text_to_send;
|
||||||
}
|
}
|
||||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||||
|
|
||||||
|
@ -2267,7 +2259,7 @@ struct server_context {
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_chat_msg = *opt_msg;
|
res->oaicompat_chat_msg = msg;
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
|
@ -2308,7 +2300,11 @@ struct server_context {
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
|
res->oaicompat_chat_msg = slot.params.chat_parser ? slot.params.chat_parser->parse_final(slot.generated_text) : common_chat_msg {
|
||||||
|
/* .role = */ "assistant",
|
||||||
|
/* .content = */ slot.generated_text,
|
||||||
|
/* .tool_calls = */ {}
|
||||||
|
};
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
|
@ -3773,7 +3769,7 @@ int main(int argc, char ** argv) {
|
||||||
std::function<bool()> is_connection_closed,
|
std::function<bool()> is_connection_closed,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
oaicompat_type oaicompat,
|
oaicompat_type oaicompat,
|
||||||
const common_chat_template * tmpl) {
|
const common_chat_template * tmpl = nullptr) {
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
|
@ -3785,21 +3781,29 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get<std::string>().c_str());
|
common_chat_data chat_data;
|
||||||
std::string prompt;
|
|
||||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||||
auto chat_data = common_chat_init(*tmpl, {
|
chat_data = common_chat_init(*tmpl, {
|
||||||
/* .messages = */ json_data(data, "messages", json::array()),
|
/* .messages = */ json_value(data, "messages", json::array()),
|
||||||
/* .tools = */ json_data(data, "tools", json()),
|
/* .tools = */ json_value(data, "tools", json()),
|
||||||
/
|
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
|
||||||
|
/* .json_schema = */ json_value(data, "json_schema", json()),
|
||||||
|
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
|
||||||
|
/* .stream = */ json_value(data, "json_schema", false),
|
||||||
|
/* .grammar = */ json_value(data, "grammar", std::string("")),
|
||||||
});
|
});
|
||||||
|
if (data.contains("grammar")) {
|
||||||
prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get<std::string>());
|
chat_data.grammar = data.at("grammar");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
prompt = data.at("prompt").get<std::string>();
|
chat_data.prompt = data.at("prompt");
|
||||||
|
if (data.contains("grammar")) {
|
||||||
|
chat_data.grammar = data.at("grammar");
|
||||||
|
} else if (data.contains("json_schema")) {
|
||||||
|
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
task.params.chat_parser = common_chat_init()
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
|
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
@ -3811,16 +3815,27 @@ int main(int argc, char ** argv) {
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
ctx_server.ctx,
|
ctx_server.ctx,
|
||||||
ctx_server.params_base,
|
ctx_server.params_base,
|
||||||
nullptr,
|
|
||||||
data);
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
task.params.chat_parser = common_chat_init()
|
task.params.sampling.grammar = chat_data.grammar;
|
||||||
task.params.oaicompat_tools = json_value(data, "tools", json());
|
for (const auto & trigger : chat_data.grammar_triggers) {
|
||||||
task.params.oaicompat_tool_call_style = tool_call_style;
|
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
|
||||||
|
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
|
||||||
|
task.params.sampling.grammar_trigger_words.push_back(trigger);
|
||||||
|
}
|
||||||
|
task.params.antiprompt = chat_data.additional_stops;
|
||||||
|
if (chat_data.parser) {
|
||||||
|
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
|
||||||
|
}
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(task);
|
||||||
|
@ -4005,7 +4020,8 @@ int main(int argc, char ** argv) {
|
||||||
data,
|
data,
|
||||||
req.is_connection_closed,
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT);
|
OAICOMPAT_TYPE_CHAT,
|
||||||
|
&chat_template);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue