wip [no ci]

This commit is contained in:
Xuan Son Nguyen 2024-12-04 15:03:37 +01:00
parent 1011a51b87
commit 0d6485f0f8
3 changed files with 51 additions and 48 deletions

View file

@ -494,7 +494,9 @@ struct server_response {
} }
// Send a new result to a waiting id_task // Send a new result to a waiting id_task
void send(server_task_result & result) { template<typename T>
void send(T & result) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
SRV_DBG("sending result for task id = %d\n", result.id); SRV_DBG("sending result for task id = %d\n", result.id);
std::unique_lock<std::mutex> lock(mutex_results); std::unique_lock<std::mutex> lock(mutex_results);
@ -502,7 +504,7 @@ struct server_response {
if (result.id == id_task) { if (result.id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result.id); SRV_DBG("task id = %d pushed to result queue\n", result.id);
queue_results.push_back(std::make_unique<server_task_result>(result)); queue_results.push_back(std::make_unique<T>(std::move(result)));
condition_results.notify_all(); condition_results.notify_all();
return; return;
} }
@ -1166,8 +1168,10 @@ struct server_context {
void send_partial_response(server_slot & slot, completion_token_output tkn) { void send_partial_response(server_slot & slot, completion_token_output tkn) {
server_task_result_cmpl_partial res; server_task_result_cmpl_partial res;
res.id = slot.id_task; res.id = slot.id_task;
res.content = tkn.text_to_send; res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.content = tkn.text_to_send;
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
@ -1189,7 +1193,11 @@ struct server_context {
queue_results.send(res); queue_results.send(res);
} }
void send_final_response(const server_slot & slot) { void send_final_response(server_slot & slot) {
if (slot.params.stream) {
return send_partial_response(slot, {0, "", {}});
}
server_task_result_cmpl_final res; server_task_result_cmpl_final res;
res.id = slot.id_task; res.id = slot.id_task;
res.id_slot = slot.id; res.id_slot = slot.id;
@ -1380,6 +1388,7 @@ struct server_context {
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<T>&)> & result_handler, const std::function<void(std::vector<T>&)> & result_handler,
const std::function<void(json)> & error_handler) { const std::function<void(json)> & error_handler) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
std::vector<T> results(id_tasks.size()); std::vector<T> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) { for (size_t i = 0; i < id_tasks.size(); i++) {
task_result_ptr result_raw = queue_results.recv(id_tasks); task_result_ptr result_raw = queue_results.recv(id_tasks);
@ -2815,7 +2824,7 @@ int main(int argc, char ** argv) {
if (!stream) { if (!stream) {
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) { ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
// multitask is never support in chat completion, there is only one result // multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose); json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai); res_ok(res, result_oai);
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
@ -2823,9 +2832,10 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else { } else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.to_json(), completion_id); std::vector<json> result_array = format_partial_response_oaicompat(model_name, result, completion_id);
for (auto & event_data : result_array) { for (auto & event_data : result_array) {
if (event_data.empty()) { if (event_data.empty()) {
continue; // skip the stop token continue; // skip the stop token

View file

@ -281,6 +281,8 @@ struct server_task_result_cmpl_partial : server_task_result {
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {} server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
int index = 0; int index = 0;
std::string content; std::string content;
int32_t n_decoded;
int32_t n_prompt_tokens;
stop_type stop = STOP_TYPE_NONE; stop_type stop = STOP_TYPE_NONE;
std::vector<completion_token_output> probs_output; std::vector<completion_token_output> probs_output;
result_timings timings; result_timings timings;

View file

@ -583,15 +583,14 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { static json format_final_response_oaicompat(
bool stopped_word = result.count("stopped_word") != 0; const json & request,
bool stopped_eos = json_value(result, "stopped_eos", false); server_task_result_cmpl_final & result,
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); const std::string & completion_id,
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); bool streaming = false,
std::string content = json_value(result, "content", std::string("")); bool verbose = false) {
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stopped_word || stopped_eos) { if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop"; finish_reason = "stop";
} }
@ -601,7 +600,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
{"delta", json::object()}}}) {"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason}, : json::array({json{{"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json{{"content", content}, {"message", json{{"content", result.content},
{"role", "assistant"}}}}}); {"role", "assistant"}}}}});
std::time_t t = std::time(0); std::time_t t = std::time(0);
@ -613,48 +612,42 @@ static json format_final_response_oaicompat(const json & request, const json & r
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"}, {"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json { {"usage", json {
{"completion_tokens", num_tokens_predicted}, {"completion_tokens", result.n_decoded},
{"prompt_tokens", num_prompt_tokens}, {"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens} {"total_tokens", result.n_decoded + result.n_prompt_tokens}
}}, }},
{"id", completion_id} {"id", completion_id}
}; };
// extra fields for debugging purposes // extra fields for debugging purposes
if (verbose) { if (verbose) {
res["__verbose"] = result; res["__verbose"] = result.to_json();
} }
if (result.contains("completion_probabilities")) { // TODO: fix this
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); // if (result.contains("completion_probabilities")) {
} // res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
// }
if (result.contains("timings")) { if (result.timings.prompt_n >= 0) {
res.push_back({"timings", json_value(result, "timings", json::object())}); res.push_back({"timings", result.timings.to_json()});
} }
return res; return res;
} }
// return value is vector as there is one case where we might need to generate two responses // return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) { static std::vector<json> format_partial_response_oaicompat(
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { std::string modelname,
return std::vector<json>({result}); server_task_result_cmpl_partial & result,
} const std::string & completion_id) {
bool first = result.n_decoded == 0;
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; std::string content = result.content;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason; std::string finish_reason;
if (stopped_word || stopped_eos) { if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop"; finish_reason = "stop";
} } else if (result.stop == STOP_TYPE_LIMIT) {
if (stopped_limit) {
finish_reason = "length"; finish_reason = "length";
} }
@ -724,17 +717,15 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
{"object", "chat.completion.chunk"} {"object", "chat.completion.chunk"}
}; };
if (result.contains("timings")) { if (result.timings.prompt_n >= 0) {
ret.push_back({"timings", json_value(result, "timings", json::object())}); ret.push_back({"timings", result.timings.to_json()});
} }
if (!finish_reason.empty()) { if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret.push_back({"usage", json { ret.push_back({"usage", json {
{"completion_tokens", num_tokens_predicted}, {"completion_tokens", result.n_decoded},
{"prompt_tokens", num_prompt_tokens}, {"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens} {"total_tokens", result.n_decoded + result.n_prompt_tokens}
}}); }});
} }