add virtual functions

This commit is contained in:
Xuan Son Nguyen 2024-12-05 14:44:06 +01:00
parent cb666718b1
commit 8ab173c865
3 changed files with 441 additions and 510 deletions

View file

@ -84,9 +84,6 @@ struct server_slot {
bool truncated = false;
stop_type stop;
bool oaicompat = false;
std::string oaicompat_model;
std::string stopping_word;
// sampling
@ -494,17 +491,15 @@ struct server_response {
}
// Send a new result to a waiting id_task
template<typename T>
void send(T & result) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
SRV_DBG("sending result for task id = %d\n", result.id);
void send(task_result_ptr && result) {
SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : waiting_task_ids) {
if (result.id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result.id);
if (result->id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result->id);
queue_results.push_back(std::make_unique<T>(std::move(result)));
queue_results.emplace_back(std::move(result));
condition_results.notify_all();
return;
}
@ -791,13 +786,16 @@ struct server_context {
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
slot.oaicompat = true;
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
slot.params.oaicompat = true;
slot.params.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
} else {
slot.oaicompat = false;
slot.oaicompat_model = "";
slot.params.oaicompat = false;
}
// enabling this will output extra debug information in the HTTP responses from the server
slot.params.verbose = params_base.verbosity > 9;
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
slot.params.stream = json_value(data, "stream", false);
@ -1158,25 +1156,29 @@ struct server_context {
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
server_task_result_error res;
res.id = id_task;
res.err_type = type;
res.err_msg = error;
auto res = std::make_unique<server_task_result_error>();
res->id = id_task;
res->err_type = type;
res->err_msg = error;
queue_results.send(res);
queue_results.send(std::move(res));
}
void send_partial_response(server_slot & slot, completion_token_output tkn) {
server_task_result_cmpl_partial res;
res.id = slot.id_task;
res.index = slot.index;
res.content = tkn.text_to_send;
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->id = slot.id_task;
res->index = slot.index;
res->content = tkn.text_to_send;
res.truncated = slot.truncated;
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res.stop = slot.stop;
res->stop = slot.stop;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -1186,7 +1188,7 @@ struct server_context {
std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
res.probs_output = std::vector<completion_token_output>(
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
@ -1194,10 +1196,10 @@ struct server_context {
// populate timings if this is final response or timings_per_token is enabled
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
res.timings = slot.get_timings();
res->timings = slot.get_timings();
}
queue_results.send(res);
queue_results.send(std::move(res));
}
void send_final_response(server_slot & slot) {
@ -1206,23 +1208,26 @@ struct server_context {
return send_partial_response(slot, {0, "", {}});
}
server_task_result_cmpl_final res;
res.id = slot.id_task;
res.id_slot = slot.id;
auto res = std::make_unique<server_task_result_cmpl_final>();
res->id = slot.id_task;
res->id_slot = slot.id;
res.index = slot.index;
res.content = slot.generated_text;
res.timings = slot.get_timings();
res.model_alias = slot.oaicompat_model;
res.prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->index = slot.index;
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res.truncated = slot.truncated;
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.n_tokens_cached = slot.n_past;
res.has_new_line = slot.has_new_line;
res.stopping_word = slot.stopping_word;
res.stop = slot.stop;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->n_tokens_cached = slot.n_past;
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -1230,25 +1235,25 @@ struct server_context {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
res.probs_output = std::vector<completion_token_output>(
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end() - safe_offset);
} else {
res.probs_output = std::vector<completion_token_output>(
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
}
}
res.generation_params = slot.params; // copy the parameters
res->generation_params = slot.params; // copy the parameters
queue_results.send(res);
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
server_task_result_embd res;
res.id = slot.id_task;
res.index = slot.index;
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
const int n_embd = llama_n_embd(model);
@ -1267,23 +1272,23 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res.embedding = std::vector<float>(n_embd, 0.0f);
res->embedding = std::vector<float>(n_embd, 0.0f);
continue;
}
common_embd_normalize(embd, embd_res.data(), n_embd);
res.embedding = embd_res;
res->embedding = embd_res;
}
SLT_DBG(slot, "%s", "sending embeddings\n");
queue_results.send(res);
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result_rerank res;
res.id = slot.id_task;
res.index = slot.index;
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@ -1298,16 +1303,16 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res.score = -1e6;
res->score = -1e6;
continue;
}
res.score = embd[0];
res->score = embd[0];
}
SLT_DBG(slot, "sending rerank result, res.score = %f\n", res.score);
SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
queue_results.send(res);
queue_results.send(std::move(res));
}
//
@ -1398,35 +1403,28 @@ struct server_context {
}
// receive the results from task(s) created by create_tasks_inference
template<typename T>
void receive_multi_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<T>&)> & result_handler,
const std::function<void(std::vector<task_result_ptr>&)> & result_handler,
const std::function<void(json)> & error_handler) {
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
std::vector<T> results(id_tasks.size());
std::vector<task_result_ptr> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
task_result_ptr result_raw = queue_results.recv(id_tasks);
task_result_ptr result = queue_results.recv(id_tasks);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
error_handler(format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
error_handler(result->to_json());
cancel_tasks(id_tasks);
return;
}
if (
result_raw->type == RESULT_TYPE_CMPL_FINAL
|| result_raw->type == RESULT_TYPE_EMBD
|| result_raw->type == RESULT_TYPE_RERANK
) {
T result = T::from_ptr(result_raw);
const size_t idx = result.index;
GGML_ASSERT(idx < results.size() && "index out of range");
results[idx] = std::move(result);
} else {
GGML_ASSERT(false && "unexpected result type");
}
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
);
const size_t idx = result->get_index();
GGML_ASSERT(idx < results.size() && "index out of range");
results[idx] = std::move(result);
}
result_handler(results);
}
@ -1434,29 +1432,25 @@ struct server_context {
// receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result_cmpl_partial&)> & result_handler, const
std::function<bool(task_result_ptr&)> & result_handler, const
std::function<void(json)> & error_handler) {
size_t n_finished = 0;
while (true) {
task_result_ptr result_raw = queue_results.recv(id_tasks);
task_result_ptr result = queue_results.recv(id_tasks);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
error_handler(format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
error_handler(result->to_json());
cancel_tasks(id_tasks);
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_CMPL_PARTIAL);
auto result = server_task_result_cmpl_partial::from_ptr(result_raw);
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;
}
SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str());
if (result.stop != STOP_TYPE_NONE) {
if (result->is_stop()) {
if (++n_finished == id_tasks.size()) {
break;
}
@ -1546,33 +1540,33 @@ struct server_context {
}
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
server_task_result_metrics res;
res.id = task.id;
res.n_idle_slots = n_idle_slots;
res.n_processing_slots = n_processing_slots;
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res.t_start = metrics.t_start;
auto res = std::make_unique<server_task_result_metrics>();
res->id = task.id;
res->n_idle_slots = n_idle_slots;
res->n_processing_slots = n_processing_slots;
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start;
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
res.kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
res.n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res.t_prompt_processing_total = metrics.t_prompt_processing_total;
res.n_tokens_predicted_total = metrics.n_tokens_predicted_total;
res.t_tokens_generation_total = metrics.t_tokens_generation_total;
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
res.n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
res.t_prompt_processing = metrics.t_prompt_processing;
res.n_tokens_predicted = metrics.n_tokens_predicted;
res.t_tokens_generation = metrics.t_tokens_generation;
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
res->t_prompt_processing = metrics.t_prompt_processing;
res->n_tokens_predicted = metrics.n_tokens_predicted;
res->t_tokens_generation = metrics.t_tokens_generation;
res.n_decode_total = metrics.n_decode_total;
res.n_busy_slots_total = metrics.n_busy_slots_total;
res->n_decode_total = metrics.n_decode_total;
res->n_busy_slots_total = metrics.n_busy_slots_total;
if (json_value(task.data, "reset_bucket", false)) {
metrics.reset_bucket();
}
queue_results.send(res);
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
@ -1600,15 +1594,15 @@ struct server_context {
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
server_task_result_slot_save_load result;
result.id = task.id;
result.id_slot = id_slot;
result.filename = filename;
result.is_save = true;
result.n_tokens = token_count;
result.n_bytes = nwrite;
result.t_ms = t_save_ms;
queue_results.send(result);
auto res = std::make_unique<server_task_result_slot_save_load>();
res->id = task.id;
res->id_slot = id_slot;
res->filename = filename;
res->is_save = true;
res->n_tokens = token_count;
res->n_bytes = nwrite;
res->t_ms = t_save_ms;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
@ -1643,15 +1637,15 @@ struct server_context {
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
server_task_result_slot_save_load result;
result.id = task.id;
result.id_slot = id_slot;
result.filename = filename;
result.is_save = false;
result.n_tokens = token_count;
result.n_bytes = nread;
result.t_ms = t_restore_ms;
queue_results.send(result);
auto res = std::make_unique<server_task_result_slot_save_load>();
res->id = task.id;
res->id_slot = id_slot;
res->filename = filename;
res->is_save = false;
res->n_tokens = token_count;
res->n_bytes = nread;
res->t_ms = t_restore_ms;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
@ -1673,18 +1667,18 @@ struct server_context {
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
server_task_result_slot_erase result;
result.id = task.id;
result.id_slot = id_slot;
result.n_erased = n_erased;
queue_results.send(result);
auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
res->id_slot = id_slot;
res->n_erased = n_erased;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
common_lora_adapters_apply(ctx, loras);
server_task_result_apply_lora result;
result.id = task.id;
queue_results.send(result);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
} break;
}
}
@ -2250,10 +2244,6 @@ int main(int argc, char ** argv) {
common_init();
// enabling this will output extra debug information in the HTTP responses from the server
// see format_final_response_oaicompat()
const bool verbose = params.verbosity > 9;
// struct that contains llama context and inference
server_context ctx_server;
@ -2445,26 +2435,27 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(task, true); // high-priority task
// get the result
task_result_ptr result_raw = ctx_server.queue_results.recv(task.id);
task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result_raw->type != RESULT_TYPE_METRICS) {
SRV_ERR("Unexpected result type: %d\n", result_raw->type);
res_error(res, format_error_response("Unexpected result type", ERROR_TYPE_SERVER));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
auto result = server_task_result_metrics::from_ptr(result_raw);
// TODO: get rid of this dynamic_cast
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
GGML_ASSERT(res_metrics != nullptr);
// optionally return "fail_on_no_slot" error
if (req.has_param("fail_on_no_slot")) {
if (result.n_idle_slots == 0) {
if (res_metrics->n_idle_slots == 0) {
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
res_ok(res, result.slots_data);
res_ok(res, res_metrics->slots_data);
};
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
@ -2484,68 +2475,69 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(task, true); // high-priority task
// get the result
task_result_ptr result_raw = ctx_server.queue_results.recv(task.id);
task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
res_error(res, format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_METRICS);
auto result = server_task_result_metrics::from_ptr(result_raw);
// TODO: get rid of this dynamic_cast
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
GGML_ASSERT(res_metrics != nullptr);
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
json all_metrics_def = json {
{"counter", {{
{"name", "prompt_tokens_total"},
{"help", "Number of prompt tokens processed."},
{"value", (uint64_t) result.n_prompt_tokens_processed_total}
{"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
}, {
{"name", "prompt_seconds_total"},
{"help", "Prompt process time"},
{"value", (uint64_t) result.t_prompt_processing_total / 1.e3}
{"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
}, {
{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", (uint64_t) result.n_tokens_predicted_total}
{"value", (uint64_t) res_metrics->n_tokens_predicted_total}
}, {
{"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"},
{"value", (uint64_t) result.t_tokens_generation_total / 1.e3}
{"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
}, {
{"name", "n_decode_total"},
{"help", "Total number of llama_decode() calls"},
{"value", result.n_decode_total}
{"value", res_metrics->n_decode_total}
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},
{"value", (float) result.n_busy_slots_total / (float) result.n_decode_total}
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", result.n_prompt_tokens_processed ? 1.e3 / result.t_prompt_processing * result.n_prompt_tokens_processed : 0.}
{"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
},{
{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value", result.n_tokens_predicted ? 1.e3 / result.t_tokens_generation * result.n_tokens_predicted : 0.}
{"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
},{
{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * result.kv_cache_used_cells / params.n_ctx}
{"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
},{
{"name", "kv_cache_tokens"},
{"help", "KV-cache tokens."},
{"value", (uint64_t) result.kv_cache_tokens_count}
{"value", (uint64_t) res_metrics->kv_cache_tokens_count}
},{
{"name", "requests_processing"},
{"help", "Number of request processing."},
{"value", (uint64_t) result.n_processing_slots}
{"value", (uint64_t) res_metrics->n_processing_slots}
},{
{"name", "requests_deferred"},
{"help", "Number of request deferred."},
{"value", (uint64_t) result.n_tasks_deferred}
{"value", (uint64_t) res_metrics->n_tasks_deferred}
}}}
};
@ -2566,7 +2558,7 @@ int main(int argc, char ** argv) {
}
}
res.set_header("Process-Start-Time-Unix", std::to_string(result.t_start));
res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
res.status = 200; // HTTP OK
@ -2592,18 +2584,15 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result_raw = ctx_server.queue_results.recv(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
res_error(res, format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD);
auto result = server_task_result_slot_save_load::from_ptr(result_raw);
res_ok(res, result.to_json());
res_ok(res, result->to_json());
};
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
@ -2626,18 +2615,16 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result_raw = ctx_server.queue_results.recv(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
res_error(res, format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD);
auto result = server_task_result_slot_save_load::from_ptr(result_raw);
res_ok(res, result.to_json());
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
res_ok(res, result->to_json());
};
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
@ -2650,18 +2637,16 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result_raw = ctx_server.queue_results.recv(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
res_error(res, format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_ERASE);
auto result = server_task_result_slot_erase::from_ptr(result_raw);
res_ok(res, result.to_json());
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
res_ok(res, result->to_json());
};
const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
@ -2722,15 +2707,13 @@ int main(int argc, char ** argv) {
server_task_inf_type inf_type,
json & data,
httplib::Response & res,
const std::function<std::vector<json>(server_task_result_cmpl_partial&)> & format_partial = nullptr,
const std::function<json(std::vector<server_task_result_cmpl_final>&)> & format_final = nullptr,
// wether to send [DONE] event after completion (required for OAI-compat)
bool send_done_event = false) {
bool oai_compat = false) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
data["completion_id"] = gen_chatcmplid();
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -2739,17 +2722,15 @@ int main(int argc, char ** argv) {
const auto task_ids = server_task::get_list_id(tasks);
if (!stream) {
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
if (format_final) {
res_ok(res, format_final(results));
} else if (results.size() == 1) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].to_json());
res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json());
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(res.to_json());
arr.push_back(oai_compat ? res->to_json_oai_compat() : res->to_json());
}
res_ok(res, arr);
}
@ -2759,22 +2740,23 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, format_partial = std::move(format_partial), send_done_event](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
if (format_partial) {
for (const auto & res : format_partial(result)) {
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool {
json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
if (!server_sent_event(sink, "data", res)) {
return false;
}
}
return true;
} else {
return server_sent_event(sink, "data", result.to_json());
return server_sent_event(sink, "data", res_json);
}
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
if (send_done_event) {
if (oai_compat) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
@ -2792,13 +2774,7 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(
SERVER_TASK_INF_TYPE_COMPLETION,
data,
res,
// TODO: support OAI-compat response via format_partial and format_final
/* format_partial */ nullptr,
/* format_final */ nullptr);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -2851,7 +2827,7 @@ int main(int argc, char ** argv) {
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic, verbose](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
@ -2859,20 +2835,9 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
const auto completion_id = gen_chatcmplid();
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
return handle_completions_generic(
SERVER_TASK_INF_TYPE_COMPLETION,
data,
res,
/* format_partial */ [data, model_name, completion_id](server_task_result_cmpl_partial & result) {
return format_partial_response_oaicompat(model_name, result, completion_id);
},
/* format_final */ [data, verbose, model_name](std::vector<server_task_result_cmpl_final> & results) {
return format_final_response_oaicompat(data, results[0], model_name, false, verbose);
},
/* send_done_event */ true);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
};
const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
@ -2973,10 +2938,10 @@ int main(int argc, char ** argv) {
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_multi_results<server_task_result_embd>(task_ids, [&](std::vector<server_task_result_embd> & results) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(res.type == RESULT_TYPE_EMBD);
responses.push_back(res.to_json());
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
res_error(res, error_data);
@ -3052,10 +3017,10 @@ int main(int argc, char ** argv) {
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_multi_results<server_task_result_rerank>(task_ids, [&](std::vector<server_task_result_rerank> & results) {
ctx_server.receive_multi_results(task_ids, [&](std::vector<task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(res.type == RESULT_TYPE_RERANK);
responses.push_back(res.to_json());
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
}
}, [&](const json & error_data) {
res_error(res, error_data);
@ -3111,18 +3076,16 @@ int main(int argc, char ** argv) {
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
task_result_ptr result_raw = ctx_server.queue_results.recv(id_task);
task_result_ptr result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result_raw->type == RESULT_TYPE_ERROR) {
auto result = server_task_result_error::from_ptr(result_raw);
res_error(res, format_error_response(result.err_msg, result.err_type));
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(result_raw->type == RESULT_TYPE_APPLY_LORA);
auto result = server_task_result_apply_lora::from_ptr(result_raw);
res_ok(res, result.to_json());
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
res_ok(res, result->to_json());
};
//

View file

@ -15,9 +15,6 @@
using json = nlohmann::ordered_json;
// cast a shared_ptr to a specific type using copy constructor
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()));
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
@ -68,19 +65,6 @@ enum error_type {
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
enum result_type {
RESULT_TYPE_CMPL_FINAL,
RESULT_TYPE_CMPL_PARTIAL,
RESULT_TYPE_EMBD,
RESULT_TYPE_RERANK,
RESULT_TYPE_METRICS,
RESULT_TYPE_SLOT_SAVE_LOAD,
RESULT_TYPE_SLOT_ERASE,
RESULT_TYPE_APPLY_LORA,
RESULT_TYPE_ERROR,
RESULT_TYPE_UNKNOWN, // will throw an error
};
struct server_task {
int id = -1; // to be filled by server_queue
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
@ -126,6 +110,12 @@ struct slot_params {
uint32_t seed_cur;
bool can_speculative;
// OAI-compat fields
bool oaicompat = false;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
json to_json() {
std::vector<std::string> samplers;
samplers.reserve(sampling.samplers.size());
@ -205,11 +195,24 @@ struct result_timings {
};
struct server_task_result {
result_type type = RESULT_TYPE_UNKNOWN;
int id = -1;
int id_slot = -1;
server_task_result() = default;
server_task_result(result_type type) : type(type) {}
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_partial
return false;
}
virtual int get_index() {
return -1;
}
virtual json to_json() = 0;
virtual json to_json_oai_compat() {
// used by server_task_result_cmpl_final and server_task_result_cmpl_partial
return json();
}
virtual ~server_task_result() = default;
};
@ -233,12 +236,10 @@ struct completion_token_output {
};
struct server_task_result_cmpl_final : server_task_result {
server_task_result_cmpl_final() : server_task_result(RESULT_TYPE_CMPL_FINAL) {}
int index = 0;
std::string content;
bool stream;
result_timings timings;
std::string model_alias;
std::string prompt;
bool truncated;
@ -253,14 +254,23 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params;
json to_json() {
// OAI-compat fields
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override {
return index;
}
virtual json to_json() override {
// non-OAI-compat JSON
return json {
{"index", index},
{"content", content},
{"id_slot", id_slot},
{"stop", true},
{"model", model_alias},
{"model", oaicompat_model},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
{"generation_settings", generation_params.to_json()},
@ -274,15 +284,55 @@ struct server_task_result_cmpl_final : server_task_result {
};
}
static server_task_result_cmpl_final from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_cmpl_final, result_ptr);
}
virtual json to_json_oai_compat() override {
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
virtual ~server_task_result_cmpl_final() = default;
json choices = json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"message", json{
{"content", content},
{"role", "assistant"}
}
}}});
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model", oaicompat_model},
{"object", "chat.completion"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
}},
{"id", oaicompat_cmpl_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json();
}
// TODO: fix this
// if (result.contains("completion_probabilities")) {
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
// }
if (timings.prompt_n >= 0) {
res.push_back({"timings", timings.to_json()});
}
return res;
}
};
struct server_task_result_cmpl_partial : server_task_result {
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
int index = 0;
std::string content;
@ -295,7 +345,20 @@ struct server_task_result_cmpl_partial : server_task_result {
std::vector<completion_token_output> probs_output;
result_timings timings;
json to_json() {
// OAI-compat fields
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
virtual int get_index() override {
return index;
}
virtual bool is_stop() override {
return stop != STOP_TYPE_NONE;
}
virtual json to_json() override {
bool is_stop = stop != STOP_TYPE_NONE;
// non-OAI-compat JSON
json res = json {
@ -317,67 +380,186 @@ struct server_task_result_cmpl_partial : server_task_result {
return res;
}
static server_task_result_cmpl_partial from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_cmpl_partial, result_ptr);
}
virtual json to_json_oai_compat() override {
bool first = n_decoded == 0;
virtual ~server_task_result_cmpl_partial() = default;
std::string finish_reason;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
} else if (stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}
};
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
}
if (!finish_reason.empty()) {
ret.push_back({"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}});
}
return std::vector<json>({ret});
}
};
struct server_task_result_embd : server_task_result {
server_task_result_embd() : server_task_result(RESULT_TYPE_EMBD) {}
result_type type = RESULT_TYPE_EMBD;
int index = 0;
std::vector<float> embedding;
json to_json() {
virtual int get_index() override {
return index;
}
virtual json to_json() override {
return json {
{"index", index},
{"embedding", embedding},
};
}
static server_task_result_embd from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_embd, result_ptr);
}
virtual ~server_task_result_embd() = default;
};
struct server_task_result_rerank : server_task_result {
server_task_result_rerank() : server_task_result(RESULT_TYPE_RERANK) {}
int index = 0;
float score = -1e6;
json to_json() {
virtual int get_index() override {
return index;
}
virtual json to_json() override {
return json {
{"index", index},
{"score", score},
};
}
static server_task_result_rerank from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_rerank, result_ptr);
}
virtual ~server_task_result_rerank() = default;
};
// this function maybe used outside of server_task_result_error
static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
case ERROR_TYPE_INVALID_REQUEST:
type_str = "invalid_request_error";
code = 400;
break;
case ERROR_TYPE_AUTHENTICATION:
type_str = "authentication_error";
code = 401;
break;
case ERROR_TYPE_NOT_FOUND:
type_str = "not_found_error";
code = 404;
break;
case ERROR_TYPE_SERVER:
type_str = "server_error";
code = 500;
break;
case ERROR_TYPE_PERMISSION:
type_str = "permission_error";
code = 403;
break;
case ERROR_TYPE_NOT_SUPPORTED:
type_str = "not_supported_error";
code = 501;
break;
case ERROR_TYPE_UNAVAILABLE:
type_str = "unavailable_error";
code = 503;
break;
}
return json {
{"code", code},
{"message", message},
{"type", type_str},
};
}
struct server_task_result_error : server_task_result {
server_task_result_error() : server_task_result(RESULT_TYPE_ERROR) {}
int index = 0;
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
static server_task_result_error from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_error, result_ptr);
virtual bool is_error() override {
return true;
}
virtual ~server_task_result_error() = default;
virtual json to_json() override {
return format_error_response(err_msg, err_type);
}
};
struct server_task_result_metrics : server_task_result {
server_task_result_metrics() : server_task_result(RESULT_TYPE_METRICS) {}
int n_idle_slots;
int n_processing_slots;
int n_tasks_deferred;
@ -404,7 +586,7 @@ struct server_task_result_metrics : server_task_result {
// TODO: get rid of this json object and use to_json() instead
json slots_data = json::array();
json to_json() {
virtual json to_json() override {
return json {
{ "idle", n_idle_slots },
{ "processing", n_processing_slots },
@ -430,16 +612,9 @@ struct server_task_result_metrics : server_task_result {
{ "slots", slots_data },
};
}
static server_task_result_metrics from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_metrics, result_ptr);
}
virtual ~server_task_result_metrics() = default;
};
struct server_task_result_slot_save_load : server_task_result {
server_task_result_slot_save_load() : server_task_result(RESULT_TYPE_SLOT_SAVE_LOAD) {}
std::string filename;
bool is_save; // true = save, false = load
@ -447,7 +622,7 @@ struct server_task_result_slot_save_load : server_task_result {
size_t n_bytes;
double t_ms;
json to_json() {
virtual json to_json() override {
if (is_save) {
return json {
{ "id_slot", id_slot },
@ -470,39 +645,21 @@ struct server_task_result_slot_save_load : server_task_result {
};
}
}
static server_task_result_slot_save_load from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_slot_save_load, result_ptr);
}
virtual ~server_task_result_slot_save_load() = default;
};
struct server_task_result_slot_erase : server_task_result {
server_task_result_slot_erase() : server_task_result(RESULT_TYPE_SLOT_ERASE) {}
size_t n_erased;
json to_json() {
virtual json to_json() override {
return json {
{ "id_slot", id_slot },
{ "n_erased", n_erased },
};
}
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
}
virtual ~server_task_result_slot_erase() = default;
};
struct server_task_result_apply_lora : server_task_result {
server_task_result_apply_lora() : server_task_result(RESULT_TYPE_APPLY_LORA) {}
json to_json() {
virtual json to_json() override {
return json {{ "success", true }};
}
static server_task_result_apply_lora from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_apply_lora, result_ptr);
}
};

View file

@ -583,155 +583,6 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
static json format_final_response_oaicompat(
const json & request,
server_task_result_cmpl_final & result,
const std::string & completion_id,
bool streaming = false,
bool verbose = false) {
std::string finish_reason = "length";
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", result.content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage", json {
{"completion_tokens", result.n_decoded},
{"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
}},
{"id", completion_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = result.to_json();
}
// TODO: fix this
// if (result.contains("completion_probabilities")) {
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
// }
if (result.timings.prompt_n >= 0) {
res.push_back({"timings", result.timings.to_json()});
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(
std::string modelname,
server_task_result_cmpl_partial & result,
const std::string & completion_id) {
bool first = result.n_decoded == 0;
std::string content = result.content;
std::string finish_reason;
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
finish_reason = "stop";
} else if (result.stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
if (result.timings.prompt_n >= 0) {
ret.push_back({"timings", result.timings.to_json()});
}
if (!finish_reason.empty()) {
ret.push_back({"usage", json {
{"completion_tokens", result.n_decoded},
{"prompt_tokens", result.n_prompt_tokens},
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
}});
}
return std::vector<json>({ret});
}
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array();
int i = 0;
@ -823,43 +674,3 @@ static json format_detokenized_response(const std::string & content) {
{"content", content}
};
}
static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
case ERROR_TYPE_INVALID_REQUEST:
type_str = "invalid_request_error";
code = 400;
break;
case ERROR_TYPE_AUTHENTICATION:
type_str = "authentication_error";
code = 401;
break;
case ERROR_TYPE_NOT_FOUND:
type_str = "not_found_error";
code = 404;
break;
case ERROR_TYPE_SERVER:
type_str = "server_error";
code = 500;
break;
case ERROR_TYPE_PERMISSION:
type_str = "permission_error";
code = 403;
break;
case ERROR_TYPE_NOT_SUPPORTED:
type_str = "not_supported_error";
code = 501;
break;
case ERROR_TYPE_UNAVAILABLE:
type_str = "unavailable_error";
code = 503;
break;
}
return json {
{"code", code},
{"message", message},
{"type", type_str},
};
}