refactor handle_completions_generic
This commit is contained in:
parent
eaa12887da
commit
cb666718b1
1 changed files with 48 additions and 52 deletions
|
@ -2716,7 +2716,16 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
// handle completion-like requests (completion, chat, infill)
|
||||||
|
// we can optionally provide a custom format for partial results and final results
|
||||||
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
||||||
|
server_task_inf_type inf_type,
|
||||||
|
json & data,
|
||||||
|
httplib::Response & res,
|
||||||
|
const std::function<std::vector<json>(server_task_result_cmpl_partial&)> & format_partial = nullptr,
|
||||||
|
const std::function<json(std::vector<server_task_result_cmpl_final>&)> & format_final = nullptr,
|
||||||
|
// wether to send [DONE] event after completion (required for OAI-compat)
|
||||||
|
bool send_done_event = false) {
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
|
@ -2731,7 +2740,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
|
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
|
||||||
if (results.size() == 1) {
|
if (format_final) {
|
||||||
|
res_ok(res, format_final(results));
|
||||||
|
} else if (results.size() == 1) {
|
||||||
// single result
|
// single result
|
||||||
res_ok(res, results[0].to_json());
|
res_ok(res, results[0].to_json());
|
||||||
} else {
|
} else {
|
||||||
|
@ -2748,12 +2759,25 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [task_ids, &ctx_server, format_partial = std::move(format_partial), send_done_event](size_t, httplib::DataSink & sink) {
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
|
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
|
||||||
|
if (format_partial) {
|
||||||
|
for (const auto & res : format_partial(result)) {
|
||||||
|
if (!server_sent_event(sink, "data", res)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
return server_sent_event(sink, "data", result.to_json());
|
return server_sent_event(sink, "data", result.to_json());
|
||||||
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
});
|
});
|
||||||
|
if (send_done_event) {
|
||||||
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
|
sink.write(ev_done.data(), ev_done.size());
|
||||||
|
}
|
||||||
sink.done();
|
sink.done();
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
@ -2768,7 +2792,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
return handle_completions_generic(
|
||||||
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
|
data,
|
||||||
|
res,
|
||||||
|
// TODO: support OAI-compat response via format_partial and format_final
|
||||||
|
/* format_partial */ nullptr,
|
||||||
|
/* format_final */ nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
@ -2821,8 +2851,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
|
@ -2830,53 +2859,20 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(tasks);
|
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
|
||||||
const auto completion_id = gen_chatcmplid();
|
const auto completion_id = gen_chatcmplid();
|
||||||
|
|
||||||
if (!stream) {
|
|
||||||
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
|
|
||||||
// multitask is never support in chat completion, there is only one result
|
|
||||||
json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose);
|
|
||||||
res_ok(res, result_oai);
|
|
||||||
}, [&](const json & error_data) {
|
|
||||||
res_error(res, error_data);
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
||||||
} else {
|
|
||||||
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) {
|
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
|
|
||||||
std::vector<json> result_array = format_partial_response_oaicompat(model_name, result, completion_id);
|
|
||||||
for (auto & event_data : result_array) {
|
|
||||||
if (event_data.empty()) {
|
|
||||||
continue; // skip the stop token
|
|
||||||
}
|
|
||||||
if (!server_sent_event(sink, "data", event_data)) {
|
|
||||||
return false; // connection is closed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true; // ok
|
|
||||||
}, [&](const json & error_data) {
|
|
||||||
server_sent_event(sink, "error", error_data);
|
|
||||||
});
|
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
|
||||||
sink.write(ev_done.data(), ev_done.size());
|
|
||||||
sink.done();
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
return handle_completions_generic(
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
};
|
data,
|
||||||
|
res,
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
/* format_partial */ [data, model_name, completion_id](server_task_result_cmpl_partial & result) {
|
||||||
}
|
return format_partial_response_oaicompat(model_name, result, completion_id);
|
||||||
|
},
|
||||||
|
/* format_final */ [data, verbose, model_name](std::vector<server_task_result_cmpl_final> & results) {
|
||||||
|
return format_final_response_oaicompat(data, results[0], model_name, false, verbose);
|
||||||
|
},
|
||||||
|
/* send_done_event */ true);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue