wip [no ci]
This commit is contained in:
parent
1011a51b87
commit
0d6485f0f8
3 changed files with 51 additions and 48 deletions
|
@ -494,7 +494,9 @@ struct server_response {
|
|||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result & result) {
|
||||
template<typename T>
|
||||
void send(T & result) {
|
||||
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
|
||||
SRV_DBG("sending result for task id = %d\n", result.id);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
@ -502,7 +504,7 @@ struct server_response {
|
|||
if (result.id == id_task) {
|
||||
SRV_DBG("task id = %d pushed to result queue\n", result.id);
|
||||
|
||||
queue_results.push_back(std::make_unique<server_task_result>(result));
|
||||
queue_results.push_back(std::make_unique<T>(std::move(result)));
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
|
@ -1166,8 +1168,10 @@ struct server_context {
|
|||
|
||||
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
||||
server_task_result_cmpl_partial res;
|
||||
res.id = slot.id_task;
|
||||
res.content = tkn.text_to_send;
|
||||
res.id = slot.id_task;
|
||||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.content = tkn.text_to_send;
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
|
@ -1189,7 +1193,11 @@ struct server_context {
|
|||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void send_final_response(const server_slot & slot) {
|
||||
void send_final_response(server_slot & slot) {
|
||||
if (slot.params.stream) {
|
||||
return send_partial_response(slot, {0, "", {}});
|
||||
}
|
||||
|
||||
server_task_result_cmpl_final res;
|
||||
res.id = slot.id_task;
|
||||
res.id_slot = slot.id;
|
||||
|
@ -1380,6 +1388,7 @@ struct server_context {
|
|||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<T>&)> & result_handler,
|
||||
const std::function<void(json)> & error_handler) {
|
||||
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
|
||||
std::vector<T> results(id_tasks.size());
|
||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||
task_result_ptr result_raw = queue_results.recv(id_tasks);
|
||||
|
@ -2815,7 +2824,7 @@ int main(int argc, char ** argv) {
|
|||
if (!stream) {
|
||||
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose);
|
||||
json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, result_oai);
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
|
@ -2823,9 +2832,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
||||
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.to_json(), completion_id);
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(model_name, result, completion_id);
|
||||
for (auto & event_data : result_array) {
|
||||
if (event_data.empty()) {
|
||||
continue; // skip the stop token
|
||||
|
|
|
@ -281,6 +281,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
|
||||
int index = 0;
|
||||
std::string content;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
result_timings timings;
|
||||
|
|
|
@ -583,15 +583,14 @@ static json oaicompat_completion_params_parse(
|
|||
return llama_params;
|
||||
}
|
||||
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
static json format_final_response_oaicompat(
|
||||
const json & request,
|
||||
server_task_result_cmpl_final & result,
|
||||
const std::string & completion_id,
|
||||
bool streaming = false,
|
||||
bool verbose = false) {
|
||||
std::string finish_reason = "length";
|
||||
if (stopped_word || stopped_eos) {
|
||||
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
|
@ -601,7 +600,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"message", json{{"content", result.content},
|
||||
{"role", "assistant"}}}}});
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
@ -613,48 +612,42 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
{"completion_tokens", result.n_decoded},
|
||||
{"prompt_tokens", result.n_prompt_tokens},
|
||||
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
|
||||
}},
|
||||
{"id", completion_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = result;
|
||||
res["__verbose"] = result.to_json();
|
||||
}
|
||||
|
||||
if (result.contains("completion_probabilities")) {
|
||||
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
||||
}
|
||||
// TODO: fix this
|
||||
// if (result.contains("completion_probabilities")) {
|
||||
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
||||
// }
|
||||
|
||||
if (result.contains("timings")) {
|
||||
res.push_back({"timings", json_value(result, "timings", json::object())});
|
||||
if (result.timings.prompt_n >= 0) {
|
||||
res.push_back({"timings", result.timings.to_json()});
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// return value is vector as there is one case where we might need to generate two responses
|
||||
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
|
||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||
return std::vector<json>({result});
|
||||
}
|
||||
|
||||
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
||||
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
|
||||
bool stopped_word = json_value(result, "stopped_word", false);
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
static std::vector<json> format_partial_response_oaicompat(
|
||||
std::string modelname,
|
||||
server_task_result_cmpl_partial & result,
|
||||
const std::string & completion_id) {
|
||||
bool first = result.n_decoded == 0;
|
||||
std::string content = result.content;
|
||||
|
||||
std::string finish_reason;
|
||||
if (stopped_word || stopped_eos) {
|
||||
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
if (stopped_limit) {
|
||||
} else if (result.stop == STOP_TYPE_LIMIT) {
|
||||
finish_reason = "length";
|
||||
}
|
||||
|
||||
|
@ -724,17 +717,15 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
|
|||
{"object", "chat.completion.chunk"}
|
||||
};
|
||||
|
||||
if (result.contains("timings")) {
|
||||
ret.push_back({"timings", json_value(result, "timings", json::object())});
|
||||
if (result.timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", result.timings.to_json()});
|
||||
}
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
ret.push_back({"usage", json {
|
||||
{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
{"completion_tokens", result.n_decoded},
|
||||
{"prompt_tokens", result.n_prompt_tokens},
|
||||
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
|
||||
}});
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue