Unify content + message in server_task_result_cmpl_final (+ avoid string copy)

This commit is contained in:
ochafik 2025-01-30 00:13:12 +00:00
parent 77c60e662e
commit d86a1ae80d

View file

@ -533,7 +533,7 @@ struct completion_token_output {
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
std::string content;
common_chat_msg message;
llama_tokens tokens;
bool stream;
@ -559,7 +559,6 @@ struct server_task_result_cmpl_final : server_task_result {
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_chat_msg;
virtual int get_index() override {
return index;
@ -585,7 +584,7 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"content", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
{"tokens", stream ? llama_tokens {} : tokens},
{"id_slot", id_slot},
{"stop", true},
@ -622,7 +621,7 @@ struct server_task_result_cmpl_final : server_task_result {
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"text", stream ? "" : message.content}, // in stream mode, content is already in last partial chunk
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
@ -654,13 +653,13 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls";
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
}
json tool_calls;
if (!oaicompat_chat_msg.tool_calls.empty()) {
if (!message.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : oaicompat_chat_msg.tool_calls) {
for (const auto & tc : message.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
@ -676,7 +675,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", oaicompat_chat_msg.content},
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
@ -2283,7 +2282,6 @@ struct server_context {
res->id_slot = slot.id;
res->index = slot.index;
res->content = slot.generated_text;
res->tokens = slot.generated_tokens;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
@ -2304,11 +2302,11 @@ struct server_context {
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
if (slot.params.chat_parser) {
res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text);
res->message = slot.params.chat_parser(slot.generated_text);
} else {
res->oaicompat_chat_msg = {
res->message = {
/* .role = */ "assistant",
/* .content = */ slot.generated_text,
/* .content = */ std::move(slot.generated_text),
/* .tool_calls = */ {}
};
}
@ -3838,6 +3836,8 @@ int main(int argc, char ** argv) {
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
// Grammar & tool-calls
task.params.sampling.grammar = chat_params.grammar;
task.params.sampling.grammar_lazy = chat_params.grammar_lazy;
for (const auto & trigger : chat_params.grammar_triggers) {