wip [no ci]

This commit is contained in:
Xuan Son Nguyen 2024-12-04 15:03:37 +01:00
parent 1011a51b87
commit 0d6485f0f8
3 changed files with 51 additions and 48 deletions

View file

@ -494,7 +494,9 @@ struct server_response {
}
// Send a new result to a waiting id_task
void send(server_task_result & result) {
template<typename T>
void send(T & result) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
SRV_DBG("sending result for task id = %d\n", result.id);
std::unique_lock<std::mutex> lock(mutex_results);
@ -502,7 +504,7 @@ struct server_response {
if (result.id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result.id);
queue_results.push_back(std::make_unique<server_task_result>(result));
queue_results.push_back(std::make_unique<T>(std::move(result)));
condition_results.notify_all();
return;
}
@ -1166,8 +1168,10 @@ struct server_context {
void send_partial_response(server_slot & slot, completion_token_output tkn) {
server_task_result_cmpl_partial res;
res.id = slot.id_task;
res.content = tkn.text_to_send;
res.id = slot.id_task;
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.content = tkn.text_to_send;
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
@ -1189,7 +1193,11 @@ struct server_context {
queue_results.send(res);
}
void send_final_response(const server_slot & slot) {
void send_final_response(server_slot & slot) {
if (slot.params.stream) {
return send_partial_response(slot, {0, "", {}});
}
server_task_result_cmpl_final res;
res.id = slot.id_task;
res.id_slot = slot.id;
@ -1380,6 +1388,7 @@ struct server_context {
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<T>&)> & result_handler,
const std::function<void(json)> & error_handler) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
std::vector<T> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
task_result_ptr result_raw = queue_results.recv(id_tasks);
@ -2815,7 +2824,7 @@ int main(int argc, char ** argv) {
if (!stream) {
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose);
json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
@ -2823,9 +2832,10 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.to_json(), completion_id);
std::vector<json> result_array = format_partial_response_oaicompat(model_name, result, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token

View file

@ -281,6 +281,8 @@ struct server_task_result_cmpl_partial : server_task_result {
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
int index = 0;
std::string content;
int32_t n_decoded;
int32_t n_prompt_tokens;
stop_type stop = STOP_TYPE_NONE;
std::vector<completion_token_output> probs_output;
result_timings timings;

View file

@ -583,15 +583,14 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
static json format_final_response_oaicompat(
const json & request,
server_task_result_cmpl_final & result,
const std::string & completion_id,
bool streaming = false,
bool verbose = false) {
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
@ -601,7 +600,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"message", json{{"content", result.content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
@ -613,48 +612,42 @@ static json format_final_response_oaicompat(const json & request, const json & r
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
{"completion_tokens", result.n_decoded},
{"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
}},
{"id", completion_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = result;
res["__verbose"] = result.to_json();
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
// TODO: fix this
// if (result.contains("completion_probabilities")) {
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
// }
if (result.contains("timings")) {
res.push_back({"timings", json_value(result, "timings", json::object())});
if (result.timings.prompt_n >= 0) {
res.push_back({"timings", result.timings.to_json()});
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
static std::vector<json> format_partial_response_oaicompat(
std::string modelname,
server_task_result_cmpl_partial & result,
const std::string & completion_id) {
bool first = result.n_decoded == 0;
std::string content = result.content;
std::string finish_reason;
if (stopped_word || stopped_eos) {
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
if (stopped_limit) {
} else if (result.stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}
@ -724,17 +717,15 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
{"object", "chat.completion.chunk"}
};
if (result.contains("timings")) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
if (result.timings.prompt_n >= 0) {
ret.push_back({"timings", result.timings.to_json()});
}
if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret.push_back({"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
{"completion_tokens", result.n_decoded},
{"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
}});
}