From 9f56c176697cade52acf01398ae4522006332af7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 2 Sep 2024 00:31:40 +0200 Subject: [PATCH] fix embeddings --- examples/server/server.cpp | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0c219c93f..6e0149023 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3037,8 +3037,6 @@ int main(int argc, char ** argv) { }, [&](json error_data) { server_sent_event(sink, "error", error_data); }); - std::string done_event = "[DONE]"; // OAI-compat behavior - sink.write(done_event.c_str(), done_event.size()); sink.done(); return true; }; @@ -3106,28 +3104,28 @@ int main(int argc, char ** argv) { } // create and queue the task - json responses; + json responses = json::array(); + bool error = false; { std::vector tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); // get the result - server_task_result result = ctx_server.queue_results.recv(tasks); - ctx_server.queue_results.remove_waiting_tasks(tasks); - if (!result.error) { - if (result.data.count("results")) { - // result for multi-task - responses = result.data.at("results"); - } else { - // result for single task - responses = std::vector{result.data}; + std::vector task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); } - } else { - // error received, ignore everything else - res_error(res, result.data); - return; - } + }, [&](json error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; } // write JSON response