fix code
This commit is contained in:
parent
c47c41cd35
commit
21f8b73d60
2 changed files with 16 additions and 10 deletions
|
@ -177,6 +177,8 @@ struct server_slot {
|
||||||
bool stopped_word = false;
|
bool stopped_word = false;
|
||||||
bool stopped_limit = false;
|
bool stopped_limit = false;
|
||||||
|
|
||||||
|
bool timing_per_token = false;
|
||||||
|
|
||||||
bool oaicompat = false;
|
bool oaicompat = false;
|
||||||
|
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
|
@ -882,6 +884,8 @@ struct server_context {
|
||||||
slot.oaicompat_model = "";
|
slot.oaicompat_model = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.timing_per_token = json_value(data, "timing_per_token", false);
|
||||||
|
|
||||||
slot.params.stream = json_value(data, "stream", false);
|
slot.params.stream = json_value(data, "stream", false);
|
||||||
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||||
|
@ -1269,7 +1273,6 @@ struct server_context {
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_discard", slot.params.n_discard},
|
{"n_discard", slot.params.n_discard},
|
||||||
{"ignore_eos", slot.params.sampling.ignore_eos},
|
{"ignore_eos", slot.params.sampling.ignore_eos},
|
||||||
{"timing_per_token", slot.params.sampling.timing_per_token},
|
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
//{"logit_bias", slot.params.sampling.logit_bias},
|
//{"logit_bias", slot.params.sampling.logit_bias},
|
||||||
{"n_probs", slot.params.sampling.n_probs},
|
{"n_probs", slot.params.sampling.n_probs},
|
||||||
|
@ -1280,6 +1283,7 @@ struct server_context {
|
||||||
{"speculative.n_max", slot.params.speculative.n_max},
|
{"speculative.n_max", slot.params.speculative.n_max},
|
||||||
{"speculative.n_min", slot.params.speculative.n_min},
|
{"speculative.n_min", slot.params.speculative.n_min},
|
||||||
{"speculative.p_min", slot.params.speculative.p_min},
|
{"speculative.p_min", slot.params.speculative.p_min},
|
||||||
|
{"timing_per_token", slot.timing_per_token},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1314,7 +1318,6 @@ struct server_context {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"multimodal", false},
|
{"multimodal", false},
|
||||||
{"index", slot.index},
|
{"index", slot.index},
|
||||||
{"timings", slot.get_formated_timings()},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
|
@ -1338,6 +1341,10 @@ struct server_context {
|
||||||
res.data["model"] = slot.oaicompat_model;
|
res.data["model"] = slot.oaicompat_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.timing_per_token) {
|
||||||
|
res.data["timings"] = slot.get_formated_timings();
|
||||||
|
}
|
||||||
|
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3002,7 +3009,6 @@ int main(int argc, char ** argv) {
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
bool timings = json_value(data, "timing_per_token", false);
|
|
||||||
|
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
const auto task_ids = server_task::get_list_id(tasks);
|
||||||
const auto completion_id = gen_chatcmplid();
|
const auto completion_id = gen_chatcmplid();
|
||||||
|
@ -3010,7 +3016,7 @@ int main(int argc, char ** argv) {
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||||
// multitask is never support in chat completion, there is only one result
|
// multitask is never support in chat completion, there is only one result
|
||||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings);
|
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
||||||
res_ok(res, result_oai);
|
res_ok(res, result_oai);
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
|
@ -3018,9 +3024,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings);
|
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||||
for (auto & event_data : result_array) {
|
for (auto & event_data : result_array) {
|
||||||
if (event_data.empty()) {
|
if (event_data.empty()) {
|
||||||
continue; // skip the stop token
|
continue; // skip the stop token
|
||||||
|
|
|
@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) {
|
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
||||||
bool stopped_word = result.count("stopped_word") != 0;
|
bool stopped_word = result.count("stopped_word") != 0;
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
|
@ -650,7 +650,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||||
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (timings) {
|
if (result.contains("timings")) {
|
||||||
res.push_back({"timings", json_value(result, "timings", json::object())});
|
res.push_back({"timings", json_value(result, "timings", json::object())});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -658,7 +658,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||||
}
|
}
|
||||||
|
|
||||||
// return value is vector as there is one case where we might need to generate two responses
|
// return value is vector as there is one case where we might need to generate two responses
|
||||||
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) {
|
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
|
||||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||||
return std::vector<json>({result});
|
return std::vector<json>({result});
|
||||||
}
|
}
|
||||||
|
@ -745,7 +745,7 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
|
||||||
{"object", "chat.completion.chunk"}
|
{"object", "chat.completion.chunk"}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (timings) {
|
if (result.contains("timings")) {
|
||||||
ret.push_back({"timings", json_value(result, "timings", json::object())});
|
ret.push_back({"timings", json_value(result, "timings", json::object())});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue