Add --jinja and --chat-template-file flags

This commit is contained in:
ochafik 2024-12-30 03:40:34 +00:00
parent abd274a48f
commit e5113e8d74
12 changed files with 289 additions and 50 deletions

View file

@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
| `--grammar-file FNAME` | file to read grammar from |
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
**Example-specific params**

View file

@ -1623,15 +1623,35 @@ struct server_context {
return true;
}
bool validate_model_chat_template() const {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res >= 0) {
llama_chat_message chat[] = {{"user", "test"}};
std::string tmpl = std::string(model_template.data(), model_template.size());
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
return chat_res > 0;
bool validate_model_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = llama_chat_templates_from_model(model, "");
try {
templates.default_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
if (templates.tool_use_template) {
templates.tool_use_template->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
}
} else {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res >= 0) {
std::string tmpl = std::string(model_template.data(), model_template.size());
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
return false;
}
@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) {
}
};
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
std::mutex chat_templates_mutex;
std::optional<llama_chat_templates> chat_templates;
auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & {
std::lock_guard<std::mutex> lock(chat_templates_mutex);
if (!chat_templates) {
chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template);
}
return *chat_templates;
};
const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) {
// this endpoint is publicly available, please only return what is safe to be exposed
const auto & templates = get_chat_templates();
json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
{ "chat_template", templates.default_template.source() },
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && templates.tool_use_template) {
data["chat_template_tool_use"] = templates.tool_use_template->source();
}
res_ok(res, data);
};
@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) {
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
auto body = json::parse(req.body);
const auto & templates = get_chat_templates();
const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template;
json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja);
return handle_completions_generic(
SERVER_TASK_TYPE_COMPLETION,
data,
@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) {
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
if (!ctx_server.validate_model_chat_template(params.use_jinja)) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
params.chat_template = "chatml";
}

View file

@ -4,22 +4,24 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja",
[
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library():
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),

View file

@ -68,8 +68,9 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
lora_files: List[str] | None = None
chat_template_file: str | None = None
jinja: bool | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
@ -154,6 +155,10 @@ class ServerProcess:
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.jinja:
server_args.append("--jinja")
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key:

View file

@ -16,6 +16,8 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat-template.hpp"
#include <random>
#include <sstream>
@ -382,19 +384,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
return formatted_chat;
}
static std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 2) {
return "";
} else {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
}
//
// base64 utils (TODO: move to common in the future)
//
@ -552,11 +541,21 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
const minja::chat_template & tmpl,
bool use_jinja)
{
json llama_params;
// Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
if (has_tools) {
if (use_jinja) {
LOG_WRN("tools param is not fully supported yet\n");
} else {
throw std::runtime_error("tools param requires --jinja flag");
}
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@ -579,6 +578,13 @@ static json oaicompat_completion_params_parse(
}
}
// Apply chat template to the list of messages
if (use_jinja) {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} else {
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
@ -594,7 +600,7 @@ static json oaicompat_completion_params_parse(
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
static const std::vector<std::string> unsupported_params { "tool_choice" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);