diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0ab09db22..c8cb48b15 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -84,9 +84,6 @@ struct server_slot { bool truncated = false; stop_type stop; - bool oaicompat = false; - - std::string oaicompat_model; std::string stopping_word; // sampling @@ -494,17 +491,15 @@ struct server_response { } // Send a new result to a waiting id_task - template - void send(T & result) { - static_assert(std::is_base_of::value, "T must be derived from server_task_result"); - SRV_DBG("sending result for task id = %d\n", result.id); + void send(task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { - if (result.id == id_task) { - SRV_DBG("task id = %d pushed to result queue\n", result.id); + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); - queue_results.push_back(std::make_unique(std::move(result))); + queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } @@ -791,13 +786,16 @@ struct server_context { const auto & data = task.data; if (data.count("__oaicompat") != 0) { - slot.oaicompat = true; - slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + slot.params.oaicompat = true; + slot.params.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); } else { - slot.oaicompat = false; - slot.oaicompat_model = ""; + slot.params.oaicompat = false; } + + // enabling this will output extra debug information in the HTTP responses from the server + slot.params.verbose = params_base.verbosity > 9; slot.params.timings_per_token = json_value(data, "timings_per_token", false); slot.params.stream = json_value(data, "stream", false); @@ -1158,25 +1156,29 @@ struct server_context { void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - server_task_result_error res; - res.id = id_task; - res.err_type = type; - res.err_msg = error; + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; - queue_results.send(res); + queue_results.send(std::move(res)); } void send_partial_response(server_slot & slot, completion_token_output tkn) { - server_task_result_cmpl_partial res; - res.id = slot.id_task; - res.index = slot.index; - res.content = tkn.text_to_send; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; - res.truncated = slot.truncated; - res.n_decoded = slot.n_decoded; - res.n_prompt_tokens = slot.n_prompt_tokens; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; - res.stop = slot.stop; + res->stop = slot.stop; + + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.params.verbose; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -1186,7 +1188,7 @@ struct server_context { std::vector probs_output; if (probs_pos < probs_stop_pos) { - res.probs_output = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); } @@ -1194,10 +1196,10 @@ struct server_context { // populate timings if this is final response or timings_per_token is enabled if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { - res.timings = slot.get_timings(); + res->timings = slot.get_timings(); } - queue_results.send(res); + queue_results.send(std::move(res)); } void send_final_response(server_slot & slot) { @@ -1206,23 +1208,26 @@ struct server_context { return send_partial_response(slot, {0, "", {}}); } - server_task_result_cmpl_final res; - res.id = slot.id_task; - res.id_slot = slot.id; + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; - res.index = slot.index; - res.content = slot.generated_text; - res.timings = slot.get_timings(); - res.model_alias = slot.oaicompat_model; - res.prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->index = slot.index; + res->content = slot.generated_text; + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res.truncated = slot.truncated; - res.n_decoded = slot.n_decoded; - res.n_prompt_tokens = slot.n_prompt_tokens; - res.n_tokens_cached = slot.n_past; - res.has_new_line = slot.has_new_line; - res.stopping_word = slot.stopping_word; - res.stop = slot.stop; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.params.verbose; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -1230,25 +1235,25 @@ struct server_context { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res.probs_output = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset); } else { - res.probs_output = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end()); } } - res.generation_params = slot.params; // copy the parameters + res->generation_params = slot.params; // copy the parameters - queue_results.send(res); + queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { - server_task_result_embd res; - res.id = slot.id_task; - res.index = slot.index; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; const int n_embd = llama_n_embd(model); @@ -1267,23 +1272,23 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res.embedding = std::vector(n_embd, 0.0f); + res->embedding = std::vector(n_embd, 0.0f); continue; } common_embd_normalize(embd, embd_res.data(), n_embd); - res.embedding = embd_res; + res->embedding = embd_res; } SLT_DBG(slot, "%s", "sending embeddings\n"); - queue_results.send(res); + queue_results.send(std::move(res)); } void send_rerank(const server_slot & slot, const llama_batch & batch) { - server_task_result_rerank res; - res.id = slot.id_task; - res.index = slot.index; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -1298,16 +1303,16 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res.score = -1e6; + res->score = -1e6; continue; } - res.score = embd[0]; + res->score = embd[0]; } - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res.score); + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - queue_results.send(res); + queue_results.send(std::move(res)); } // @@ -1398,35 +1403,28 @@ struct server_context { } // receive the results from task(s) created by create_tasks_inference - template void receive_multi_results( const std::unordered_set & id_tasks, - const std::function&)> & result_handler, + const std::function&)> & result_handler, const std::function & error_handler) { - static_assert(std::is_base_of::value, "T must be derived from server_task_result"); - std::vector results(id_tasks.size()); + std::vector results(id_tasks.size()); for (size_t i = 0; i < id_tasks.size(); i++) { - task_result_ptr result_raw = queue_results.recv(id_tasks); + task_result_ptr result = queue_results.recv(id_tasks); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - error_handler(format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + error_handler(result->to_json()); cancel_tasks(id_tasks); return; } - if ( - result_raw->type == RESULT_TYPE_CMPL_FINAL - || result_raw->type == RESULT_TYPE_EMBD - || result_raw->type == RESULT_TYPE_RERANK - ) { - T result = T::from_ptr(result_raw); - const size_t idx = result.index; - GGML_ASSERT(idx < results.size() && "index out of range"); - results[idx] = std::move(result); - } else { - GGML_ASSERT(false && "unexpected result type"); - } + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); } result_handler(results); } @@ -1434,29 +1432,25 @@ struct server_context { // receive the results from task(s) created by create_tasks_inference, in stream mode void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const - std::function & result_handler, const + std::function & result_handler, const std::function & error_handler) { size_t n_finished = 0; while (true) { - task_result_ptr result_raw = queue_results.recv(id_tasks); + task_result_ptr result = queue_results.recv(id_tasks); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - error_handler(format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + error_handler(result->to_json()); cancel_tasks(id_tasks); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_CMPL_PARTIAL); - auto result = server_task_result_cmpl_partial::from_ptr(result_raw); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); if (!result_handler(result)) { cancel_tasks(id_tasks); break; } - SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str()); - - if (result.stop != STOP_TYPE_NONE) { + if (result->is_stop()) { if (++n_finished == id_tasks.size()) { break; } @@ -1546,33 +1540,33 @@ struct server_context { } SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - server_task_result_metrics res; - res.id = task.id; - res.n_idle_slots = n_idle_slots; - res.n_processing_slots = n_processing_slots; - res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); - res.t_start = metrics.t_start; + auto res = std::make_unique(); + res->id = task.id; + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; - res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); - res.kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); - res.n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res.t_prompt_processing_total = metrics.t_prompt_processing_total; - res.n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res.t_tokens_generation_total = metrics.t_tokens_generation_total; + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; - res.n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res.t_prompt_processing = metrics.t_prompt_processing; - res.n_tokens_predicted = metrics.n_tokens_predicted; - res.t_tokens_generation = metrics.t_tokens_generation; + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; - res.n_decode_total = metrics.n_decode_total; - res.n_busy_slots_total = metrics.n_busy_slots_total; + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; if (json_value(task.data, "reset_bucket", false)) { metrics.reset_bucket(); } - queue_results.send(res); + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { @@ -1600,15 +1594,15 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; - server_task_result_slot_save_load result; - result.id = task.id; - result.id_slot = id_slot; - result.filename = filename; - result.is_save = true; - result.n_tokens = token_count; - result.n_bytes = nwrite; - result.t_ms = t_save_ms; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { @@ -1643,15 +1637,15 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; - server_task_result_slot_save_load result; - result.id = task.id; - result.id_slot = id_slot; - result.filename = filename; - result.is_save = false; - result.n_tokens = token_count; - result.n_bytes = nread; - result.t_ms = t_restore_ms; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_ERASE: { @@ -1673,18 +1667,18 @@ struct server_context { llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); - server_task_result_slot_erase result; - result.id = task.id; - result.id_slot = id_slot; - result.n_erased = n_erased; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SET_LORA: { common_lora_adapters_apply(ctx, loras); - server_task_result_apply_lora result; - result.id = task.id; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); } break; } } @@ -2250,10 +2244,6 @@ int main(int argc, char ** argv) { common_init(); - // enabling this will output extra debug information in the HTTP responses from the server - // see format_final_response_oaicompat() - const bool verbose = params.verbosity > 9; - // struct that contains llama context and inference server_context ctx_server; @@ -2445,26 +2435,27 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - task_result_ptr result_raw = ctx_server.queue_results.recv(task.id); + task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result_raw->type != RESULT_TYPE_METRICS) { - SRV_ERR("Unexpected result type: %d\n", result_raw->type); - res_error(res, format_error_response("Unexpected result type", ERROR_TYPE_SERVER)); + if (result->is_error()) { + res_error(res, result->to_json()); return; } - auto result = server_task_result_metrics::from_ptr(result_raw); + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); // optionally return "fail_on_no_slot" error if (req.has_param("fail_on_no_slot")) { - if (result.n_idle_slots == 0) { + if (res_metrics->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.slots_data); + res_ok(res, res_metrics->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -2484,68 +2475,69 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - task_result_ptr result_raw = ctx_server.queue_results.recv(task.id); + task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - res_error(res, format_error_response(result.err_msg, result.err_type)); + + if (result->is_error()) { + res_error(res, result->to_json()); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_METRICS); - auto result = server_task_result_metrics::from_ptr(result_raw); + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) result.n_prompt_tokens_processed_total} + {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) result.t_prompt_processing_total / 1.e3} + {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) result.n_tokens_predicted_total} + {"value", (uint64_t) res_metrics->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) result.t_tokens_generation_total / 1.e3} + {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", result.n_decode_total} + {"value", res_metrics->n_decode_total} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) result.n_busy_slots_total / (float) result.n_decode_total} + {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", result.n_prompt_tokens_processed ? 1.e3 / result.t_prompt_processing * result.n_prompt_tokens_processed : 0.} + {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", result.n_tokens_predicted ? 1.e3 / result.t_tokens_generation * result.n_tokens_predicted : 0.} + {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} },{ {"name", "kv_cache_usage_ratio"}, {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * result.kv_cache_used_cells / params.n_ctx} + {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx} },{ {"name", "kv_cache_tokens"}, {"help", "KV-cache tokens."}, - {"value", (uint64_t) result.kv_cache_tokens_count} + {"value", (uint64_t) res_metrics->kv_cache_tokens_count} },{ {"name", "requests_processing"}, {"help", "Number of request processing."}, - {"value", (uint64_t) result.n_processing_slots} + {"value", (uint64_t) res_metrics->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of request deferred."}, - {"value", (uint64_t) result.n_tasks_deferred} + {"value", (uint64_t) res_metrics->n_tasks_deferred} }}} }; @@ -2566,7 +2558,7 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(result.t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -2592,18 +2584,15 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); + task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - res_error(res, format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + res_error(res, result->to_json()); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD); - auto result = server_task_result_slot_save_load::from_ptr(result_raw); - res_ok(res, result.to_json()); + res_ok(res, result->to_json()); }; const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { @@ -2626,18 +2615,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); + task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - res_error(res, format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + res_error(res, result->to_json()); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_SAVE_LOAD); - auto result = server_task_result_slot_save_load::from_ptr(result_raw); - res_ok(res, result.to_json()); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { @@ -2650,18 +2637,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); + task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - res_error(res, format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + res_error(res, result->to_json()); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_SLOT_ERASE); - auto result = server_task_result_slot_erase::from_ptr(result_raw); - res_ok(res, result.to_json()); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { @@ -2722,15 +2707,13 @@ int main(int argc, char ** argv) { server_task_inf_type inf_type, json & data, httplib::Response & res, - const std::function(server_task_result_cmpl_partial&)> & format_partial = nullptr, - const std::function&)> & format_final = nullptr, - // wether to send [DONE] event after completion (required for OAI-compat) - bool send_done_event = false) { + bool oai_compat = false) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } + data["completion_id"] = gen_chatcmplid(); std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -2739,17 +2722,15 @@ int main(int argc, char ** argv) { const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - if (format_final) { - res_ok(res, format_final(results)); - } else if (results.size() == 1) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { // single result - res_ok(res, results[0].to_json()); + res_ok(res, oai_compat ? results[0]->to_json_oai_compat() : results[0]->to_json()); } else { // multiple results (multitask) json arr = json::array(); for (auto & res : results) { - arr.push_back(res.to_json()); + arr.push_back(oai_compat ? res->to_json_oai_compat() : res->to_json()); } res_ok(res, arr); } @@ -2759,22 +2740,23 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, format_partial = std::move(format_partial), send_done_event](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { - if (format_partial) { - for (const auto & res : format_partial(result)) { + const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](task_result_ptr & result) -> bool { + json res_json = oai_compat ? result->to_json_oai_compat() : result->to_json(); + if (res_json.is_array()) { + for (const auto & res : res_json) { if (!server_sent_event(sink, "data", res)) { return false; } } return true; } else { - return server_sent_event(sink, "data", result.to_json()); + return server_sent_event(sink, "data", res_json); } }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); - if (send_done_event) { + if (oai_compat) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -2792,13 +2774,7 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, - data, - res, - // TODO: support OAI-compat response via format_partial and format_final - /* format_partial */ nullptr, - /* format_final */ nullptr); + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -2851,7 +2827,7 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -2859,20 +2835,9 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - const auto completion_id = gen_chatcmplid(); std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, - data, - res, - /* format_partial */ [data, model_name, completion_id](server_task_result_cmpl_partial & result) { - return format_partial_response_oaicompat(model_name, result, completion_id); - }, - /* format_final */ [data, verbose, model_name](std::vector & results) { - return format_final_response_oaicompat(data, results[0], model_name, false, verbose); - }, - /* send_done_event */ true); + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -2973,10 +2938,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { - GGML_ASSERT(res.type == RESULT_TYPE_EMBD); - responses.push_back(res.to_json()); + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3052,10 +3017,10 @@ int main(int argc, char ** argv) { // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { - GGML_ASSERT(res.type == RESULT_TYPE_RERANK); - responses.push_back(res.to_json()); + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } }, [&](const json & error_data) { res_error(res, error_data); @@ -3111,18 +3076,16 @@ int main(int argc, char ** argv) { const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(id_task); - task_result_ptr result_raw = ctx_server.queue_results.recv(id_task); + task_result_ptr result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - if (result_raw->type == RESULT_TYPE_ERROR) { - auto result = server_task_result_error::from_ptr(result_raw); - res_error(res, format_error_response(result.err_msg, result.err_type)); + if (result->is_error()) { + res_error(res, result->to_json()); return; } - GGML_ASSERT(result_raw->type == RESULT_TYPE_APPLY_LORA); - auto result = server_task_result_apply_lora::from_ptr(result_raw); - res_ok(res, result.to_json()); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; // diff --git a/examples/server/server.hpp b/examples/server/server.hpp index 1e65614f6..201f15456 100644 --- a/examples/server/server.hpp +++ b/examples/server/server.hpp @@ -15,9 +15,6 @@ using json = nlohmann::ordered_json; -// cast a shared_ptr to a specific type using copy constructor -#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast(ptr.get())); - enum stop_type { STOP_TYPE_NONE, STOP_TYPE_EOS, @@ -68,19 +65,6 @@ enum error_type { ERROR_TYPE_NOT_SUPPORTED, // custom error }; -enum result_type { - RESULT_TYPE_CMPL_FINAL, - RESULT_TYPE_CMPL_PARTIAL, - RESULT_TYPE_EMBD, - RESULT_TYPE_RERANK, - RESULT_TYPE_METRICS, - RESULT_TYPE_SLOT_SAVE_LOAD, - RESULT_TYPE_SLOT_ERASE, - RESULT_TYPE_APPLY_LORA, - RESULT_TYPE_ERROR, - RESULT_TYPE_UNKNOWN, // will throw an error -}; - struct server_task { int id = -1; // to be filled by server_queue int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL @@ -126,6 +110,12 @@ struct slot_params { uint32_t seed_cur; bool can_speculative; + // OAI-compat fields + bool oaicompat = false; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + bool verbose = false; + json to_json() { std::vector samplers; samplers.reserve(sampling.samplers.size()); @@ -205,11 +195,24 @@ struct result_timings { }; struct server_task_result { - result_type type = RESULT_TYPE_UNKNOWN; int id = -1; int id_slot = -1; - server_task_result() = default; - server_task_result(result_type type) : type(type) {} + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_partial + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual json to_json_oai_compat() { + // used by server_task_result_cmpl_final and server_task_result_cmpl_partial + return json(); + } virtual ~server_task_result() = default; }; @@ -233,12 +236,10 @@ struct completion_token_output { }; struct server_task_result_cmpl_final : server_task_result { - server_task_result_cmpl_final() : server_task_result(RESULT_TYPE_CMPL_FINAL) {} int index = 0; std::string content; bool stream; result_timings timings; - std::string model_alias; std::string prompt; bool truncated; @@ -253,14 +254,23 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; - json to_json() { + // OAI-compat fields + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + bool verbose = false; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { // non-OAI-compat JSON return json { {"index", index}, {"content", content}, {"id_slot", id_slot}, {"stop", true}, - {"model", model_alias}, + {"model", oaicompat_model}, {"tokens_predicted", n_decoded}, {"tokens_evaluated", n_prompt_tokens}, {"generation_settings", generation_params.to_json()}, @@ -274,15 +284,55 @@ struct server_task_result_cmpl_final : server_task_result { }; } - static server_task_result_cmpl_final from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_cmpl_final, result_ptr); - } + virtual json to_json_oai_compat() override { + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } - virtual ~server_task_result_cmpl_final() = default; + json choices = json::array({json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{ + {"content", content}, + {"role", "assistant"} + } + }}}); + + std::time_t t = std::time(0); + + json res = json { + {"choices", choices}, + {"created", t}, + {"model", oaicompat_model}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json(); + } + + // TODO: fix this + // if (result.contains("completion_probabilities")) { + // res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + // } + + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } }; struct server_task_result_cmpl_partial : server_task_result { - server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {} int index = 0; std::string content; @@ -295,7 +345,20 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector probs_output; result_timings timings; - json to_json() { + // OAI-compat fields + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + bool verbose = false; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return stop != STOP_TYPE_NONE; + } + + virtual json to_json() override { bool is_stop = stop != STOP_TYPE_NONE; // non-OAI-compat JSON json res = json { @@ -317,67 +380,186 @@ struct server_task_result_cmpl_partial : server_task_result { return res; } - static server_task_result_cmpl_partial from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_cmpl_partial, result_ptr); - } + virtual json to_json_oai_compat() override { + bool first = n_decoded == 0; - virtual ~server_task_result_cmpl_partial() = default; + std::string finish_reason; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } else if (stop == STOP_TYPE_LIMIT) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}); + } else { + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + // Some idiosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({json::object()}); + } + + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); + } + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + if (!finish_reason.empty()) { + ret.push_back({"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}); + } + + return std::vector({ret}); + } }; struct server_task_result_embd : server_task_result { - server_task_result_embd() : server_task_result(RESULT_TYPE_EMBD) {} - result_type type = RESULT_TYPE_EMBD; int index = 0; std::vector embedding; - json to_json() { + virtual int get_index() override { + return index; + } + + virtual json to_json() override { return json { {"index", index}, {"embedding", embedding}, }; } - - static server_task_result_embd from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_embd, result_ptr); - } - - virtual ~server_task_result_embd() = default; }; struct server_task_result_rerank : server_task_result { - server_task_result_rerank() : server_task_result(RESULT_TYPE_RERANK) {} int index = 0; float score = -1e6; - json to_json() { + virtual int get_index() override { + return index; + } + + virtual json to_json() override { return json { {"index", index}, {"score", score}, }; } - - static server_task_result_rerank from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_rerank, result_ptr); - } - - virtual ~server_task_result_rerank() = default; }; +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + struct server_task_result_error : server_task_result { - server_task_result_error() : server_task_result(RESULT_TYPE_ERROR) {} int index = 0; error_type err_type = ERROR_TYPE_SERVER; std::string err_msg; - static server_task_result_error from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_error, result_ptr); + virtual bool is_error() override { + return true; } - virtual ~server_task_result_error() = default; + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } }; struct server_task_result_metrics : server_task_result { - server_task_result_metrics() : server_task_result(RESULT_TYPE_METRICS) {} int n_idle_slots; int n_processing_slots; int n_tasks_deferred; @@ -404,7 +586,7 @@ struct server_task_result_metrics : server_task_result { // TODO: get rid of this json object and use to_json() instead json slots_data = json::array(); - json to_json() { + virtual json to_json() override { return json { { "idle", n_idle_slots }, { "processing", n_processing_slots }, @@ -430,16 +612,9 @@ struct server_task_result_metrics : server_task_result { { "slots", slots_data }, }; } - - static server_task_result_metrics from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_metrics, result_ptr); - } - - virtual ~server_task_result_metrics() = default; }; struct server_task_result_slot_save_load : server_task_result { - server_task_result_slot_save_load() : server_task_result(RESULT_TYPE_SLOT_SAVE_LOAD) {} std::string filename; bool is_save; // true = save, false = load @@ -447,7 +622,7 @@ struct server_task_result_slot_save_load : server_task_result { size_t n_bytes; double t_ms; - json to_json() { + virtual json to_json() override { if (is_save) { return json { { "id_slot", id_slot }, @@ -470,39 +645,21 @@ struct server_task_result_slot_save_load : server_task_result { }; } } - - static server_task_result_slot_save_load from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_slot_save_load, result_ptr); - } - - virtual ~server_task_result_slot_save_load() = default; }; struct server_task_result_slot_erase : server_task_result { - server_task_result_slot_erase() : server_task_result(RESULT_TYPE_SLOT_ERASE) {} size_t n_erased; - json to_json() { + virtual json to_json() override { return json { { "id_slot", id_slot }, { "n_erased", n_erased }, }; } - - static server_task_result_slot_erase from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_slot_erase, result_ptr); - } - - virtual ~server_task_result_slot_erase() = default; }; struct server_task_result_apply_lora : server_task_result { - server_task_result_apply_lora() : server_task_result(RESULT_TYPE_APPLY_LORA) {} - json to_json() { + virtual json to_json() override { return json {{ "success", true }}; } - - static server_task_result_apply_lora from_ptr(std::unique_ptr & result_ptr) { - return copy_cast_ptr(server_task_result_apply_lora, result_ptr); - } }; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 98a777192..8a8d9f8f7 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -583,155 +583,6 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat( - const json & request, - server_task_result_cmpl_final & result, - const std::string & completion_id, - bool streaming = false, - bool verbose = false) { - std::string finish_reason = "length"; - if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", result.content}, - {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = json { - {"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json { - {"completion_tokens", result.n_decoded}, - {"prompt_tokens", result.n_prompt_tokens}, - {"total_tokens", result.n_decoded + result.n_prompt_tokens} - }}, - {"id", completion_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = result.to_json(); - } - - // TODO: fix this - // if (result.contains("completion_probabilities")) { - // res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - // } - - if (result.timings.prompt_n >= 0) { - res.push_back({"timings", result.timings.to_json()}); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat( - std::string modelname, - server_task_result_cmpl_partial & result, - const std::string & completion_id) { - bool first = result.n_decoded == 0; - std::string content = result.content; - - std::string finish_reason; - if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } else if (result.stop == STOP_TYPE_LIMIT) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } - - json ret = json { - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; - - if (result.timings.prompt_n >= 0) { - ret.push_back({"timings", result.timings.to_json()}); - } - - if (!finish_reason.empty()) { - ret.push_back({"usage", json { - {"completion_tokens", result.n_decoded}, - {"prompt_tokens", result.n_prompt_tokens}, - {"total_tokens", result.n_decoded + result.n_prompt_tokens} - }}); - } - - return std::vector({ret}); -} - static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { json data = json::array(); int i = 0; @@ -823,43 +674,3 @@ static json format_detokenized_response(const std::string & content) { {"content", content} }; } - -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -}