From d86a1ae80d942a394da1805408f4f77be269247b Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 30 Jan 2025 00:13:12 +0000 Subject: [PATCH] Unify content + message in server_task_result_cmpl_final (+ avoid string copy) --- examples/server/server.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8e39ddfc8..7fba2533e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -533,7 +533,7 @@ struct completion_token_output { struct server_task_result_cmpl_final : server_task_result { int index = 0; - std::string content; + common_chat_msg message; llama_tokens tokens; bool stream; @@ -559,7 +559,6 @@ struct server_task_result_cmpl_final : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_chat_msg; virtual int get_index() override { return index; @@ -585,7 +584,7 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_non_oaicompat() { json res = json { {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"content", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk {"tokens", stream ? llama_tokens {} : tokens}, {"id_slot", id_slot}, {"stop", true}, @@ -622,7 +621,7 @@ struct server_task_result_cmpl_final : server_task_result { json res = json { {"choices", json::array({ json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"text", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk {"index", index}, {"logprobs", logprobs}, {"finish_reason", finish_reason}, @@ -654,13 +653,13 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat() { std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls"; } json tool_calls; - if (!oaicompat_chat_msg.tool_calls.empty()) { + if (!message.tool_calls.empty()) { tool_calls = json::array(); - for (const auto & tc : oaicompat_chat_msg.tool_calls) { + for (const auto & tc : message.tool_calls) { tool_calls.push_back({ {"type", "function"}, {"function", { @@ -676,7 +675,7 @@ struct server_task_result_cmpl_final : server_task_result { {"finish_reason", finish_reason}, {"index", 0}, {"message", json { - {"content", oaicompat_chat_msg.content}, + {"content", message.content}, {"tool_calls", tool_calls}, {"role", "assistant"}, }}, @@ -2283,7 +2282,6 @@ struct server_context { res->id_slot = slot.id; res->index = slot.index; - res->content = slot.generated_text; res->tokens = slot.generated_tokens; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); @@ -2304,11 +2302,11 @@ struct server_context { res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; if (slot.params.chat_parser) { - res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text); + res->message = slot.params.chat_parser(slot.generated_text); } else { - res->oaicompat_chat_msg = { + res->message = { /* .role = */ "assistant", - /* .content = */ slot.generated_text, + /* .content = */ std::move(slot.generated_text), /* .tool_calls = */ {} }; } @@ -3838,6 +3836,8 @@ int main(int argc, char ** argv) { // OAI-compat task.params.oaicompat = oaicompat; task.params.oaicompat_cmpl_id = completion_id; + + // Grammar & tool-calls task.params.sampling.grammar = chat_params.grammar; task.params.sampling.grammar_lazy = chat_params.grammar_lazy; for (const auto & trigger : chat_params.grammar_triggers) {