Merge branch 'master' into gg/llama-kv-cache

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-02-06 10:04:33 +02:00
commit 0f1c1cab2c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
83 changed files with 3593 additions and 1410 deletions

View file

@ -131,6 +131,11 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
}
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
@ -165,8 +170,9 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
// {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@ -363,12 +369,26 @@ struct server_task {
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
msg.content = content;
}
json tool_calls;
if (!message.tool_calls.empty()) {
if (!msg.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : message.tool_calls) {
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
}
}
json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
{"message", message},
};
if (!stream && probs_output.size() > 0) {
@ -2833,8 +2858,7 @@ struct server_context {
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](server_slot & slot, llama_token token) {
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
@ -3329,6 +3353,8 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
return;
}
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_DBG("request: %s\n", req.body.c_str());
@ -3415,9 +3441,13 @@ int main(int argc, char ** argv) {
message = "Unknown Exception";
}
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
res_error(res, formatted_error);
try {
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
res_error(res, formatted_error);
} catch (const std::exception & e) {
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
}
});
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {