This commit is contained in:
lhpqaq 2024-12-01 00:35:41 +08:00
parent c47c41cd35
commit 21f8b73d60
2 changed files with 16 additions and 10 deletions

View file

@ -177,6 +177,8 @@ struct server_slot {
bool stopped_word = false; bool stopped_word = false;
bool stopped_limit = false; bool stopped_limit = false;
bool timing_per_token = false;
bool oaicompat = false; bool oaicompat = false;
std::string oaicompat_model; std::string oaicompat_model;
@ -882,6 +884,8 @@ struct server_context {
slot.oaicompat_model = ""; slot.oaicompat_model = "";
} }
slot.timing_per_token = json_value(data, "timing_per_token", false);
slot.params.stream = json_value(data, "stream", false); slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", true); slot.params.cache_prompt = json_value(data, "cache_prompt", true);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
@ -1269,7 +1273,6 @@ struct server_context {
{"n_keep", slot.params.n_keep}, {"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard}, {"n_discard", slot.params.n_discard},
{"ignore_eos", slot.params.sampling.ignore_eos}, {"ignore_eos", slot.params.sampling.ignore_eos},
{"timing_per_token", slot.params.sampling.timing_per_token},
{"stream", slot.params.stream}, {"stream", slot.params.stream},
//{"logit_bias", slot.params.sampling.logit_bias}, //{"logit_bias", slot.params.sampling.logit_bias},
{"n_probs", slot.params.sampling.n_probs}, {"n_probs", slot.params.sampling.n_probs},
@ -1280,6 +1283,7 @@ struct server_context {
{"speculative.n_max", slot.params.speculative.n_max}, {"speculative.n_max", slot.params.speculative.n_max},
{"speculative.n_min", slot.params.speculative.n_min}, {"speculative.n_min", slot.params.speculative.n_min},
{"speculative.p_min", slot.params.speculative.p_min}, {"speculative.p_min", slot.params.speculative.p_min},
{"timing_per_token", slot.timing_per_token},
}; };
} }
@ -1314,7 +1318,6 @@ struct server_context {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"multimodal", false}, {"multimodal", false},
{"index", slot.index}, {"index", slot.index},
{"timings", slot.get_formated_timings()},
}; };
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -1338,6 +1341,10 @@ struct server_context {
res.data["model"] = slot.oaicompat_model; res.data["model"] = slot.oaicompat_model;
} }
if (slot.timing_per_token) {
res.data["timings"] = slot.get_formated_timings();
}
queue_results.send(res); queue_results.send(res);
} }
@ -3002,7 +3009,6 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
bool timings = json_value(data, "timing_per_token", false);
const auto task_ids = server_task::get_list_id(tasks); const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid(); const auto completion_id = gen_chatcmplid();
@ -3010,7 +3016,7 @@ int main(int argc, char ** argv) {
if (!stream) { if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result // multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings); json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai); res_ok(res, result_oai);
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
@ -3018,9 +3024,9 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else { } else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings); std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) { for (auto & event_data : result_array) {
if (event_data.empty()) { if (event_data.empty()) {
continue; // skip the stop token continue; // skip the stop token

View file

@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) { static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -650,7 +650,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
} }
if (timings) { if (result.contains("timings")) {
res.push_back({"timings", json_value(result, "timings", json::object())}); res.push_back({"timings", json_value(result, "timings", json::object())});
} }
@ -658,7 +658,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
} }
// return value is vector as there is one case where we might need to generate two responses // return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) { static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result}); return std::vector<json>({result});
} }
@ -745,7 +745,7 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
{"object", "chat.completion.chunk"} {"object", "chat.completion.chunk"}
}; };
if (timings) { if (result.contains("timings")) {
ret.push_back({"timings", json_value(result, "timings", json::object())}); ret.push_back({"timings", json_value(result, "timings", json::object())});
} }